7.4. Загрузка модели

ml.Model считывает файл модели из флеш-памяти, разбирает его, выделяет оперативную память, необходимую сети во время вывода, и возвращает объект, содержащий всё, что остальной части скрипта нужно знать о загруженной сети.

7.4.1. Конструктор

Конструктор принимает путь и необязательный постобработчик:

model = ml.Model("/rom/blazeface_front_128.tflite",
                 postprocess=BlazeFace())

Модели в /rom/ (флеш-резидентная файловая система) читаются на месте: веса сети остаются во флеш-памяти, и загруженная модель расходует ОЗУ лишь в объёме арены тензоров. Модели в /sdcard/ копируются в ОЗУ при загрузке, поэтому общая стоимость складывается из размера файла модели и арены тензоров. Оба пути работают; компромисс касается ОЗУ.

Если рядом существует файл .txt с тем же базовым именем, его содержимое автоматически загружается в labels. Ключевое слово postprocess= регистрирует вызываемый объект, который predict() запускает после каждого вывода.

7.4.2. Свойства только для чтения

Загруженная модель предоставляет небольшой набор свойств только для чтения, описывающих сеть без её запуска.

Файл и память.

  • len – размер файла модели на диске в байтах.

  • ram – размер арены тензоров, необходимой сети для промежуточных активаций во время вывода, в байтах.

Входные тензоры.

  • input_shape – список кортежей, по одному на каждый входной тензор, задающий форму, которую ожидает сеть. У сетей машинного зрения один вход с формой (1, H, W, C).

  • input_dtype – список односимвольных кодов dtype ('b' int8, 'B' uint8, 'h' int16, 'H' uint16, 'f' float32), по одному на каждый вход.

  • input_scale и input_zero_pointпараметры квантования, которые преобразуют входные данные с действительными значениями, на которых обучалась сеть, в целочисленное представление, с которым работает камера.

Выходные тензоры. Зеркало входного набора: output_shape, output_dtype, output_scale, output_zero_point. Сети обнаружения формируют два или три выходных тензора (рамки, оценки уверенности, иногда вероятности классов); сети классификации формируют один.

Дополнительно. labels – список имён классов, загруженный из соседнего файла .txt, или None. postprocess – зарегистрированный постобработчик или None.

7.4.3. Изучение BlazeFace

Загрузка поставляемой модели BlazeFace и вывод каждого свойства дают фактические значения:

import ml
from ml.postprocessing.mediapipe import BlazeFace

model = ml.Model("/rom/blazeface_front_128.tflite",
                 postprocess=BlazeFace())

print("file size:    ", model.len, "bytes")
print("tensor arena: ", model.ram, "bytes")
print("input shape:  ", model.input_shape)
print("input dtype:  ", model.input_dtype)
print("input scale:  ", model.input_scale)
print("input zp:     ", model.input_zero_point)
print("output shape: ", model.output_shape)
print("output dtype: ", model.output_dtype)
print("output scale: ", model.output_scale)
print("output zp:    ", model.output_zero_point)

Эти значения конкретно определяют интерфейс сети: один входной тензор int8 формы (1, 128, 128, 3) и два выхода int8 – один для коэффициентов регрессии рамок, другой для оценок уверенности по каждому якорю. Параметры квантования описывают, как эти значения int8 отображаются в действительные числа с плавающей точкой, на которых обучалась сеть; постобработчик использует их, чтобы отменить квантование перед декодированием рамок.

Каждое свойство является единственным источником истины для того, что оно описывает. Скрипты читают input_shape, чтобы знать, с каким разрешением захватывать, читают output_scale и output_zero_point, чтобы вручную декодировать тензоры, и читают labels для получения понятных человеку имён классов – никогда не зашитых в код, никогда не предполагаемых.