mohammed-aljafry commited on
Commit
cbcacad
·
verified ·
1 Parent(s): 29e1e9d

Upload model_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model_utils.py +136 -0
model_utils.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================================
2
+ # model_utils.py - أدوات مساعدة لإدارة نموذج Interfuser
3
+ # ============================================================================
4
+ # هذا الملف مسؤول عن كل العمليات المتعلقة بنموذج PyTorch:
5
+ # 1. العثور على النماذج المتاحة.
6
+ # 2. تحميل نموذج محدد إلى الذاكرة (CPU/GPU).
7
+ # 3. توفير وصول سهل إلى النموذج المحمل حاليًا.
8
+ # هذا يفصل منطق النموذج بشكل كامل عن منطق واجهة المستخدم.
9
+ # ============================================================================
10
+
11
+ import os
12
+ import torch
13
+ import logging
14
+
15
+ # استيراد الأدوات اللازمة من ملف تعريف النموذج
16
+ try:
17
+ from model_definition import load_and_prepare_model, create_model_config
18
+ except ImportError as e:
19
+ print(f"خطأ في الاستيراد: تأكد من وجود ملف model_definition.py. الخطأ: {e}")
20
+ exit()
21
+
22
+ # --- المتغيرات العامة الخاصة بالنموذج ---
23
+
24
+ # الدليل الذي يحتوي على ملفات النماذج
25
+ MODEL_DIR = "model"
26
+
27
+ # تحديد الجهاز تلقائيًا (سيستخدم GPU إذا كان متاحًا)
28
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+
30
+ # هذه المتغيرات ستحتفظ بالنموذج المحمل حاليًا في الذاكرة لتجنب إعادة التحميل
31
+ CURRENTLY_LOADED_MODEL: torch.nn.Module = None
32
+ CURRENT_MODEL_NAME: str = None
33
+
34
+ # إعداد نظام التسجيل لمتابعة عمليات التحميل
35
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
36
+
37
+
38
+ def get_available_models():
39
+ """
40
+ يبحث في مجلد 'model' ويعيد قائمة بأسماء النماذج المتاحة.
41
+
42
+ Returns:
43
+ list[str]: قائمة بأسماء الملفات للنماذج المتاحة.
44
+ """
45
+ if not os.path.isdir(MODEL_DIR):
46
+ logging.warning(f"مجلد النماذج '{MODEL_DIR}' غير موجود.")
47
+ return []
48
+
49
+ try:
50
+ models = [f for f in os.listdir(MODEL_DIR) if f.endswith(('.pth', '.pt'))]
51
+ logging.info(f"تم العثور على النماذج التالية: {models}")
52
+ return models
53
+ except Exception as e:
54
+ logging.error(f"حدث خطأ أثناء قراءة مجلد النماذج: {e}")
55
+ return []
56
+
57
+
58
+ def load_model_by_name(model_name: str):
59
+ """
60
+ يحمل نموذجًا محددًا بالاسم. إذا كان النموذج المطلوب محملًا بالفعل،
61
+ فإنه يتخطى عملية التحميل.
62
+
63
+ Args:
64
+ model_name (str): اسم ملف النموذج المراد تحميله (e.g., 'best_model.pth').
65
+
66
+ Returns:
67
+ str: رسالة نصية تشير إلى حالة عملية التحميل.
68
+ """
69
+ global CURRENTLY_LOADED_MODEL, CURRENT_MODEL_NAME
70
+
71
+ if not model_name:
72
+ return "لم يتم اختيار نموذج."
73
+
74
+ # إذا كان النموذج المطلوب هو نفسه المحمل حاليًا، فلا داعي لفعل أي شيء
75
+ if model_name == CURRENT_MODEL_NAME and CURRENTLY_LOADED_MODEL is not None:
76
+ message = f"النموذج '{model_name}' محمل بالفعل."
77
+ logging.info(message)
78
+ return message
79
+
80
+ logging.info(f"بدء تحميل النموذج: '{model_name}' على الجهاز {DEVICE}...")
81
+ model_path = os.path.join(MODEL_DIR, model_name)
82
+
83
+ if not os.path.exists(model_path):
84
+ error_message = f"ملف النموذج '{model_path}' غير موجود."
85
+ logging.error(error_message)
86
+ # تفريغ النموذج الحالي إذا كان المسار خاطئًا
87
+ CURRENTLY_LOADED_MODEL = None
88
+ CURRENT_MODEL_NAME = None
89
+ raise FileNotFoundError(error_message)
90
+
91
+ try:
92
+ # استخدام الدوال من model_definition.py لإنشاء وتحميل النموذج
93
+ model_config = create_model_config(model_path=model_path)
94
+ model = load_and_prepare_model(model_config, DEVICE)
95
+
96
+ # تحديث المتغيرات العامة بالنموذج الجديد
97
+ CURRENTLY_LOADED_MODEL = model
98
+ CURRENT_MODEL_NAME = model_name
99
+
100
+ success_message = f"✅ تم تحميل النموذج بنجاح: {model_name}"
101
+ logging.info(success_message)
102
+ return success_message
103
+
104
+ except Exception as e:
105
+ logging.error(f"❌ حدث خطأ فادح أثناء تحميل النموذج '{model_name}': {e}", exc_info=True)
106
+ # إعادة تعيين المتغيرات العامة في حالة الفشل
107
+ CURRENTLY_LOADED_MODEL = None
108
+ CURRENT_MODEL_NAME = None
109
+ # إرسال الخطأ للأعلى ليتم عرضه في واجهة Gradio
110
+ raise e
111
+
112
+
113
+ def get_current_model() -> torch.nn.Module:
114
+ """
115
+ يعيد كائن النموذج المحمل حاليًا.
116
+ إذا لم يكن هناك نموذج محمل، يحاول تحميل أول نموذج متاح كخيار افتراضي.
117
+
118
+ Returns:
119
+ torch.nn.Module or None: كائن النموذج المحمل أو None إذا فشل التحميل.
120
+ """
121
+ if CURRENTLY_LOADED_MODEL is None:
122
+ logging.info("لا يوجد نموذج محمل حاليًا. محاولة تحميل النموذج الافتراضي...")
123
+ available_models = get_available_models()
124
+
125
+ if available_models:
126
+ # محاولة تحميل أول نموذج في القائمة
127
+ try:
128
+ load_model_by_name(available_models[0])
129
+ except Exception as e:
130
+ logging.error(f"فشل تحميل النموذج الافتراضي '{available_models[0]}': {e}")
131
+ return None
132
+ else:
133
+ logging.warning("لا توجد نماذج متاحة في مجلد 'model' لتحميلها.")
134
+ return None
135
+
136
+ return CURRENTLY_LOADED_MODEL