mohammed-aljafry commited on
Commit
30e282e
·
verified ·
1 Parent(s): c38878a

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +43 -70
app.py CHANGED
@@ -11,9 +11,7 @@ import cv2
11
  import math
12
 
13
  # --- استيراد من الملفات المنظمة في مشروعك ---
14
- # نفترض أن بنية النموذج موجودة في model/architecture.py
15
- from model import build_interfuser_model
16
- # نفترض أن بقية المنطق موجود في logic.py
17
  from logic import (
18
  transform, lidar_transform, InterfuserController, ControllerConfig,
19
  Tracker, DisplayInterface, render, render_waypoints, render_self_car,
@@ -23,30 +21,23 @@ from logic import (
23
  # ==============================================================================
24
  # 1. إعدادات ومسارات النماذج
25
  # ==============================================================================
26
- WEIGHTS_DIR = "model"
27
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
 
29
- # متغير عام لتخزين النموذج المحمّل حاليًا
30
- current_model = None
31
-
32
  # قاموس لتحديد الإعدادات الخاصة بكل نموذج.
33
- # اسم المفتاح يجب أن يطابق اسم ملف الأوزان (بدون .pth).
34
- # إذا لم يتم تحديد إعدادات لنموذج ما، سيتم استخدام الإعدادات الافتراضية في دالة البناء.
35
  MODELS_SPECIFIC_CONFIGS = {
36
  "interfuser_baseline": {
37
  "rgb_backbone_name": "r50",
38
  "embed_dim": 256,
39
- "direct_concat": True, # هذا النموذج يتوقع دمج الصور
40
  },
41
  "interfuser_lightweight": {
42
  "rgb_backbone_name": "r26",
43
  "embed_dim": 128,
44
  "enc_depth": 4,
45
  "dec_depth": 4,
46
- "direct_concat": True, # هذا النموذج يتوقع دمج الصور
47
  }
48
- # أضف هنا أي إعدادات لنماذج أخرى لديك
49
- # "my_other_model": { "direct_concat": False, ... }
50
  }
51
 
52
  def find_available_models():
@@ -56,53 +47,45 @@ def find_available_models():
56
  if not os.path.isdir(WEIGHTS_DIR):
57
  print(f"تحذير: مجلد الأوزان '{WEIGHTS_DIR}' غير موجود.")
58
  return []
59
-
60
- models = [f.replace(".pth", "") for f in os.listdir(WEIGHTS_DIR) if f.endswith(".pth")]
61
- return models
62
 
63
  # ==============================================================================
64
- # 2. دالة تحميل النموذج الديناميكية
65
  # ==============================================================================
66
  def load_model(model_name: str):
67
  """
68
- تحمل النموذج المحدد من القائمة المنسدلة وتضعه في المتغير العام current_model.
69
  """
70
- global current_model
71
-
72
- if not model_name:
73
- return "الرجاء اختيار نموذج من القائمة."
74
 
75
  weights_path = os.path.join(WEIGHTS_DIR, f"{model_name}.pth")
76
- print(f"Attempting to load model: '{model_name}' from '{weights_path}'")
77
 
78
- # الحصول على الإعدادات المخصصة للنموذج، أو قاموس فارغ إذا لم توجد
79
  model_config = MODELS_SPECIFIC_CONFIGS.get(model_name, {})
80
-
81
- # بناء النموذج باستخدام الإعدادات المحددة
82
  model = build_interfuser_model(model_config)
83
 
84
  if not os.path.exists(weights_path):
85
- gr.Warning(f"ملف الأوزان '{weights_path}' غير موجود. سيتم استخدام النموذج بأوزان عشوائية.")
86
  else:
87
  try:
88
- # استخدام weights_only=True للأمان
89
  state_dic = torch.load(weights_path, map_location=device, weights_only=True)
90
  model.load_state_dict(state_dic)
91
  print(f"تم تحميل أوزان النموذج '{model_name}' بنجاح.")
92
  except Exception as e:
93
- gr.Warning(f"فشل تحميل الأوزان للنموذج '{model_name}': {e}. تأكد من تطابق الإعدادات في 'MODELS_SPECIFIC_CONFIGS' مع الملف المحفوظ. سيتم استخدام أوزان عشوائية.")
94
 
95
  model.to(device)
96
  model.eval()
97
 
98
- current_model = model # تحديث النموذج العام
99
-
100
- return f"تم تحميل نموذج: {model_name}"
101
 
102
  # ==============================================================================
103
- # 3. دالة التشغيل الرئيسية لـ Gradio
104
  # ==============================================================================
105
  def run_single_frame(
 
106
  rgb_image_path,
107
  rgb_left_image_path,
108
  rgb_right_image_path,
@@ -111,40 +94,36 @@ def run_single_frame(
111
  measurements_path,
112
  target_point_list
113
  ):
114
- global current_model
115
-
116
- if current_model is None:
117
- raise gr.Error("الرجاء اختيار وتحميل نموذج أولاً من القائمة المنسدلة.")
118
 
119
  try:
120
  # --- 1. قراءة ومعالجة المدخلات ---
121
- if not rgb_image_path:
122
- raise gr.Error("الرجاء توفير مسار الصورة الأمامية (RGB).")
123
 
124
  rgb_image_pil = Image.open(rgb_image_path.name).convert("RGB")
 
125
  rgb_left_pil = Image.open(rgb_left_image_path.name).convert("RGB") if rgb_left_image_path else rgb_image_pil
126
  rgb_right_pil = Image.open(rgb_right_image_path.name).convert("RGB") if rgb_right_image_path else rgb_image_pil
127
  rgb_center_pil = Image.open(rgb_center_image_path.name).convert("RGB") if rgb_center_image_path else rgb_image_pil
128
 
129
- # تطبيق التحويلات لتحويل الصور إلى تنسورات
130
  front_tensor = transform(rgb_image_pil).unsqueeze(0).to(device)
131
  left_tensor = transform(rgb_left_pil).unsqueeze(0).to(device)
132
  right_tensor = transform(rgb_right_pil).unsqueeze(0).to(device)
133
  center_tensor = transform(rgb_center_pil).unsqueeze(0).to(device)
134
-
135
  if lidar_image_path:
136
  lidar_array = np.load(lidar_image_path.name)
137
- if lidar_array.max() > 0:
138
- lidar_array = (lidar_array / lidar_array.max()) * 255.0
139
  lidar_pil = Image.fromarray(lidar_array.astype(np.uint8)).convert('RGB')
140
  else:
141
  lidar_pil = Image.fromarray(np.zeros((112, 112, 3), dtype=np.uint8))
142
  lidar_tensor = lidar_transform(lidar_pil).unsqueeze(0).to(device)
143
 
144
- with open(measurements_path.name, 'r') as f:
145
- m_dict = json.load(f)
146
 
147
- # إنشاء تنسور القياسات الصحيح (10 عناصر)
148
  measurements_tensor = torch.tensor([[
149
  m_dict.get('x', 0.0), m_dict.get('y', 0.0), m_dict.get('theta', 0.0),
150
  m_dict.get('speed', 5.0), m_dict.get('steer', 0.0), m_dict.get('throttle', 0.0),
@@ -154,20 +133,15 @@ def run_single_frame(
154
 
155
  target_point_tensor = torch.tensor([target_point_list], dtype=torch.float32).to(device)
156
 
157
- # تجميع المدخلات للنموذج
158
  inputs = {
159
- 'rgb': front_tensor, # للنماذج التي لا تدمج
160
- 'rgb_left': left_tensor,
161
- 'rgb_right': right_tensor,
162
- 'rgb_center': center_tensor,
163
- 'lidar': lidar_tensor,
164
- 'measurements': measurements_tensor,
165
- 'target_point': target_point_tensor
166
  }
167
 
168
  # --- 2. تشغيل النموذج ---
169
  with torch.no_grad():
170
- outputs = current_model(inputs)
171
  traffic, waypoints, is_junction, traffic_light, stop_sign, _ = outputs
172
 
173
  # --- 3. المعالجة اللاحقة والتصوّر ---
@@ -182,18 +156,16 @@ def run_single_frame(
182
 
183
  controller = InterfuserController(ControllerConfig())
184
  steer, throttle, brake, metadata = controller.run_step(
185
- speed=speed, waypoints=waypoints_np, junction=is_junction.sigmoid()[0, 1].item(),
186
- traffic_light_state=traffic_light.sigmoid()[0, 0].item(),
187
- stop_sign=stop_sign.sigmoid()[0, 1].item(), meta_data=updated_traffic
188
  )
189
 
190
- # إنشاء لوحة التحكم المرئية
191
  map_t0, counts_t0 = render(updated_traffic, t=0)
192
  map_t1, counts_t1 = render(updated_traffic, t=T1_FUTURE_TIME)
193
  map_t2, counts_t2 = render(updated_traffic, t=T2_FUTURE_TIME)
194
 
195
  wp_map = render_waypoints(waypoints_np)
196
- self_car_map = render_self_car(loc=np.array([0,0]), ori=[math.cos(0), math.sin(0)], box=[4.0, 2.0])
197
 
198
  map_t0 = cv2.add(cv2.add(map_t0, wp_map), self_car_map)
199
  map_t0 = cv2.resize(map_t0, (400, 400))
@@ -201,15 +173,11 @@ def run_single_frame(
201
  map_t2 = cv2.add(ensure_rgb(map_t2), ensure_rgb(self_car_map)); map_t2 = cv2.resize(map_t2, (200, 200))
202
 
203
  display = DisplayInterface()
204
- light_state = "Red" if traffic_light.sigmoid()[0,0].item() > 0.5 else "Green"
205
- stop_sign_state = "Yes" if stop_sign.sigmoid()[0,1].item() > 0.5 else "No"
206
 
207
  interface_data = {
208
  'camera_view': np.array(rgb_image_pil), 'map_t0': map_t0, 'map_t1': map_t1, 'map_t2': map_t2,
209
- 'text_info': {
210
- 'Frame': 'API Frame', 'Control': f"S:{steer:.2f} T:{throttle:.2f} B:{int(brake)}",
211
- 'Light': f"L: {light_state}", 'Stop': f"St: {stop_sign_state}"
212
- },
213
  'object_counts': {'t0': counts_t0, 't1': counts_t1, 't2': counts_t2}
214
  }
215
 
@@ -233,12 +201,14 @@ def run_single_frame(
233
  # 4. تعريف واجهة Gradio
234
  # ==============================================================================
235
 
236
- # البحث عن النماذج المتاحة عند بدء تشغيل الواجهة
237
  available_models = find_available_models()
238
 
239
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
240
  gr.Markdown("# 🚗 محاكاة القيادة الذاتية باستخدام Interfuser")
241
 
 
 
 
242
  with gr.Row():
243
  model_selector = gr.Dropdown(
244
  label="اختر النموذج من مجلد 'model/weights'",
@@ -249,14 +219,15 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
249
 
250
  # التحميل الأولي والتحميل عند التغيير
251
  if available_models:
252
- demo.load(fn=load_model, inputs=model_selector, outputs=status_textbox)
253
- model_selector.change(fn=load_model, inputs=model_selector, outputs=status_textbox)
 
254
 
255
  gr.Markdown("---")
256
 
257
  with gr.Tabs():
258
  with gr.TabItem("نقطة نهاية API (إطار واحد)", id=1):
259
- gr.Markdown("### اختبار النموذج بإدخال مباشر (Single Frame Inference)")
260
 
261
  with gr.Row():
262
  with gr.Column(scale=1):
@@ -278,6 +249,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
278
  api_run_button.click(
279
  fn=run_single_frame,
280
  inputs=[
 
281
  api_rgb_image_path, api_rgb_left_image_path, api_rgb_right_image_path,
282
  api_rgb_center_image_path, api_lidar_image_path,
283
  api_measurements_path, api_target_point_list
@@ -293,4 +265,5 @@ if __name__ == "__main__":
293
  if not available_models:
294
  print("تحذير: لم يتم العثور على أي ملفات نماذج (.pth) في مجلد 'model/weights'.")
295
  print("سيتم تشغيل الواجهة ولكن لن تتمكن من تحميل أي نموذج.")
 
296
  demo.queue().launch(debug=True)
 
11
  import math
12
 
13
  # --- استيراد من الملفات المنظمة في مشروعك ---
14
+ from model.architecture import build_interfuser_model
 
 
15
  from logic import (
16
  transform, lidar_transform, InterfuserController, ControllerConfig,
17
  Tracker, DisplayInterface, render, render_waypoints, render_self_car,
 
21
  # ==============================================================================
22
  # 1. إعدادات ومسارات النماذج
23
  # ==============================================================================
24
+ WEIGHTS_DIR = "model/weights"
25
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
 
 
 
 
27
  # قاموس لتحديد الإعدادات الخاصة بكل نموذج.
 
 
28
  MODELS_SPECIFIC_CONFIGS = {
29
  "interfuser_baseline": {
30
  "rgb_backbone_name": "r50",
31
  "embed_dim": 256,
32
+ "direct_concat": True,
33
  },
34
  "interfuser_lightweight": {
35
  "rgb_backbone_name": "r26",
36
  "embed_dim": 128,
37
  "enc_depth": 4,
38
  "dec_depth": 4,
39
+ "direct_concat": True,
40
  }
 
 
41
  }
42
 
43
  def find_available_models():
 
47
  if not os.path.isdir(WEIGHTS_DIR):
48
  print(f"تحذير: مجلد الأوزان '{WEIGHTS_DIR}' غير موجود.")
49
  return []
50
+ return [f.replace(".pth", "") for f in os.listdir(WEIGHTS_DIR) if f.endswith(".pth")]
 
 
51
 
52
  # ==============================================================================
53
+ # 2. دالة تحميل النموذج (لا تستخدم متغيرات عامة)
54
  # ==============================================================================
55
  def load_model(model_name: str):
56
  """
57
+ تبني وتحمل النموذج المختار وتُرجعه ككائن.
58
  """
59
+ if not model_name or "لم يتم" in model_name:
60
+ return None, "الرجاء اختيار نموذج صالح."
 
 
61
 
62
  weights_path = os.path.join(WEIGHTS_DIR, f"{model_name}.pth")
63
+ print(f"Building model: '{model_name}'")
64
 
 
65
  model_config = MODELS_SPECIFIC_CONFIGS.get(model_name, {})
 
 
66
  model = build_interfuser_model(model_config)
67
 
68
  if not os.path.exists(weights_path):
69
+ gr.Warning(f"ملف الأوزان '{weights_path}' غير موجود. النموذج سيعمل بأوزان عشوائية.")
70
  else:
71
  try:
 
72
  state_dic = torch.load(weights_path, map_location=device, weights_only=True)
73
  model.load_state_dict(state_dic)
74
  print(f"تم تحميل أوزان النموذج '{model_name}' بنجاح.")
75
  except Exception as e:
76
+ gr.Warning(f"فشل تحميل الأوزان للنموذج '{model_name}': {e}.")
77
 
78
  model.to(device)
79
  model.eval()
80
 
81
+ # إرجاع كائن النموذج نفسه + ر��الة للمستخدم
82
+ return model, f"تم تحميل نموذج: {model_name}"
 
83
 
84
  # ==============================================================================
85
+ # 3. دالة التشغيل الرئيسية (تستقبل النموذج كمدخل)
86
  # ==============================================================================
87
  def run_single_frame(
88
+ model_from_state, # <-- مدخل جديد من gr.State
89
  rgb_image_path,
90
  rgb_left_image_path,
91
  rgb_right_image_path,
 
94
  measurements_path,
95
  target_point_list
96
  ):
97
+ # لم نعد نستخدم المتغير العام، بل نستخدم النموذج الذي تم تمريره
98
+ if model_from_state is None:
99
+ raise gr.Error("الرجاء اختيار وتحميل نموذج صالح أولاً من القائمة المنسدلة.")
 
100
 
101
  try:
102
  # --- 1. قراءة ومعالجة المدخلات ---
103
+ if not (rgb_image_path and measurements_path):
104
+ raise gr.Error("الرجاء توفير الصورة الأمامية وملف القياسات على الأقل.")
105
 
106
  rgb_image_pil = Image.open(rgb_image_path.name).convert("RGB")
107
+ # بقية معالجة المدخلات
108
  rgb_left_pil = Image.open(rgb_left_image_path.name).convert("RGB") if rgb_left_image_path else rgb_image_pil
109
  rgb_right_pil = Image.open(rgb_right_image_path.name).convert("RGB") if rgb_right_image_path else rgb_image_pil
110
  rgb_center_pil = Image.open(rgb_center_image_path.name).convert("RGB") if rgb_center_image_path else rgb_image_pil
111
 
 
112
  front_tensor = transform(rgb_image_pil).unsqueeze(0).to(device)
113
  left_tensor = transform(rgb_left_pil).unsqueeze(0).to(device)
114
  right_tensor = transform(rgb_right_pil).unsqueeze(0).to(device)
115
  center_tensor = transform(rgb_center_pil).unsqueeze(0).to(device)
116
+
117
  if lidar_image_path:
118
  lidar_array = np.load(lidar_image_path.name)
119
+ if lidar_array.max() > 0: lidar_array = (lidar_array / lidar_array.max()) * 255.0
 
120
  lidar_pil = Image.fromarray(lidar_array.astype(np.uint8)).convert('RGB')
121
  else:
122
  lidar_pil = Image.fromarray(np.zeros((112, 112, 3), dtype=np.uint8))
123
  lidar_tensor = lidar_transform(lidar_pil).unsqueeze(0).to(device)
124
 
125
+ with open(measurements_path.name, 'r') as f: m_dict = json.load(f)
 
126
 
 
127
  measurements_tensor = torch.tensor([[
128
  m_dict.get('x', 0.0), m_dict.get('y', 0.0), m_dict.get('theta', 0.0),
129
  m_dict.get('speed', 5.0), m_dict.get('steer', 0.0), m_dict.get('throttle', 0.0),
 
133
 
134
  target_point_tensor = torch.tensor([target_point_list], dtype=torch.float32).to(device)
135
 
 
136
  inputs = {
137
+ 'rgb': front_tensor, 'rgb_left': left_tensor, 'rgb_right': right_tensor,
138
+ 'rgb_center': center_tensor, 'lidar': lidar_tensor,
139
+ 'measurements': measurements_tensor, 'target_point': target_point_tensor
 
 
 
 
140
  }
141
 
142
  # --- 2. تشغيل النموذج ---
143
  with torch.no_grad():
144
+ outputs = model_from_state(inputs)
145
  traffic, waypoints, is_junction, traffic_light, stop_sign, _ = outputs
146
 
147
  # --- 3. المعالجة اللاحقة والتصوّر ---
 
156
 
157
  controller = InterfuserController(ControllerConfig())
158
  steer, throttle, brake, metadata = controller.run_step(
159
+ speed, waypoints_np, is_junction.sigmoid()[0, 1].item(),
160
+ traffic_light.sigmoid()[0, 0].item(), stop_sign.sigmoid()[0, 1].item(), updated_traffic
 
161
  )
162
 
 
163
  map_t0, counts_t0 = render(updated_traffic, t=0)
164
  map_t1, counts_t1 = render(updated_traffic, t=T1_FUTURE_TIME)
165
  map_t2, counts_t2 = render(updated_traffic, t=T2_FUTURE_TIME)
166
 
167
  wp_map = render_waypoints(waypoints_np)
168
+ self_car_map = render_self_car(np.array([0,0]), [math.cos(0), math.sin(0)], [4.0, 2.0])
169
 
170
  map_t0 = cv2.add(cv2.add(map_t0, wp_map), self_car_map)
171
  map_t0 = cv2.resize(map_t0, (400, 400))
 
173
  map_t2 = cv2.add(ensure_rgb(map_t2), ensure_rgb(self_car_map)); map_t2 = cv2.resize(map_t2, (200, 200))
174
 
175
  display = DisplayInterface()
176
+ light_state, stop_sign_state = "Red" if traffic_light.sigmoid()[0,0].item() > 0.5 else "Green", "Yes" if stop_sign.sigmoid()[0,1].item() > 0.5 else "No"
 
177
 
178
  interface_data = {
179
  'camera_view': np.array(rgb_image_pil), 'map_t0': map_t0, 'map_t1': map_t1, 'map_t2': map_t2,
180
+ 'text_info': { 'Control': f"S:{steer:.2f} T:{throttle:.2f} B:{int(brake)}", 'Light': f"L: {light_state}", 'Stop': f"St: {stop_sign_state}" },
 
 
 
181
  'object_counts': {'t0': counts_t0, 't1': counts_t1, 't2': counts_t2}
182
  }
183
 
 
201
  # 4. تعريف واجهة Gradio
202
  # ==============================================================================
203
 
 
204
  available_models = find_available_models()
205
 
206
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
207
  gr.Markdown("# 🚗 محاكاة القيادة الذاتية باستخدام Interfuser")
208
 
209
+ # مكون الحالة الخفي لتخزين النموذج الخاص بكل جلسة
210
+ model_state = gr.State(value=None)
211
+
212
  with gr.Row():
213
  model_selector = gr.Dropdown(
214
  label="اختر النموذج من مجلد 'model/weights'",
 
219
 
220
  # التحميل الأولي والتحميل عند التغيير
221
  if available_models:
222
+ demo.load(fn=load_model, inputs=model_selector, outputs=[model_state, status_textbox])
223
+
224
+ model_selector.change(fn=load_model, inputs=model_selector, outputs=[model_state, status_textbox])
225
 
226
  gr.Markdown("---")
227
 
228
  with gr.Tabs():
229
  with gr.TabItem("نقطة نهاية API (إطار واحد)", id=1):
230
+ gr.Markdown("### اختبار النموذج بإدخال مباشر")
231
 
232
  with gr.Row():
233
  with gr.Column(scale=1):
 
249
  api_run_button.click(
250
  fn=run_single_frame,
251
  inputs=[
252
+ model_state, # تمرير الحالة كأول مدخل
253
  api_rgb_image_path, api_rgb_left_image_path, api_rgb_right_image_path,
254
  api_rgb_center_image_path, api_lidar_image_path,
255
  api_measurements_path, api_target_point_list
 
265
  if not available_models:
266
  print("تحذير: لم يتم العثور على أي ملفات نماذج (.pth) في مجلد 'model/weights'.")
267
  print("سيتم تشغيل الواجهة ولكن لن تتمكن من تحميل أي نموذج.")
268
+ # .queue() ضروري للتعامل مع الجلسات المتعددة بشكل صحيح
269
  demo.queue().launch(debug=True)