7.15. Écrire le vôtre

Lorsque le catalogue ne couvre pas un modèle – un réseau de recherche dont la disposition de sortie est sur mesure, une modification d’une architecture existante, un tenseur dont l’interprétation sémantique est propre à l’application – l’application fournit son propre post-traitement. Le protocole est simple : un appelable qui prend (model, inputs, outputs) et retourne tout ce que l’application attend de predict().

Une classe avec __call__ est la forme conventionnelle

class MyPostprocessor:
    def __init__(self, threshold=0.5):
        self.threshold = threshold

    def __call__(self, model, inputs, outputs):
        ...
        return result

Une simple fonction fonctionne aussi – le moteur vérifie seulement que l’objet est appelable.

7.15.1. Le brancher

Deux points de rattachement. L’argument nommé postprocess= du constructeur lie l’appelable pour chaque appel à predict() sur le modèle

model = ml.Model("/rom/my_model.tflite",
                 postprocess=MyPostprocessor())

Pour remplacer la liaison pour un seul appel – changer de décodeur sans recharger le modèle – passez callback= directement à predict

result = model.predict([img], callback=MyOtherPostprocessor())

La signature de l’appelable est la même dans les deux cas.

7.15.2. Ce que l’appelable reçoit

  • model – l’instance Model, utile pour les paramètres de quantification (output_scale, output_zero_point, output_dtype) et les dimensions d’entrée (input_shape).

  • inputs – la liste des entrées que l’application a passées à predict(). Le premier élément est généralement l’instance Normalization liée ; son attribut roi est ce que NMS attend pour reprojeter les boîtes dans l’image d’origine.

  • outputs – les tenseurs de sortie bruts sous forme de liste d’objets ndarray, dans leur dtype natif. Les sorties flottantes arrivent telles quelles ; les sorties entières arrivent quantifiées.

7.15.3. Arithmétique quantifiée

Les décodeurs fournis font tous appel aux mêmes auxiliaires de ml.utils, et un décodeur personnalisé veut généralement le même modèle : quantize() transpose un seuil flottant dans l’espace quantifié du modèle, threshold() filtre sans déquantifier tout le tenseur, et dequantize() s’exécute une seule fois sur les survivantes. sigmoid() et logit() sont disponibles pour les réseaux dont les canaux de sortie sont des logits pré-sigmoïde (les détecteurs MediaPipe en sont le cas canonique).

Pour les modèles à sorties flottantes – têtes de régression, modèles avec une couche de déquantification finale intégrée – les auxiliaires de quantification laissent passer les valeurs inchangées, de sorte que le même code de post-traitement fonctionne avec l’un ou l’autre dtype sans cas particulier.

7.15.4. Valeur de retour

Ce que l’appelable retourne est ce que predict() retourne. Pour les décodeurs émettant des boîtes, la convention est de pousser les candidats à travers une NMS et de retourner ses listes par classe – la forme d’appel que documente la suppression des non-maxima et que construit en contexte la présentation détaillée de YOLOv8. Pour tout le reste, retournez ce qui convient à l’application : un seul ndarray, une chaîne d’étiquette, un tuple (class, score, embedding), un dictionnaire.