7.6. predict 剖析¶
Model.predict(inputs, *, callback=None) 是已加载的模型对象真正开展工作的地方。在输入进入与结果输出之间,三个阶段依次运行:预处理、引擎分派、后处理。其中两个阶段接受脚本直接控制的参数;中间的引擎则由摄像头决定。
7.6.1. 预处理¶
预处理阶段将每个输入转换成网络所期望的稠密张量。最常见的输入是以 RGB565 捕获的 image.Image。该阶段将其裁剪并缩放到网络的 input_shape,从 RGB565 转换为网络训练时所基于的通道格式(大多数视觉网络为 RGB888),应用逐通道的缩放和偏移,并且 —— 当网络期望整数输入时 —— 在同一遍处理中量化到模型的 input_dtype。为浮点输入训练的网络会跳过量化步骤,直接接收缩放和偏移的结果。
默认的 ml.preprocessing.Normalization 读取模型的输入 dtype 并自动运行正确的变换。手动调优的 Normalization 会为基于自定义通道统计量训练的模型覆盖缩放、均值和标准差值(基于 ImageNet 得出的均值和标准差是一个常见情形)。一个普通的可调用对象则完全覆盖该阶段 —— 当输入根本不是图像,或应用程序已自行生成稠密张量时很有用。
7.6.2. 引擎分派¶
引擎阶段运行网络。它分派到哪个引擎由摄像头固定:H7 和 RT1062 运行 TFLM(用于微控制器的 TensorFlow Lite 解释器,在存在 ARM 优化的 CMSIS-NN 内核时进行分派);AE3 运行同样的 TFLM 解释器,配以其 Cortex-M55 回退路径,由 Ethos-U NPU 处理离线 Vela 编译器标记给加速器的任何算子;N6 运行 STAI,即 ST 为 N6 专门打造的 NPU 提供的运行时。
脚本不挑选引擎。随摄像头一起发布的引擎会运行摄像头加载的每一个模型。
7.6.3. 后处理¶
后处理阶段将网络的原始输出张量转回可用的结果。默认行为是将每个输出张量反量化为浮点数(对于具有浮点输出的网络则原样传递),并以 ndarray 对象列表的形式返回。大多数应用程序会注册一个后处理器 —— 一个了解网络输出布局的可调用对象 —— 将张量解码为应用程序据以行动的结果形式:一个边界框列表、一个关键点列表、一个类别列表。
脚本以两种方式控制此阶段。构造函数上的 postprocess= 关键字注册一个在每次调用时都运行的后处理器。predict() 上的 callback= 关键字仅为一次调用覆盖已注册的后处理器 —— 在无需重新加载模型即可在多个解码器之间切换时很有用。两种形式都接收 (model, inputs, outputs) 并返回应用程序所期望的任何内容。
7.6.4. 脚本所控制的内容¶
预处理和后处理是脚本的两个把手。默认的预处理器能处理大多数视觉模型;针对给定网络系列的正确后处理器可从 ml.postprocessing 下的目录中选取。中间的引擎由构建决定,无论脚本要求什么都以相同的方式运行。