7.7. 归一化¶
ml.Model.predict() 接受一个输入列表,因为有些网络有不止一个输入张量,但该列表无法以内联方式携带逐输入的参数 —— 没有一个 kwarg 槽位用于表达"将这个输入裁剪为 (x, y, w, h) 但保持其他输入不变"。ml.preprocessing.Normalization 就是填补这一空缺的封装器。一个 Normalization 实例为一个输入持有参数;当脚本需要默认值以外的任何东西时,就在 predict 列表中传入被封装的输入。
求助于它的最常见原因是将捕获帧的特定区域裁剪进网络,而不是整幅图像。
7.7.1. 参数¶
Normalization(scale=(0.0, 1.0),
mean=(0.0, 0.0, 0.0),
stdev=(1.0, 1.0, 1.0),
roi=None)
roi—— 源帧中的(x, y, w, h)矩形,在缩放之前裁剪。默认为整帧。Normalization的大多数用法只设置这一个参数。scale—— 浮点输入张量在归一化后所期望的(min, max)范围。像素范围0..255会被线性映射到此范围中。常见值有用于 ReLU 训练网络的(0.0, 1.0)和用于对称归一化网络的(-1.0, 1.0)。mean—— 逐通道的(R, G, B)均值,在缩放后从图像中减去。与网络训练时所基于的通道统计量相匹配 —— 用于基于 ImageNet 的网络的(0.485, 0.456, 0.406)是典型示例。灰度网络使用标准的0.299*R + 0.587*G + 0.114*B将均值化简为一个亮度值。stdev—— 逐通道的(R, G, B)标准差,在减去均值后图像被它除,同样与网络的训练统计量相匹配。对于灰度网络以同样方式化简为亮度值。
7.7.2. 参数何时起作用¶
当网络的 input_dtype 为 int8 或 uint8 时,scale、mean 和 stdev 会被忽略。对于整数输入的网络,裁剪后的图像字节会直接写入张量,由网络自身的 input_scale 和 input_zero_point 处理整数到实数的转换。这三个参数仅在网络期望浮点输入时才起作用。
roi 在所有情况下都会被读取 —— 无论输入 dtype 如何,它都控制着源帧的哪一部分到达网络。
7.7.3. ROI 与缩放¶
ROI 会从其源尺寸双线性缩放到网络的输入尺寸。图像在目标区域中居中,缩放会填满目标区域 —— 它不保持纵横比。送入正方形网络输入的非正方形 ROI 会出现水平或垂直拉伸。
拉伸是否要紧取决于网络。像 MediaPipe 系列(BlazeFace、FaceLandmarks、HandLandmarks、MoveNet)这样的人脸检测和关键点模型是基于正方形裁剪训练的,当输入纵横比偏离时会迅速退化;对于这些模型,应用程序需要给它们一个正方形 ROI —— 要么通过 window() 以正方形帧尺寸捕获,要么用 roi= 参数裁剪。YOLO 系列目标检测器通常使用包含随机拉伸在内的数据增强训练,能在不损失多少精度的情况下接受非正方形 ROI;直接传入完整的捕获帧通常没问题。
当网络的输入尺寸与 ROI 完全匹配时,缩放会退化为一次拷贝,这是开销最小的情形。
7.7.4. 覆盖默认行为¶
predict() 会自动用 Normalization() 封装每个 image.Image 输入 —— 即上述的默认参数。随摄像头一起发布的大多数模型都是基于默认值已经覆盖的像素范围训练的,因此常见情形是直接传入图像:
result = model.predict([img])
要使用自定义 ROI —— 最常见的覆盖方式 —— 构建一个设置了 ROI 的 Normalization 并将图像绑定到它:
from ml.preprocessing import Normalization
norm = Normalization(roi=(80, 60, 160, 120))
result = model.predict([norm(img)])
要匹配网络的训练时通道统计量,请设置浮点参数:
norm = Normalization(scale=(0.0, 1.0),
mean=(0.485, 0.456, 0.406),
stdev=(0.229, 0.224, 0.225))
result = model.predict([norm(img)])
在图像上调用 Normalization 实例会返回一个新的已绑定实例,引擎从它填充张量。该已绑定实例就是 predict 用来代替原始图像所接受的对象,由于它是逐输入的对象,多输入网络可以在同一个 predict 列表中混用具有不同 ROI 的图像。
对于期望应用程序已以张量形式生成的输入的网络 —— 来自外设的缓冲区、由另一条流水线计算出的 ndarray、非图像的数值数据 —— 请完全跳过 Normalization,直接传入 ndarray 或一个生成它的可调用对象。predict() 会将这些原样传给引擎而不进行封装。