mohammed-aljafry commited on
Commit
67bc258
·
verified ·
1 Parent(s): d75aba6

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +16 -29
app.py CHANGED
@@ -5,7 +5,6 @@ import json
5
  import traceback
6
  import torch
7
  import gradio as gr
8
-
9
  import numpy as np
10
  from PIL import Image
11
  import cv2
@@ -38,8 +37,6 @@ def find_available_models():
38
  # ==============================================================================
39
  # 2. الدوال الأساسية (load_model, run_single_frame)
40
  # ==============================================================================
41
- # (هذه الدوال تبقى كما هي من الإصدار السابق الذي يدعم الجلسات)
42
-
43
  def load_model(model_name: str):
44
  if not model_name or "لم يتم" in model_name:
45
  return None, "الرجاء اختيار نموذج صالح."
@@ -48,16 +45,14 @@ def load_model(model_name: str):
48
  model_config = MODELS_SPECIFIC_CONFIGS.get(model_name, {})
49
  model = build_interfuser_model(model_config)
50
  if not os.path.exists(weights_path):
51
- # استبدال gr.Warning برمي استثناء خطأ
52
- raise gr.Error(f"ملف الأوزان '{weights_path}' غير موجود.")
53
  else:
54
  try:
55
  state_dic = torch.load(weights_path, map_location=device, weights_only=True)
56
  model.load_state_dict(state_dic)
57
  print(f"تم تحميل أوزان النموذج '{model_name}' بنجاح.")
58
  except Exception as e:
59
- # استبدال gr.Warning برمي استثناء خطأ
60
- raise gr.Error(f"فشل تحميل الأوزان للنموذج '{model_name}': {e}.")
61
  model.to(device)
62
  model.eval()
63
  return model, f"تم تحميل نموذج: {model_name}"
@@ -72,7 +67,6 @@ def run_single_frame(
72
  if not (rgb_image_path and measurements_path):
73
  raise gr.Error("الرجاء توفير الصورة الأمامية وملف القياسات على الأقل.")
74
 
75
- # --- 1. معالجة المدخلات ---
76
  rgb_image_pil = Image.open(rgb_image_path).convert("RGB")
77
  rgb_left_pil = Image.open(rgb_left_image_path).convert("RGB") if rgb_left_image_path else rgb_image_pil
78
  rgb_right_pil = Image.open(rgb_right_image_path).convert("RGB") if rgb_right_image_path else rgb_image_pil
@@ -107,19 +101,16 @@ def run_single_frame(
107
  'measurements': measurements_tensor, 'target_point': target_point_tensor
108
  }
109
 
110
- # --- 2. تشغيل النموذج ---
111
  with torch.no_grad():
112
  outputs = model_from_state(inputs)
113
  traffic, waypoints, is_junction, traffic_light, stop_sign, _ = outputs
114
-
115
- # --- 3. المعالجة اللاحقة والتصوّر ---
116
  speed, pos, theta = m_dict.get('speed',5.0), [m_dict.get('x',0.0), m_dict.get('y',0.0)], m_dict.get('theta',0.0)
117
  traffic_np, waypoints_np = traffic[0].detach().cpu().numpy().reshape(20,20,-1), waypoints[0].detach().cpu().numpy() * WAYPOINT_SCALE_FACTOR
118
  tracker, controller = Tracker(), InterfuserController(ControllerConfig())
119
  updated_traffic = tracker.update_and_predict(traffic_np.copy(), pos, theta, 0)
120
  steer, throttle, brake, metadata = controller.run_step(speed, waypoints_np, is_junction.sigmoid()[0,1].item(), traffic_light.sigmoid()[0,0].item(), stop_sign.sigmoid()[0,1].item(), updated_traffic)
121
 
122
- # ... (بقية الكود الخاص بالرسم والتصوّر لا يتغير) ...
123
  map_t0, counts_t0 = render(updated_traffic, t=0)
124
  map_t1, counts_t1 = render(updated_traffic, t=T1_FUTURE_TIME)
125
  map_t2, counts_t2 = render(updated_traffic, t=T2_FUTURE_TIME)
@@ -136,7 +127,6 @@ def run_single_frame(
136
  'object_counts': {'t0': counts_t0,'t1': counts_t1,'t2': counts_t2}}
137
  dashboard_image = display.run_interface(interface_data)
138
 
139
- # --- 4. تجهيز المخرجات ---
140
  result_dict = {"predicted_waypoints": waypoints_np.tolist(), "control_commands": {"steer": steer,"throttle": throttle,"brake": bool(brake)},
141
  "perception": {"traffic_light_status": light_state,"stop_sign_detected": (stop_sign_state == "Yes"),"is_at_junction_prob": round(is_junction.sigmoid()[0,1].item(), 3)},
142
  "metadata": {"speed_info": metadata[0],"perception_info": metadata[1],"stop_info": metadata[2],"safe_distance": metadata[3]}}
@@ -147,12 +137,11 @@ def run_single_frame(
147
  raise gr.Error(f"حدث خطأ أثناء معالجة الإطار: {e}")
148
 
149
  # ==============================================================================
150
- # 4. تعريف واجهة Gradio المحسّنة
151
  # ==============================================================================
152
  available_models = find_available_models()
153
 
154
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), css=".gradio-container {max-width: 95% !important;}") as demo:
155
- # مكون الحالة الخفي لتخزين النموذج الخاص بكل جلسة
156
  model_state = gr.State(value=None)
