7.4. Tải mô hình

ml.Model đọc tệp mô hình (ML) từ bộ nhớ flash, phân tích cú pháp, cấp phát RAM mà mạng nơ-ron cần trong quá trình suy luận, và trả về một đối tượng chứa mọi thông tin mà phần còn lại của tập lệnh cần biết về mạng đã tải.

7.4.1. Hàm khởi tạo

Hàm khởi tạo nhận một đường dẫn và một bộ xử lý hậu kỳ tùy chọn:

model = ml.Model("/rom/blazeface_front_128.tflite",
                 postprocess=BlazeFace())

Các mô hình (ML) trên /rom/ (hệ thống tệp thường trú trong flash) được đọc trực tiếp tại chỗ: trọng số của mạng nơ-ron nằm trong flash và mô hình (ML) đã tải chỉ dùng lượng RAM tương đương vùng tensor. Các mô hình (ML) trên /sdcard/ được sao chép vào RAM khi tải, vì vậy tổng chi phí là kích thước tệp mô hình (ML) cộng với vùng tensor. Cả hai đường dẫn đều hoạt động; sự đánh đổi là RAM.

Nếu tồn tại tệp .txt cùng tên (chỉ khác phần mở rộng), nội dung của nó sẽ được tải tự động vào labels. Từ khóa postprocess= đăng ký một callable mà predict() chạy sau mỗi lần suy luận.

7.4.2. Thuộc tính chỉ đọc

Một mô hình (ML) đã tải cung cấp một tập hợp nhỏ các thuộc tính chỉ đọc mô tả mạng nơ-ron mà không cần thực thi nó.

Tệp và bộ nhớ.

  • len -- kích thước tệp mô hình (ML) trên đĩa, tính bằng byte.

  • ram -- kích thước của vùng tensor mà mạng nơ-ron cần cho các kích hoạt trung gian trong quá trình suy luận, tính bằng byte.

Tensor đầu vào.

  • input_shape -- danh sách các tuple, một phần tử cho mỗi tensor đầu vào, cho biết hình dạng mà mạng nơ-ron mong đợi. Các mạng thị giác có một đầu vào với hình dạng (1, H, W, C).

  • input_dtype -- danh sách các mã dtype một ký tự ('b' int8, 'B' uint8, 'h' int16, 'H' uint16, 'f' float32), một phần tử cho mỗi đầu vào.

  • input_scaleinput_zero_point -- các tham số lượng tử hóa chuyển đổi giữa đầu vào thực mà mạng nơ-ron được huấn luyện và biểu diễn số nguyên mà camera thực thi.

Tensor đầu ra. Phản chiếu của tập đầu vào: output_shape, output_dtype, output_scale, output_zero_point. Các mạng phát hiện tạo ra hai hoặc ba tensor đầu ra (hộp giới hạn, điểm tin cậy, đôi khi xác suất lớp); các mạng phân loại tạo ra một.

Thông tin bổ sung. labels là danh sách tên lớp được tải từ tệp .txt cùng tên, hoặc None. postprocess là bộ xử lý hậu kỳ đã đăng ký, hoặc None.

7.4.3. Kiểm tra BlazeFace

Tải mô hình (ML) BlazeFace đi kèm và in từng thuộc tính sẽ cho ra các con số thực tế:

import ml
from ml.postprocessing.mediapipe import BlazeFace

model = ml.Model("/rom/blazeface_front_128.tflite",
                 postprocess=BlazeFace())

print("file size:    ", model.len, "bytes")
print("tensor arena: ", model.ram, "bytes")
print("input shape:  ", model.input_shape)
print("input dtype:  ", model.input_dtype)
print("input scale:  ", model.input_scale)
print("input zp:     ", model.input_zero_point)
print("output shape: ", model.output_shape)
print("output dtype: ", model.output_dtype)
print("output scale: ", model.output_scale)
print("output zp:    ", model.output_zero_point)

Các con số xác định giao diện của mạng nơ-ron một cách cụ thể: một tensor đầu vào int8 duy nhất (1, 128, 128, 3) và hai đầu ra int8 -- một cho các hệ số hồi quy hộp giới hạn, một cho điểm tin cậy theo từng neo. Các tham số lượng tử hóa mô tả cách các giá trị int8 ánh xạ đến các số thực mà mạng nơ-ron được huấn luyện; bộ xử lý hậu kỳ sử dụng chúng để hoàn tác lượng tử hóa trước khi giải mã các hộp giới hạn.

Mỗi thuộc tính là nguồn thông tin duy nhất và chính xác cho những gì nó mô tả. Các tập lệnh đọc input_shape để biết cần chụp ở độ phân giải nào, đọc output_scaleoutput_zero_point để giải mã tensor thủ công, và đọc labels cho tên lớp dễ đọc -- không bao giờ được mã cứng hay giả định.