7.4. טעינת מודל¶
ml.Model קוראת קובץ מודל מזיכרון הפלאש (flash), מנתחת אותו, מקצה את ה-RAM שהרשת זקוקה לו במהלך ההסקה, ומחזירה אובייקט שנושא את כל מה ששאר הסקריפט צריך לדעת על הרשת הטעונה.
7.4.1. הבנאי (constructor)¶
הבנאי מקבל נתיב ופוסט-מעבד אופציונלי:
model = ml.Model("/rom/blazeface_front_128.tflite",
postprocess=BlazeFace())
מודלים שב-/rom/ (מערכת הקבצים השוכנת בזיכרון הפלאש) נקראים במקומם: משקלי הרשת נשארים בזיכרון הפלאש (flash), והמודל הטעון מוציא רק את כמות ה-RAM של זירת הטנזורים (tensor arena). מודלים שב-/sdcard/ מועתקים אל ה-RAM בזמן הטעינה, ולכן העלות הכוללת היא גודל קובץ המודל בתוספת זירת הטנזורים. כל אחד מהנתיבים עובד; ההתחשבנות היא על ה-RAM.
אם קיים קובץ .txt אחאי בעל אותו שם בסיס, תוכנו נטען באופן אוטומטי אל labels. מילת המפתח postprocess= רושמת ישות בת-קריאה (callable) ש-predict() מריצה לאחר כל הסקה.
7.4.2. מאפיינים לקריאה בלבד¶
מודל טעון חושף קבוצה קטנה של מאפיינים לקריאה בלבד המתארים את הרשת מבלי שאיש יריץ אותה.
קובץ וזיכרון.
len– גודל קובץ המודל בדיסק, בבייטים.ram– גודל זירת הטנזורים (tensor arena) שהרשת זקוקה לה עבור ההפעלות (activations) הביניימיות שלה במהלך ההסקה, בבייטים.
טנזורי קלט.
input_shape– רשימה של טאפלים, אחד לכל טנזור קלט, המציינת את הצורה שהרשת מצפה לה. לרשתות ראייה יש קלט אחד בצורה(1, H, W, C).input_dtype– רשימה של קודי dtype בני תו בודד ('b'int8,'B'uint8,'h'int16,'H'uint16,'f'float32), אחד לכל קלט.input_scaleו-input_zero_point– פרמטרי הקוונטיזציה הממירים בין הקלט בעל הערך הממשי שעליו הרשת אומנה לבין הייצוג השלם (integer) שמולו המצלמה מריצה.
טנזורי פלט. תמונת ראי של קבוצת הקלט: output_shape, output_dtype, output_scale, output_zero_point. רשתות זיהוי מייצרות שניים או שלושה טנזורי פלט (תיבות, ציוני ביטחון, ולעיתים הסתברויות מחלקה); רשתות סיווג מייצרות אחד.
תוספות. labels היא רשימת שמות המחלקות הנטענת מקובץ ה-.txt האחאי, או None. postprocess הוא הפוסט-מעבד הרשום, או None.
7.4.3. בחינת BlazeFace¶
טעינת מודל BlazeFace המצורף והדפסת כל מאפיין נותנות את המספרים בפועל:
import ml
from ml.postprocessing.mediapipe import BlazeFace
model = ml.Model("/rom/blazeface_front_128.tflite",
postprocess=BlazeFace())
print("file size: ", model.len, "bytes")
print("tensor arena: ", model.ram, "bytes")
print("input shape: ", model.input_shape)
print("input dtype: ", model.input_dtype)
print("input scale: ", model.input_scale)
print("input zp: ", model.input_zero_point)
print("output shape: ", model.output_shape)
print("output dtype: ", model.output_dtype)
print("output scale: ", model.output_scale)
print("output zp: ", model.output_zero_point)
המספרים מזהים באופן קונקרטי את הממשק של הרשת: טנזור קלט יחיד (1, 128, 128, 3) מסוג int8 ושני פלטים מסוג int8 – אחד למקדמי רגרסיית התיבות, ואחד לציוני ביטחון לכל עוגן (anchor). פרמטרי הקוונטיזציה מתארים כיצד אותם ערכי int8 ממופים אל המספרים הממשיים (float) שמולם הרשת אומנה; הפוסט-מעבד משתמש בהם כדי לבטל את הקוונטיזציה לפני פענוח התיבות.
כל מאפיין הוא מקור האמת היחיד לגבי מה שהוא מתאר. סקריפטים קוראים את input_shape כדי לדעת באיזו צורה ללכוד, קוראים את output_scale ואת output_zero_point כדי לפענח טנזורים ידנית, וקוראים את labels עבור שמות מחלקות קריאים לאדם – לעולם לא מקודדים בקשיחות, לעולם לא בהנחה.