157
 
158
  gr.Markdown("# 🚗 محاكاة القيادة الذاتية باستخدام Interfuser")
@@ -176,22 +165,19 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), cs
176
  with gr.Group():
177
  gr.Markdown("## 🗂️ الخطوة 2: ارفع ملفات السيناريو")
178
 
179
- # المدخلات المطلوبة
180
  with gr.Group():
181
  gr.Markdown("**(مطلوب)**")
182
- api_rgb_image_path = gr.File(label="صورة الكاميرا الأمامية (RGB)")
183
- api_measurements_path = gr.File(label="ملف القياسات (JSON)")
184
 
185
- # المدخلات الاختيارية
186
  with gr.Accordion("📷 مدخلات اختيارية (كاميرات ومستشعرات إضافية)", open=False):
187
- api_rgb_left_image_path = gr.File(label="كاميرا اليسار (RGB)")
188
- api_rgb_right_image_path = gr.File(label="كاميرا اليمين (RGB)")
189
- api_rgb_center_image_path = gr.File(label="كاميرا الوسط (RGB)")
190
- api_lidar_image_path = gr.File(label="بيانات الليدار (NPY)")
191
 
192
  api_target_point_list = gr.JSON(label="📍 النقطة المستهدفة (x, y)", value=[0.0, 100.0])
193
 
194
- # زر التشغيل
195
  api_run_button = gr.Button("🚀 شغل المحاكاة", variant="primary", scale=2)
196
 
197
  # --- أمثلة جاهزة ---
@@ -200,10 +186,11 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), cs
200
  gr.Markdown("انقر على مثال لتعبئة الحقول تلقائياً (يتطلب وجود مجلد `examples` بنفس بنية البيانات).")
201
  gr.Examples(
202
  examples=[
203
- [os.path.join(EXAMPLES_DIR, "sample1", "rgb.png"), None, None, None, None, os.path.join(EXAMPLES_DIR, "sample1", "measurements.json")],
204
- [os.path.join(EXAMPLES_DIR, "sample2", "rgb.png"), None, None, None, None, os.path.join(EXAMPLES_DIR, "sample2", "measurements.json")]
205
  ],
206
- inputs=[api_rgb_image_path, api_rgb_left_image_path, api_rgb_right_image_path, api_rgb_center_image_path, api_lidar_image_path, api_measurements_path],
 
207
  label="اختر سيناريو اختبار"
208
  )
209
 
@@ -234,5 +221,5 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), cs
234
  # ==============================================================================
235
  if __name__ == "__main__":
236
  if not available_models:
237
- print("تحذير: لم يتم العثور على أي ملفات نماذج (.pth) في مجلد 'model'.")
238
- demo.queue().launch(debug=True, share=True) # share=True لإنشاء رابط عام مؤقت
 
5
  import traceback
6
  import torch
7
  import gradio as gr
 
8
  import numpy as np
9
  from PIL import Image
10
  import cv2
 
37
  # ==============================================================================
38
  # 2. الدوال الأساسية (load_model, run_single_frame)
39
  # ==============================================================================
 
 
40
  def load_model(model_name: str):
41
  if not model_name or "لم يتم" in model_name:
42
  return None, "الرجاء اختيار نموذج صالح."
 
45
  model_config = MODELS_SPECIFIC_CONFIGS.get(model_name, {})
46
  model = build_interfuser_model(model_config)
47
  if not os.path.exists(weights_path):
48
+ gr.Warning(f"ملف الأوزان '{weights_path}' غير موجود.")
 
49
  else:
50
  try:
51
  state_dic = torch.load(weights_path, map_location=device, weights_only=True)
52
  model.load_state_dict(state_dic)
53
  print(f"تم تحميل أوزان النموذج '{model_name}' بنجاح.")
54
  except Exception as e:
55
+ gr.Warning(f"فشل تحميل الأوزان للنموذج '{model_name}': {e}.")
 
56
  model.to(device)
57
  model.eval()
58
  return model, f"تم تحميل نموذج: {model_name}"
 
67
  if not (rgb_image_path and measurements_path):
68
  raise gr.Error("الرجاء توفير الصورة الأمامية وملف القياسات على الأقل.")
69
 
 
70
  rgb_image_pil = Image.open(rgb_image_path).convert("RGB")
71
  rgb_left_pil = Image.open(rgb_left_image_path).convert("RGB") if rgb_left_image_path else rgb_image_pil
72
  rgb_right_pil = Image.open(rgb_right_image_path).convert("RGB") if rgb_right_image_path else rgb_image_pil
 
101
  'measurements': measurements_tensor, 'target_point': target_point_tensor
