7.4. Wczytywanie modelu¶
ml.Model odczytuje plik modelu z pamięci flash, parsuje go, alokuje pamięć RAM, której sieć potrzebuje podczas wnioskowania, i zwraca obiekt zawierający wszystko, co reszta skryptu musi wiedzieć o wczytanej sieci.
7.4.1. Konstruktor¶
Konstruktor przyjmuje ścieżkę oraz opcjonalny post-procesor:
model = ml.Model("/rom/blazeface_front_128.tflite",
postprocess=BlazeFace())
Modele w /rom/ (systemie plików rezydującym w pamięci flash) są odczytywane w miejscu: wagi sieci pozostają w pamięci flash, a wczytany model zajmuje jedynie tyle RAM, ile wymaga arena tensorów. Modele w /sdcard/ są kopiowane do RAM w momencie wczytywania, więc całkowity koszt to rozmiar pliku modelu plus arena tensorów. Obie ścieżki działają; kompromisem jest pamięć RAM.
Jeśli istnieje sąsiadujący plik .txt o tej samej nazwie bazowej, jego zawartość jest automatycznie wczytywana do labels. Słowo kluczowe postprocess= rejestruje obiekt wywoływalny, który predict() uruchamia po każdym wnioskowaniu.
7.4.2. Właściwości tylko do odczytu¶
Wczytany model udostępnia niewielki zestaw właściwości tylko do odczytu, które opisują sieć, zanim ktokolwiek ją uruchomi.
Plik i pamięć.
len– rozmiar pliku modelu na dysku, w bajtach.ram– rozmiar areny tensorów, której sieć potrzebuje na swoje pośrednie aktywacje podczas wnioskowania, w bajtach.
Tensory wejściowe.
input_shape– lista krotek, po jednej na każdy tensor wejściowy, podająca kształt oczekiwany przez sieć. Sieci wizyjne mają jedno wejście o kształcie(1, H, W, C).input_dtype– lista jednoznakowych kodów dtype ('b'int8,'B'uint8,'h'int16,'H'uint16,'f'float32), po jednym na każde wejście.input_scaleorazinput_zero_point– parametry kwantyzacji, które konwertują pomiędzy wartościami rzeczywistymi, na których trenowano sieć, a reprezentacją całkowitoliczbową, na której operuje kamera.
Tensory wyjściowe. Odbicie zestawu wejściowego: output_shape, output_dtype, output_scale, output_zero_point. Sieci wykrywające produkują dwa lub trzy tensory wyjściowe (ramki, wyniki pewności, czasem prawdopodobieństwa klas); sieci klasyfikujące produkują jeden.
Dodatki. labels to lista nazw klas wczytana z sąsiadującego pliku .txt lub None. postprocess to zarejestrowany post-procesor lub None.
7.4.3. Inspekcja BlazeFace¶
Wczytanie dostarczanego modelu BlazeFace i wypisanie każdej właściwości daje rzeczywiste liczby:
import ml
from ml.postprocessing.mediapipe import BlazeFace
model = ml.Model("/rom/blazeface_front_128.tflite",
postprocess=BlazeFace())
print("file size: ", model.len, "bytes")
print("tensor arena: ", model.ram, "bytes")
print("input shape: ", model.input_shape)
print("input dtype: ", model.input_dtype)
print("input scale: ", model.input_scale)
print("input zp: ", model.input_zero_point)
print("output shape: ", model.output_shape)
print("output dtype: ", model.output_dtype)
print("output scale: ", model.output_scale)
print("output zp: ", model.output_zero_point)
Liczby konkretnie identyfikują interfejs sieci: pojedynczy tensor wejściowy int8 o kształcie (1, 128, 128, 3) oraz dwa wyjścia int8 – jedno dla współczynników regresji ramek, drugie dla wyników pewności na każdą kotwicę. Parametry kwantyzacji opisują, jak te wartości int8 odwzorowują się na rzeczywiste liczby zmiennoprzecinkowe, na których trenowano sieć; post-procesor używa ich do cofnięcia kwantyzacji przed dekodowaniem ramek.
Każda właściwość jest jedynym źródłem prawdy o tym, co opisuje. Skrypty odczytują input_shape, aby wiedzieć, w jakiej rozdzielczości przechwytywać, odczytują output_scale oraz output_zero_point, aby ręcznie dekodować tensory, i odczytują labels, aby uzyskać czytelne dla człowieka nazwy klas – nigdy nie zaszyte na sztywno, nigdy nie zakładane.