7.7. 正規化¶
ml.Model.predict() 接受一個輸入列表,因為有些網路具有多於一個輸入張量,但該列表無法以內嵌方式攜帶各輸入的引數 -- 沒有 kwarg 欄位可用來表示「將這個輸入裁切為 (x, y, w, h) 但不動其他輸入」。ml.preprocessing.Normalization 就是填補這個缺口的包裝器。一個 Normalization 實例持有單一輸入的參數;每當指令碼需要任何非預設值的內容時,便在 predict 列表中傳入這個被包裝的輸入。
使用它最常見的原因是將擷取影格的特定區域裁切送入網路,而非整張影像。
7.7.1. 參數¶
Normalization(scale=(0.0, 1.0),
mean=(0.0, 0.0, 0.0),
stdev=(1.0, 1.0, 1.0),
roi=None)
roi-- 來源影格中於調整大小之前要裁切的(x, y, w, h)矩形。預設為整個影格。Normalization的大多數用法只會設定這個參數。scale-- 浮點輸入張量在正規化之後所預期的(min, max)範圍。像素範圍0..255會線性映射到此範圍。常見值為用於 ReLU 訓練網路的(0.0, 1.0)與用於對稱正規化網路的(-1.0, 1.0)。mean-- 在縮放之後從影像中減去的各通道(R, G, B)平均值。與網路訓練時所採用的通道統計量相符 -- 對於衍生自 ImageNet 的網路而言,(0.485, 0.456, 0.406)是典型範例。灰階網路使用標準的0.299*R + 0.587*G + 0.114*B將平均值縮減為一個亮度值。stdev-- 在減去平均值之後用以除影像的各通道(R, G, B)標準差,同樣與網路的訓練統計量相符。對於灰階網路會以相同方式縮減為亮度。
7.7.2. 參數何時重要¶
當網路的 input_dtype 為 int8 或 uint8 時,scale、mean 與 stdev 會被忽略。對於整數輸入網路,裁切後的影像位元組會直接寫入張量中,而網路自身的 input_scale 與 input_zero_point 會處理整數到實數的轉換。這三個參數只有在網路預期浮點輸入時才重要。
roi 在所有情況下都會被讀取 -- 無論輸入 dtype 為何,它都控制來源影格的哪一部分送達網路。
7.7.3. ROI 與調整大小¶
ROI 會以雙線性方式從其來源尺寸縮放至網路的輸入尺寸。影像會在目標中置中,且縮放會填滿目標 -- 它不會保留長寬比。送入正方形網路輸入的非正方形 ROI 會在水平或垂直方向被拉伸出來。
拉伸是否重要取決於網路。臉部偵測與特徵點模型,如 MediaPipe 族系(BlazeFace、FaceLandmarks、HandLandmarks、MoveNet),是針對正方形裁切訓練的,當輸入長寬比偏離時會迅速劣化;對於這些模型,應用程式需要給它們一個正方形的 ROI -- 可透過 window() 以正方形影格大小擷取,或使用 roi= 參數進行裁切。YOLO 族系物件偵測器通常以包含隨機拉伸的擴增方式訓練,可接受非正方形 ROI 而無太多準確度損失;直接傳入完整擷取的影格通常沒問題。
當網路的輸入尺寸與 ROI 完全相符時,縮放會退化為一次複製,這是成本最低的情況。
7.7.4. 覆寫預設值¶
predict() 會自動以 Normalization() 包裝每個 image.Image 輸入 -- 即上述的預設參數。隨相機出貨的大多數模型都是針對預設值已涵蓋的像素範圍訓練的,因此常見的情況是直接傳入影像::
result = model.predict([img])
若要使用自訂 ROI -- 最常見的覆寫 -- 請建構一個已設定 ROI 的 Normalization 並將影像綁定到它::
from ml.preprocessing import Normalization
norm = Normalization(roi=(80, 60, 160, 120))
result = model.predict([norm(img)])
若要符合網路訓練時的通道統計量,請設定浮點參數::
norm = Normalization(scale=(0.0, 1.0),
mean=(0.485, 0.456, 0.406),
stdev=(0.229, 0.224, 0.225))
result = model.predict([norm(img)])
對影像呼叫 Normalization 實例會回傳一個新的已綁定實例,引擎會從中填充張量。已綁定的實例就是 predict 用以取代原始影像所接受的對象,而由於它是一個各輸入專屬的物件,多輸入網路可以在同一個 predict 列表中混用具有不同 ROI 的影像。
對於預期應用程式已以張量形式產生的輸入的網路 -- 來自周邊裝置的緩衝區、由另一條管線計算出的 ndarray、非影像的數值資料 -- 請完全略過 Normalization,直接傳入該 ndarray 或一個產生它的可呼叫物件。predict() 會將那些直接傳遞給引擎而不加包裝。