7.4. モデルの読み込み

ml.Model はフラッシュからモデルファイルを読み込み、それを解析し、ネットワークが推論中に必要とするRAMを割り当て、読み込まれたネットワークについてスクリプトの残りの部分が知るべきすべての情報を保持するオブジェクトを返します。

7.4.1. コンストラクタ

コンストラクタはパスとオプションのポストプロセッサを受け取ります:

model = ml.Model("/rom/blazeface_front_128.tflite",
                 postprocess=BlazeFace())

/rom/(フラッシュ常駐のファイルシステム)上のモデルはその場で読み込まれます。ネットワークの重みはフラッシュに留まり、読み込まれたモデルはテンサーアリーナ分のRAMしか消費しません。/sdcard/ 上のモデルは読み込み時にRAMへコピーされるため、総コストはモデルファイルサイズにテンサーアリーナを加えたものになります。どちらのパスでも動作します。トレードオフはRAMです。

同じベース名を持つ兄弟関係の .txt ファイルが存在する場合、その内容は自動的に labels に読み込まれます。postprocess= キーワードは、各推論の後に predict() が実行する呼び出し可能オブジェクトを登録します。

7.4.2. 読み取り専用プロパティ

読み込まれたモデルは、誰も実行しなくてもネットワークを記述する小さな読み取り専用プロパティのセットを公開します。

ファイルとメモリ。

  • len -- ディスク上のモデルファイルサイズ(バイト単位)。

  • ram -- 推論中の中間アクティベーションのためにネットワークが必要とするテンサーアリーナのサイズ(バイト単位)。

入力テンサー。

  • input_shape -- 入力テンサーごとに1つのタプルのリストで、ネットワークが期待する形状を示します。ビジョンネットワークは形状 (1, H, W, C) の入力を1つ持ちます。

  • input_dtype -- 1文字のdtypeコード('b' int8、'B' uint8、'h' int16、'H' uint16、'f' float32)のリストで、入力ごとに1つです。

  • input_scaleinput_zero_point -- ネットワークが学習された実数値の入力と、カメラが実行する整数表現との間を変換する量子化パラメータ

出力テンサー。 入力セットの対応物です。output_shapeoutput_dtypeoutput_scaleoutput_zero_point。検出ネットワークは2つまたは3つの出力テンサー(ボックス、信頼度スコア、場合によってはクラス確率)を生成します。分類ネットワークは1つを生成します。

追加情報。 labels は兄弟関係の .txt ファイルから読み込まれたクラス名のリスト、または None です。postprocess は登録されたポストプロセッサ、または None です。

7.4.3. BlazeFaceの調査

同梱のBlazeFaceモデルを読み込み、各プロパティを出力すると実際の数値が得られます:

import ml
from ml.postprocessing.mediapipe import BlazeFace

model = ml.Model("/rom/blazeface_front_128.tflite",
                 postprocess=BlazeFace())

print("file size:    ", model.len, "bytes")
print("tensor arena: ", model.ram, "bytes")
print("input shape:  ", model.input_shape)
print("input dtype:  ", model.input_dtype)
print("input scale:  ", model.input_scale)
print("input zp:     ", model.input_zero_point)
print("output shape: ", model.output_shape)
print("output dtype: ", model.output_dtype)
print("output scale: ", model.output_scale)
print("output zp:    ", model.output_zero_point)

これらの数値はネットワークのインターフェースを具体的に示しています。単一の (1, 128, 128, 3) int8 入力テンサーと2つの int8 出力 -- 1つはボックス回帰係数用、もう1つはアンカーごとの信頼度スコア用です。量子化パラメータは、それらのint8値がネットワークが学習された実際のfloatにどうマッピングされるかを記述します。ポストプロセッサはそれらを使ってボックスをデコードする前に量子化を元に戻します。

すべてのプロパティは、それが記述する内容についての唯一の信頼できる情報源です。スクリプトは何をキャプチャすべきかを知るために input_shape を読み、テンサーを手動でデコードするために output_scaleoutput_zero_point を読み、人間が読めるクラス名のために labels を読みます -- ハードコードされることも、推測されることも決してありません。