102
  }
103
 
 
104
  with torch.no_grad():
105
  outputs = model_from_state(inputs)
106
  traffic, waypoints, is_junction, traffic_light, stop_sign, _ = outputs
107
+
 
108
  speed, pos, theta = m_dict.get('speed',5.0), [m_dict.get('x',0.0), m_dict.get('y',0.0)], m_dict.get('theta',0.0)
109
  traffic_np, waypoints_np = traffic[0].detach().cpu().numpy().reshape(20,20,-1), waypoints[0].detach().cpu().numpy() * WAYPOINT_SCALE_FACTOR
110
  tracker, controller = Tracker(), InterfuserController(ControllerConfig())
111
  updated_traffic = tracker.update_and_predict(traffic_np.copy(), pos, theta, 0)
112
  steer, throttle, brake, metadata = controller.run_step(speed, waypoints_np, is_junction.sigmoid()[0,1].item(), traffic_light.sigmoid()[0,0].item(), stop_sign.sigmoid()[0,1].item(), updated_traffic)
113
 
 
114
  map_t0, counts_t0 = render(updated_traffic, t=0)
115
  map_t1, counts_t1 = render(updated_traffic, t=T1_FUTURE_TIME)
116
  map_t2, counts_t2 = render(updated_traffic, t=T2_FUTURE_TIME)
 
127
  'object_counts': {'t0': counts_t0,'t1': counts_t1,'t2': counts_t2}}
128
  dashboard_image = display.run_interface(interface_data)
129
 
 
130
  result_dict = {"predicted_waypoints": waypoints_np.tolist(), "control_commands": {"steer": steer,"throttle": throttle,"brake": bool(brake)},
131
  "perception": {"traffic_light_status": light_state,"stop_sign_detected": (stop_sign_state == "Yes"),"is_at_junction_prob": round(is_junction.sigmoid()[0,1].item(), 3)},
132
  "metadata": {"speed_info": metadata[0],"perception_info": metadata[1],"stop_info": metadata[2],"safe_distance": metadata[3]}}
 
137
  raise gr.Error(f"حدث خطأ أثناء معالجة الإطار: {e}")
138
 
139
  # ==============================================================================
140
+ # 4. تعريف واجهة Gradio المحسّنة (مع الإصلاح)
141
  # ==============================================================================
142
  available_models = find_available_models()
143
 
144
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), css=".gradio-container {max-width: 95% !important;}") as demo:
 
145
  model_state = gr.State(value=None)
146
 
147
  gr.Markdown("# 🚗 محاكاة القيادة الذاتية باستخدام Interfuser")
 
165
  with gr.Group():
166
  gr.Markdown("## 🗂️ الخطوة 2: ارفع ملفات السيناريو")
167
 
 
168
  with gr.Group():
169
  gr.Markdown("**(مطلوب)**")
170
+ api_rgb_image_path = gr.File(label="صورة الكاميرا الأمامية (RGB)", type="filepath")
171
+ api_measurements_path = gr.File(label="ملف القياسات (JSON)", type="filepath")
172
 
 
173
  with gr.Accordion("📷 مدخلات اختيارية (كاميرات ومستشعرات إضافية)", open=False):
174
+ api_rgb_left_image_path = gr.File(label="كاميرا اليسار (RGB)", type="filepath")
175
+ api_rgb_right_image_path = gr.File(label="كاميرا اليمين (RGB)", type="filepath")
176
+ api_rgb_center_image_path = gr.File(label="كاميرا الوسط (RGB)", type="filepath")
177
+ api_lidar_image_path = gr.File(label="بيانات الليدار (NPY)", type="filepath")
178
 
179
  api_target_point_list = gr.JSON(label="📍 النقطة المستهدفة (x, y)", value=[0.0, 100.0])
180
 
 
181
  api_run_button = gr.Button("🚀 شغل المحاكاة", variant="primary", scale=2)
182
 
183
  # --- أمثلة جاهزة ---
 
186
  gr.Markdown("انقر على مثال لتعبئة الحقول تلقائياً (يتطلب وجود مجلد `examples` بنفس بنية البيانات).")
187
  gr.Examples(
188
  examples=[
189
+ [os.path.join(EXAMPLES_DIR, "sample1", "rgb.png"), os.path.join(EXAMPLES_DIR, "sample1", "measurements.json")],
190
+ [os.path.join(EXAMPLES_DIR, "sample2", "rgb.png"), os.path.join(EXAMPLES_DIR, "sample2", "measurements.json")]
191
  ],
192
+ # يجب أن تتطابق المدخلات مع الحقول المطلوبة في الأمثلة
193
+ inputs=[api_rgb_image_path, api_measurements_path],
194
  label="اختر سيناريو اختبار"
195
  )
196
 
 
221
  # ==============================================================================
222
  if __name__ == "__main__":
223
  if not available_models:
224
+ print("تحذير: لم يتم العثور على أي ملفات نماذج (.pth) في مجلد 'model/weights'.")
225
+ demo.queue().launch(debug=True, share=True)