7.15. Oman kirjoittaminen

Kun luettelo ei kata mallia – tutkimusverkko, jonka ulostuloasettelu on räätälöity, hienosäätö olemassa olevaan arkkitehtuuriin, tensori, jonka semanttinen tulkinta on sovelluskohtainen – sovellus tarjoaa oman jälkikäsittelijänsä. Protokolla on yksinkertainen: kutsuttava, joka ottaa (model, inputs, outputs) ja palauttaa sen, mitä sovellus odottaa predict()-metodilta.

Luokka, jolla on __call__, on tavanomainen muoto:

class MyPostprocessor:
    def __init__(self, threshold=0.5):
        self.threshold = threshold

    def __call__(self, model, inputs, outputs):
        ...
        return result

Tavallinen funktio toimii myös – moottori tarkistaa vain, että olio on kutsuttavissa.

7.15.1. Sen kytkeminen

Kaksi kiinnityskohtaa. Konstruktorin postprocess=-avainsana-argumentti sitoo kutsuttavan jokaiselle predict()-kutsulle mallissa:

model = ml.Model("/rom/my_model.tflite",
                 postprocess=MyPostprocessor())

Sidonnan ohittamiseksi yksittäisellä kutsulla – dekoodereiden vaihtamiseksi mallia uudelleen lataamatta – välitä callback= suoraan predict-metodille:

result = model.predict([img], callback=MyOtherPostprocessor())

Kutsuttavan allekirjoitus on sama kummassakin tapauksessa.

7.15.2. Mitä kutsuttava saa

  • modelModel-instanssi, hyödyllinen kvantisointiparametreille (output_scale, output_zero_point, output_dtype) ja syötemittasuhteille (input_shape).

  • inputs – lista syötteistä, jotka sovellus välitti predict()-metodille. Ensimmäinen elementti on yleensä sidottu Normalization-instanssi; sen roi-attribuutti on se, mitä NMS odottaa laatikoiden uudelleenkuvaamiseksi takaisin alkuperäiseen kuvaan.

  • outputs – raakaulostulotensorit listana ndarray-olioita niiden alkuperäisessä dtype-tyypissä. Liukuluku-ulostulot saapuvat sellaisenaan; kokonaisluku-ulostulot saapuvat kvantisoituina.

7.15.3. Kvantisoitu laskenta

Mukana toimitetut dekooderit turvautuvat kaikki samoihin apufunktioihin ml.utils-moduulissa, ja mukautettu haluaa yleensä saman kaavan: quantize() nostaa liukulukukynnysarvon mallin kvantisoituun avaruuteen, threshold() suodattaa dekvantisoimatta koko tensoria, ja dequantize() ajetaan kerran säilyneille. sigmoid() ja logit() ovat käytettävissä verkoille, joiden ulostulokanavat ovat sigmoidia edeltäviä logitteja (MediaPipe-tunnistimet ovat kanoninen tapaus).

Malleille, joissa on liukuluku-ulostulot – regressiopäät, mallit joissa on sisäänrakennettu viimeinen dekvantisointikerros – kvantisoinnin apufunktiot kulkevat läpi muuttumattomina, joten sama jälkikäsittelijäkoodi toimii kumpaa tahansa dtype-tyyppiä vasten ilman erikoiskäsittelyä.

7.15.4. Paluuarvo

Mitä tahansa kutsuttava palauttaa, sen predict() palauttaa. Laatikoita tuottaville dekoodereille käytäntö on työntää ehdokkaat NMS-olion läpi ja palauttaa sen luokkakohtaiset listat – kutsumuoto, jonka non-max suppression dokumentoi ja jonka YOLOv8:n läpikäynti rakentaa kontekstissa. Mitä tahansa muuta varten palauta se, minkä sovellus kokee kätevimmäksi: yksittäinen ndarray, luokkanimimerkkijono, monikko (class, score, embedding) tai sanakirja.