7.6. predict 剖析

Model.predict(inputs, *, callback=None) 是已加载的模型对象真正开展工作的地方。在输入进入与结果输出之间,三个阶段依次运行:预处理引擎分派后处理。其中两个阶段接受脚本直接控制的参数;中间的引擎则由摄像头决定。

五个相连的方框从左到右水平排列。最左边是"Image input";一个箭头指向 带有副标题"Normalization"的"Pre-process";一个箭头指向 带有副标题"TFLM / STAI"的"Engine";一个箭头指向 带有副标题"postprocess="的"Post-process";最后一个 箭头指向"Result"。中间三个阶段下方各带一个 标签 —— Pre-process 下方为"user-controllable", Engine 下方为"automatic",Post-process 下方为 "user-controllable"。

predict() 的三个阶段。预处理和后处理接受脚本控制的参数;中间的引擎由摄像头固定。

7.6.1. 预处理

预处理阶段将每个输入转换成网络所期望的稠密张量。最常见的输入是以 RGB565 捕获的 image.Image。该阶段将其裁剪并缩放到网络的 input_shape,从 RGB565 转换为网络训练时所基于的通道格式(大多数视觉网络为 RGB888),应用逐通道的缩放和偏移,并且 —— 当网络期望整数输入时 —— 在同一遍处理中量化到模型的 input_dtype。为浮点输入训练的网络会跳过量化步骤,直接接收缩放和偏移的结果。

默认的 ml.preprocessing.Normalization 读取模型的输入 dtype 并自动运行正确的变换。手动调优的 Normalization 会为基于自定义通道统计量训练的模型覆盖缩放、均值和标准差值(基于 ImageNet 得出的均值和标准差是一个常见情形)。一个普通的可调用对象则完全覆盖该阶段 —— 当输入根本不是图像,或应用程序已自行生成稠密张量时很有用。

7.6.2. 引擎分派

引擎阶段运行网络。它分派到哪个引擎由摄像头固定:H7 和 RT1062 运行 TFLM(用于微控制器的 TensorFlow Lite 解释器,在存在 ARM 优化的 CMSIS-NN 内核时进行分派);AE3 运行同样的 TFLM 解释器,配以其 Cortex-M55 回退路径,由 Ethos-U NPU 处理离线 Vela 编译器标记给加速器的任何算子;N6 运行 STAI,即 ST 为 N6 专门打造的 NPU 提供的运行时。

脚本不挑选引擎。随摄像头一起发布的引擎会运行摄像头加载的每一个模型。

7.6.3. 后处理

后处理阶段将网络的原始输出张量转回可用的结果。默认行为是将每个输出张量反量化为浮点数(对于具有浮点输出的网络则原样传递),并以 ndarray 对象列表的形式返回。大多数应用程序会注册一个后处理器 —— 一个了解网络输出布局的可调用对象 —— 将张量解码为应用程序据以行动的结果形式:一个边界框列表、一个关键点列表、一个类别列表。

脚本以两种方式控制此阶段。构造函数上的 postprocess= 关键字注册一个在每次调用时都运行的后处理器。predict() 上的 callback= 关键字仅为一次调用覆盖已注册的后处理器 —— 在无需重新加载模型即可在多个解码器之间切换时很有用。两种形式都接收 (model, inputs, outputs) 并返回应用程序所期望的任何内容。

7.6.4. 脚本所控制的内容

预处理和后处理是脚本的两个把手。默认的预处理器能处理大多数视觉模型;针对给定网络系列的正确后处理器可从 ml.postprocessing 下的目录中选取。中间的引擎由构建决定,无论脚本要求什么都以相同的方式运行。