7.15. 编写你自己的¶
当目录无法覆盖某个模型时——例如某个输出布局是定制的研究网络、对现有架构的一处调整、或某个语义解释因应用而异的张量——应用程序需提供自己的后处理器。其协议很简单:一个接收 (model, inputs, outputs) 并返回应用程序期望从 predict() 得到的任何内容的可调用对象。
带有 __call__ 的类是惯用的形式:
class MyPostprocessor:
def __init__(self, threshold=0.5):
self.threshold = threshold
def __call__(self, model, inputs, outputs):
...
return result
普通函数也可以——引擎只检查该对象是否可调用。
7.15.1. 将其挂接进去¶
有两个挂接点。构造函数上的 postprocess= 关键字参数会为该模型上的每一次 predict() 调用绑定该可调用对象:
model = ml.Model("/rom/my_model.tflite",
postprocess=MyPostprocessor())
若要为单次调用覆盖该绑定——即在不重新加载模型的情况下切换解码器——可直接向 predict 传入 callback=:
result = model.predict([img], callback=MyOtherPostprocessor())
两种情况下可调用对象的签名都相同。
7.15.2. 可调用对象接收什么¶
model——Model实例,用于获取量化参数(output_scale、output_zero_point、output_dtype)和输入维度(input_shape)。inputs——应用程序传给predict()的输入列表。第一个元素通常是绑定的Normalization实例;它的roi属性正是NMS在将框重新映射回原始图像时所需要的。outputs——原始输出张量,以ndarray对象列表的形式给出,保持其原生 dtype。浮点输出按原样到达;整数输出以量化形式到达。
7.15.3. 量化运算¶
内置解码器都会用到 ml.utils 中相同的辅助函数,自定义解码器通常也希望采用同样的模式:quantize() 将浮点阈值提升到模型的量化空间,threshold() 在不对整个张量进行反量化的情况下进行过滤,而 dequantize() 只对存活下来的元素运行一次。对于输出通道为预 sigmoid logit 的网络(MediaPipe 检测器是典型案例),可以使用 sigmoid() 和 logit()。
对于具有浮点输出的模型——回归头、内置了最终反量化层的模型——量化辅助函数会原样透传,因此同一套后处理器代码无需特殊处理即可适用于任一 dtype。
7.15.4. 返回值¶
可调用对象返回什么,predict() 就返回什么。对于输出框的解码器,惯例是将候选框推送通过一个 NMS 并返回其按类划分的列表——也就是 非极大值抑制 所记录、并由 YOLOv8 详解 在具体场景中构建的调用结构。对于其他任何情况,返回应用程序认为方便的任何内容:单个 ndarray、一个标签字符串、一个 (class, score, embedding) 元组、一个字典。