7.15. Pisanie własnego

Gdy katalog nie obejmuje danego modelu – sieci badawczej o niestandardowym układzie wyjścia, modyfikacji istniejącej architektury, tensora, którego interpretacja semantyczna jest specyficzna dla aplikacji – aplikacja dostarcza własny post-processor. Protokół jest prosty: obiekt wywoływalny, który przyjmuje (model, inputs, outputs) i zwraca cokolwiek, czego aplikacja oczekuje od predict().

Klasa z metodą __call__ to konwencjonalna forma:

class MyPostprocessor:
    def __init__(self, threshold=0.5):
        self.threshold = threshold

    def __call__(self, model, inputs, outputs):
        ...
        return result

Zwykła funkcja również działa – silnik sprawdza jedynie, czy obiekt jest wywoływalny.

7.15.1. Podłączanie go

Dwa punkty podłączenia. Argument nazwany postprocess= w konstruktorze wiąże obiekt wywoływalny dla każdego wywołania predict() na modelu:

model = ml.Model("/rom/my_model.tflite",
                 postprocess=MyPostprocessor())

Aby nadpisać to powiązanie dla pojedynczego wywołania – zamienić dekodery bez ponownego wczytywania modelu – przekaż callback= bezpośrednio do predict:

result = model.predict([img], callback=MyOtherPostprocessor())

Sygnatura obiektu wywoływalnego jest w obu przypadkach taka sama.

7.15.2. Co otrzymuje obiekt wywoływalny

  • model – instancja Model, przydatna do parametrów kwantyzacji (output_scale, output_zero_point, output_dtype) oraz wymiarów wejścia (input_shape).

  • inputs – lista wejść przekazanych przez aplikację do predict(). Pierwszym elementem jest zwykle związana instancja Normalization; jej atrybut roi jest tym, czego NMS oczekuje do odwzorowania ramek z powrotem na oryginalny obraz.

  • outputs – surowe tensory wyjściowe jako lista obiektów ndarray, w ich natywnym dtype. Wyjścia zmiennoprzecinkowe przychodzą bez zmian; wyjścia całkowitoliczbowe przychodzą skwantyzowane.

7.15.3. Arytmetyka skwantyzowana

Wszystkie dostarczane dekodery sięgają po te same funkcje pomocnicze w ml.utils, a niestandardowy zwykle potrzebuje tego samego wzorca: quantize() przenosi zmiennoprzecinkowy próg do skwantyzowanej przestrzeni modelu, threshold() filtruje bez dekwantyzacji całego tensora, a dequantize() uruchamia się raz na tych, które przetrwały. sigmoid() i logit() są dostępne dla sieci, których kanały wyjściowe są logitami sprzed sigmoidy (detektory MediaPipe są kanonicznym przypadkiem).

Dla modeli z wyjściami zmiennoprzecinkowymi – głowic regresyjnych, modeli z wbudowaną końcową warstwą dekwantyzacji – funkcje pomocnicze kwantyzacji przechodzą bez zmian, więc ten sam kod post-processora działa dla obu dtype bez specjalnego traktowania.

7.15.4. Wartość zwracana

Cokolwiek zwraca obiekt wywoływalny, to właśnie zwraca predict(). Dla dekoderów emitujących ramki konwencją jest przepuszczenie kandydatów przez NMS i zwrócenie jego list dla poszczególnych klas – kształt wywołania, który dokumentuje tłumienie niemaksymalne, a przewodnik po YOLOv8 buduje w kontekście. W przypadku czegokolwiek innego zwróć cokolwiek, co aplikacja uzna za wygodne: pojedynczą ndarray, łańcuch etykiety, krotkę (class, score, embedding), słownik.