7.4. 加载模型¶
ml.Model 从闪存中读取模型文件、解析它、分配网络在推理期间所需的 RAM,并返回一个对象,该对象携带脚本其余部分需要了解的关于已加载网络的所有信息。
7.4.1. 构造函数¶
构造函数接受一个路径和一个可选的后处理器:
model = ml.Model("/rom/blazeface_front_128.tflite",
postprocess=BlazeFace())
位于 /rom/(常驻闪存的文件系统)上的模型会就地读取:网络的权重保留在闪存中,加载后的模型仅占用张量竞技场(tensor arena)大小的 RAM。位于 /sdcard/ 上的模型在加载时会被复制到 RAM 中,因此总开销为模型文件大小加上张量竞技场。两种路径都可行;权衡的是 RAM。
如果存在一个与之同名(basename 相同)的 .txt 文件,其内容会被自动加载到 labels 中。postprocess= 关键字注册一个可调用对象,predict() 会在每次推理后运行它。
7.4.2. 只读属性¶
已加载的模型暴露了一小组只读属性,用于在无需任何人运行它的情况下描述网络。
文件与内存。
输入张量。
input_shape—— 一个元组列表,每个输入张量对应一个,给出网络期望的形状。视觉网络有一个形状为(1, H, W, C)的输入。input_dtype—— 单字符 dtype 代码列表('b'int8、'B'uint8、'h'int16、'H'uint16、'f'float32),每个输入对应一个。input_scale和input_zero_point—— 量化参数,用于在网络训练时所基于的实数值输入与摄像头运行所用的整数表示之间进行转换。
输出张量。 输入集合的镜像:output_shape、output_dtype、output_scale、output_zero_point。检测网络产生两个或三个输出张量(边界框、置信度分数,有时还有类别概率);分类网络产生一个。
额外项。 labels 是从同名 .txt 文件加载的类名列表,或为 None。postprocess 是已注册的后处理器,或为 None。
7.4.3. 检视 BlazeFace¶
加载随附的 BlazeFace 模型并打印每个属性会给出实际的数值:
import ml
from ml.postprocessing.mediapipe import BlazeFace
model = ml.Model("/rom/blazeface_front_128.tflite",
postprocess=BlazeFace())
print("file size: ", model.len, "bytes")
print("tensor arena: ", model.ram, "bytes")
print("input shape: ", model.input_shape)
print("input dtype: ", model.input_dtype)
print("input scale: ", model.input_scale)
print("input zp: ", model.input_zero_point)
print("output shape: ", model.output_shape)
print("output dtype: ", model.output_dtype)
print("output scale: ", model.output_scale)
print("output zp: ", model.output_zero_point)
这些数值具体地标识了网络的接口:一个 (1, 128, 128, 3) 的 int8 输入张量和两个 int8 输出 —— 一个用于边界框回归系数,一个用于每个锚框的置信度分数。量化参数描述了这些 int8 值如何映射到网络训练时所基于的实数浮点值;后处理器使用它们在解码边界框之前撤销量化。
每个属性都是其所描述内容的唯一可信来源。脚本读取 input_shape 以了解应以何种规格进行捕获,读取 output_scale 和 output_zero_point 以手动解码张量,读取 labels 以获取人类可读的类名 —— 绝不硬编码,绝不臆测。