7.4. 載入模型

ml.Model 會從快閃記憶體讀取模型檔案、進行解析、配置網路推論期間所需的 RAM,並回傳一個物件,該物件帶有指令碼其餘部分需要知道的關於已載入網路的一切資訊。

7.4.1. 建構子

建構子接受一個路徑與一個選用的後處理器::

model = ml.Model("/rom/blazeface_front_128.tflite",
                 postprocess=BlazeFace())

位於 /rom/(常駐於快閃記憶體的檔案系統)上的模型會就地讀取:網路的權重保留在快閃記憶體中,而載入的模型只會耗用張量競技場(tensor arena)大小的 RAM。位於 /sdcard/ 上的模型則會在載入時複製到 RAM 中,因此總成本為模型檔案大小加上張量競技場。兩種路徑都可行;取捨的是 RAM。

如果存在一個同基底檔名的相鄰 .txt 檔案,其內容會自動載入到 labels 中。postprocess= 關鍵字會註冊一個可呼叫物件,predict() 會在每次推論之後執行它。

7.4.2. 唯讀屬性

已載入的模型公開了一小組唯讀屬性,這些屬性無需任何人執行網路即可描述該網路。

檔案與記憶體。

  • len -- 磁碟上模型檔案的大小,以位元組為單位。

  • ram -- 網路在推論期間為其中間激活值所需的張量競技場(tensor arena)大小,以位元組為單位。

輸入張量。

  • input_shape -- 一個由元組組成的列表,每個輸入張量一個元組,給出網路所預期的形狀。視覺網路有一個形狀為 (1, H, W, C) 的輸入。

  • input_dtype -- 單字元 dtype 代碼的列表('b' int8、'B' uint8、'h' int16、'H' uint16、'f' float32),每個輸入一個。

  • input_scaleinput_zero_point -- 用於在網路訓練時所採用的實數值輸入與相機所執行的整數表示之間進行轉換的量化參數

輸出張量。 為輸入集合的鏡像:output_shapeoutput_dtypeoutput_scaleoutput_zero_point。偵測網路會產生兩到三個輸出張量(邊界框、信賴分數,有時還有類別機率);分類網路則產生一個。

額外項目。 labels 是從相鄰 .txt 檔案載入的類別名稱列表,否則為 Nonepostprocess 是已註冊的後處理器,否則為 None

7.4.3. 檢視 BlazeFace

載入隨附的 BlazeFace 模型並印出每個屬性,即可得到實際的數值::

import ml
from ml.postprocessing.mediapipe import BlazeFace

model = ml.Model("/rom/blazeface_front_128.tflite",
                 postprocess=BlazeFace())

print("file size:    ", model.len, "bytes")
print("tensor arena: ", model.ram, "bytes")
print("input shape:  ", model.input_shape)
print("input dtype:  ", model.input_dtype)
print("input scale:  ", model.input_scale)
print("input zp:     ", model.input_zero_point)
print("output shape: ", model.output_shape)
print("output dtype: ", model.output_dtype)
print("output scale: ", model.output_scale)
print("output zp:    ", model.output_zero_point)

這些數值具體地識別出網路的介面:單一的 (1, 128, 128, 3) int8 輸入張量與兩個 int8 輸出 -- 一個用於邊界框迴歸係數,一個用於每個錨點的信賴分數。量化參數描述了那些 int8 值如何對應到網路訓練時所採用的實數浮點值;後處理器會利用它們在解碼邊界框之前還原量化。

每個屬性都是其所描述內容的唯一真實來源。指令碼讀取 input_shape 以得知要以何種規格擷取、讀取 output_scaleoutput_zero_point 以手動解碼張量、並讀取 labels 以取得人類可讀的類別名稱 -- 絕不寫死,絕不臆測。