Spaces:
Sleeping
Sleeping
| # app.py - InterFuser Self-Driving API Server | |
| import uuid | |
| import base64 | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import HTMLResponse | |
| from pydantic import BaseModel | |
| from torchvision import transforms | |
| from typing import List, Dict, Any, Optional | |
| import logging | |
| # استيراد من ملفاتنا المحلية | |
| from model_definition import InterfuserModel, load_and_prepare_model, create_model_config | |
| from simulation_modules import ( | |
| InterfuserController, ControllerConfig, Tracker, DisplayInterface, | |
| render, render_waypoints, render_self_car, WAYPOINT_SCALE_FACTOR, | |
| T1_FUTURE_TIME, T2_FUTURE_TIME | |
| ) | |
| # إعداد التسجيل | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ================== إعدادات عامة وتحميل النموذج ================== | |
| app = FastAPI( | |
| title="Baseer Self-Driving API", | |
| description="API للقيادة الذاتية باستخدام نموذج InterFuser", | |
| version="1.0.0" | |
| ) | |
| device = torch.device("cpu") | |
| logger.info(f"Using device: {device}") | |
| # تحميل النموذج باستخدام الدالة المحسنة | |
| try: | |
| # إنشاء إعدادات النموذج باستخدام الإعدادات الصحيحة من التدريب | |
| model_config = create_model_config( | |
| model_path="model/best_model.pth" | |
| # الإعدادات الصحيحة من التدريب ستطبق تلقائياً: | |
| # embed_dim=256, rgb_backbone_name='r50', waypoints_pred_head='gru' | |
| # with_lidar=False, with_right_left_sensors=False, with_center_sensor=False | |
| ) | |
| # تحميل النموذج مع الأوزان | |
| model = load_and_prepare_model(model_config, device) | |
| logger.info("✅ تم تحميل النموذج بنجاح") | |
| except Exception as e: | |
| logger.error(f"❌ خطأ في تحميل النموذج: {e}") | |
| logger.info("🔄 محاولة إنشاء نموذج بأوزان عشوائية...") | |
| try: | |
| model = InterfuserModel() | |
| model.to(device) | |
| model.eval() | |
| logger.warning("⚠️ تم إنشاء النموذج بأوزان عشوائية") | |
| except Exception as e2: | |
| logger.error(f"❌ فشل في إنشاء النموذج: {e2}") | |
| model = None | |
| # تهيئة واجهة العرض | |
| display = DisplayInterface() | |
| # قاموس لتخزين جلسات المستخدمين | |
| SESSIONS: Dict[str, Dict] = {} | |
| # ================== هياكل بيانات Pydantic ================== | |
| class Measurements(BaseModel): | |
| pos: List[float] = [0.0, 0.0] # [x, y] position | |
| theta: float = 0.0 # orientation angle | |
| speed: float = 0.0 # current speed | |
| steer: float = 0.0 # current steering | |
| throttle: float = 0.0 # current throttle | |
| brake: bool = False # brake status | |
| command: int = 4 # driving command (4 = FollowLane) | |
| target_point: List[float] = [0.0, 0.0] # target point [x, y] | |
| class ModelOutputs(BaseModel): | |
| traffic: List[List[List[float]]] # 20x20x7 grid | |
| waypoints: List[List[float]] # Nx2 waypoints | |
| is_junction: float | |
| traffic_light_state: float | |
| stop_sign: float | |
| class ControlCommands(BaseModel): | |
| steer: float | |
| throttle: float | |
| brake: bool | |
| class RunStepInput(BaseModel): | |
| session_id: str | |
| image_b64: str | |
| measurements: Measurements | |
| class RunStepOutput(BaseModel): | |
| model_outputs: ModelOutputs | |
| control_commands: ControlCommands | |
| dashboard_image_b64: str | |
| class SessionResponse(BaseModel): | |
| session_id: str | |
| message: str | |
| # ================== دوال المساعدة ================== | |
| def get_image_transform(): | |
| """إنشاء تحويلات الصورة كما في PDMDataset""" | |
| return transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Resize((224, 224), antialias=True), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # إنشاء كائن التحويل مرة واحدة | |
| image_transform = get_image_transform() | |
| def preprocess_input(frame_rgb: np.ndarray, measurements: Measurements, device: torch.device) -> Dict[str, torch.Tensor]: | |
| """ | |
| تحاكي ما يفعله PDMDataset.__getitem__ لإنشاء دفعة (batch) واحدة. | |
| """ | |
| # 1. معالجة الصورة الرئيسية | |
| from PIL import Image | |
| if isinstance(frame_rgb, np.ndarray): | |
| frame_rgb = Image.fromarray(frame_rgb) | |
| image_tensor = image_transform(frame_rgb).unsqueeze(0).to(device) # إضافة بُعد الدفعة | |
| # 2. إنشاء مدخلات الكاميرات الأخرى عن طريق الاستنساخ | |
| batch = { | |
| 'rgb': image_tensor, | |
| 'rgb_left': image_tensor.clone(), | |
| 'rgb_right': image_tensor.clone(), | |
| 'rgb_center': image_tensor.clone(), | |
| } | |
| # 3. إنشاء مدخل ليدار وهمي (أصفار) | |
| batch['lidar'] = torch.zeros(1, 3, 224, 224, dtype=torch.float32).to(device) | |
| # 4. تجميع القياسات بنفس ترتيب PDMDataset | |
| m = measurements | |
| measurements_tensor = torch.tensor([[ | |
| m.pos[0], m.pos[1], m.theta, | |
| m.steer, m.throttle, float(m.brake), | |
| m.speed, float(m.command) | |
| ]], dtype=torch.float32).to(device) | |
| batch['measurements'] = measurements_tensor | |
| # 5. إنشاء نقطة هدف | |
| batch['target_point'] = torch.tensor([m.target_point], dtype=torch.float32).to(device) | |
| # لا نحتاج إلى قيم ground truth (gt_*) أثناء التنبؤ | |
| return batch | |
| def decode_base64_image(image_b64: str) -> np.ndarray: | |
| """ | |
| فك تشفير صورة Base64 | |
| """ | |
| try: | |
| image_bytes = base64.b64decode(image_b64) | |
| nparr = np.frombuffer(image_bytes, np.uint8) | |
| image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| return image | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Invalid image format: {str(e)}") | |
| def encode_image_to_base64(image: np.ndarray) -> str: | |
| """ | |
| تشفير صورة إلى Base64 | |
| """ | |
| _, buffer = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, 85]) | |
| return base64.b64encode(buffer).decode('utf-8') | |
| # ================== نقاط نهاية الـ API ================== | |
| async def root(): | |
| """ | |
| الصفحة الرئيسية للـ API | |
| """ | |
| html_content = f""" | |
| <!DOCTYPE html> | |
| <html dir="rtl" lang="ar"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <title>🚗 Baseer Self-Driving API</title> | |
| <style> | |
| * {{ | |
| margin: 0; | |
| padding: 0; | |
| box-sizing: border-box; | |
| }} | |
| body {{ | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| min-height: 100vh; | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| padding: 20px; | |
| }} | |
| .container {{ | |
| background: rgba(255, 255, 255, 0.95); | |
| backdrop-filter: blur(10px); | |
| border-radius: 20px; | |
| padding: 40px; | |
| box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1); | |
| text-align: center; | |
| max-width: 600px; | |
| width: 100%; | |
| }} | |
| .logo {{ | |
| font-size: 4rem; | |
| margin-bottom: 20px; | |
| animation: bounce 2s infinite; | |
| }} | |
| @keyframes bounce {{ | |
| 0%, 20%, 50%, 80%, 100% {{ transform: translateY(0); }} | |
| 40% {{ transform: translateY(-10px); }} | |
| 60% {{ transform: translateY(-5px); }} | |
| }} | |
| h1 {{ | |
| color: #333; | |
| margin-bottom: 10px; | |
| font-size: 2.5rem; | |
| }} | |
| .subtitle {{ | |
| color: #666; | |
| margin-bottom: 30px; | |
| font-size: 1.2rem; | |
| }} | |
| .status {{ | |
| display: inline-block; | |
| background: #4CAF50; | |
| color: white; | |
| padding: 8px 16px; | |
| border-radius: 20px; | |
| margin: 10px 0; | |
| font-weight: bold; | |
| }} | |
| .stats {{ | |
| display: grid; | |
| grid-template-columns: repeat(auto-fit, minmax(150px, 1fr)); | |
| gap: 20px; | |
| margin: 30px 0; | |
| }} | |
| .stat-card {{ | |
| background: #f8f9fa; | |
| padding: 20px; | |
| border-radius: 15px; | |
| border-left: 4px solid #667eea; | |
| }} | |
| .stat-number {{ | |
| font-size: 2rem; | |
| font-weight: bold; | |
| color: #667eea; | |
| }} | |
| .stat-label {{ | |
| color: #666; | |
| margin-top: 5px; | |
| }} | |
| .buttons {{ | |
| display: flex; | |
| gap: 15px; | |
| justify-content: center; | |
| flex-wrap: wrap; | |
| margin-top: 30px; | |
| }} | |
| .btn {{ | |
| display: inline-block; | |
| padding: 12px 24px; | |
| border-radius: 25px; | |
| text-decoration: none; | |
| font-weight: bold; | |
| transition: all 0.3s ease; | |
| border: none; | |
| cursor: pointer; | |
| }} | |
| .btn-primary {{ | |
| background: #667eea; | |
| color: white; | |
| }} | |
| .btn-secondary {{ | |
| background: #6c757d; | |
| color: white; | |
| }} | |
| .btn:hover {{ | |
| transform: translateY(-2px); | |
| box-shadow: 0 5px 15px rgba(0, 0, 0, 0.2); | |
| }} | |
| .features {{ | |
| text-align: right; | |
| margin-top: 30px; | |
| padding: 20px; | |
| background: #f8f9fa; | |
| border-radius: 15px; | |
| }} | |
| .features h3 {{ | |
| color: #333; | |
| margin-bottom: 15px; | |
| }} | |
| .features ul {{ | |
| list-style: none; | |
| padding: 0; | |
| }} | |
| .features li {{ | |
| padding: 5px 0; | |
| color: #666; | |
| }} | |
| .features li:before {{ | |
| content: "✅ "; | |
| margin-left: 10px; | |
| }} | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <div class="logo">🚗</div> | |
| <h1>Baseer Self-Driving API</h1> | |
| <p class="subtitle">نظام القيادة الذاتية المتقدم</p> | |
| <div class="status">🟢 يعمل بنجاح</div> | |
| <div class="stats"> | |
| <div class="stat-card"> | |
| <div class="stat-number">{len(SESSIONS)}</div> | |
| <div class="stat-label">الجلسات النشطة</div> | |
| </div> | |
| <div class="stat-card"> | |
| <div class="stat-number">v1.0</div> | |
| <div class="stat-label">الإصدار</div> | |
| </div> | |
| <div class="stat-card"> | |
| <div class="stat-number">FastAPI</div> | |
| <div class="stat-label">التقنية</div> | |
| </div> | |
| </div> | |
| <div class="buttons"> | |
| <a href="/docs" class="btn btn-primary">📚 توثيق API</a> | |
| <a href="/sessions" class="btn btn-secondary">📊 الجلسات</a> | |
| </div> | |
| <div class="features"> | |
| <h3>🌟 الميزات الرئيسية</h3> | |
| <ul> | |
| <li>نموذج InterFuser للقيادة الذاتية</li> | |
| <li>معالجة الصور في الوقت الفعلي</li> | |
| <li>اكتشاف الكائنات المرورية</li> | |
| <li>تحديد المسارات الذكية</li> | |
| <li>واجهة RESTful سهلة الاستخدام</li> | |
| <li>إدارة جلسات متعددة</li> | |
| </ul> | |
| </div> | |
| </div> | |
| </body> | |
| </html> | |
| """ | |
| return html_content | |
| async def start_session(): | |
| """ | |
| بدء جلسة جديدة للمحاكاة | |
| """ | |
| session_id = str(uuid.uuid4()) | |
| # إنشاء جلسة جديدة | |
| SESSIONS[session_id] = { | |
| 'tracker': Tracker(frequency=10), | |
| 'controller': InterfuserController(ControllerConfig()), | |
| 'frame_num': 0, | |
| 'created_at': np.datetime64('now'), | |
| 'last_activity': np.datetime64('now') | |
| } | |
| logger.info(f"New session created: {session_id}") | |
| return SessionResponse( | |
| session_id=session_id, | |
| message="Session started successfully" | |
| ) | |
| async def run_step(data: RunStepInput): | |
| """ | |
| تنفيذ خطوة محاكاة كاملة | |
| """ | |
| # التحقق من وجود الجلسة | |
| if data.session_id not in SESSIONS: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| session = SESSIONS[data.session_id] | |
| tracker = session['tracker'] | |
| controller = session['controller'] | |
| # تحديث وقت النشاط | |
| session['last_activity'] = np.datetime64('now') | |
| try: | |
| # 1. فك تشفير الصورة | |
| frame_bgr = decode_base64_image(data.image_b64) | |
| frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) | |
| # 2. معالجة المدخلات | |
| inputs = preprocess_input(frame_rgb, data.measurements, device) | |
| # 3. تشغيل النموذج | |
| if model is None: | |
| raise HTTPException(status_code=500, detail="Model not loaded") | |
| with torch.no_grad(): | |
| traffic, waypoints, is_junction, traffic_light, stop_sign, _ = model(inputs) | |
| # 4. معالجة مخرجات النموذج | |
| traffic_np = traffic.cpu().numpy()[0] # أخذ أول عنصر من الـ batch | |
| waypoints_np = waypoints.cpu().numpy()[0] | |
| is_junction_prob = torch.sigmoid(is_junction)[0, 1].item() | |
| traffic_light_prob = torch.sigmoid(traffic_light)[0, 0].item() | |
| stop_sign_prob = torch.sigmoid(stop_sign)[0, 1].item() | |
| # 5. تحديث التتبع | |
| # تحويل traffic grid إلى detections للتتبع | |
| detections = [] | |
| h, w, c = traffic_np.shape | |
| for y in range(h): | |
| for x in range(w): | |
| for ch in range(c): | |
| if traffic_np[y, x, ch] > 0.2: # عتبة الكشف | |
| world_x = (x / w - 0.5) * 64 # تحويل إلى إحداثيات العالم | |
| world_y = (y / h - 0.5) * 64 | |
| detections.append({ | |
| 'position': [world_x, world_y], | |
| 'feature': traffic_np[y, x, ch] | |
| }) | |
| updated_traffic = tracker.update_and_predict(detections, session['frame_num']) | |
| # 6. تشغيل المتحكم | |
| steer, throttle, brake, metadata = controller.run_step( | |
| current_speed=data.measurements.speed, | |
| waypoints=waypoints_np, | |
| junction=is_junction_prob, | |
| traffic_light_state=traffic_light_prob, | |
| stop_sign=stop_sign_prob, | |
| meta_data={'frame': session['frame_num']} | |
| ) | |
| # 7. إنشاء خرائط العرض | |
| surround_t0, counts_t0 = render(updated_traffic, t=0) | |
| surround_t1, counts_t1 = render(updated_traffic, t=T1_FUTURE_TIME) | |
| surround_t2, counts_t2 = render(updated_traffic, t=T2_FUTURE_TIME) | |
| # إضافة المسار المقترح | |
| wp_map = render_waypoints(waypoints_np) | |
| map_t0 = cv2.add(surround_t0, wp_map) | |
| # إضافة السيارة الذاتية | |
| map_t0 = render_self_car(map_t0) | |
| map_t1 = render_self_car(surround_t1) | |
| map_t2 = render_self_car(surround_t2) | |
| # 8. إنشاء لوحة العرض النهائية | |
| interface_data = { | |
| 'camera_view': frame_bgr, | |
| 'map_t0': map_t0, | |
| 'map_t1': map_t1, | |
| 'map_t2': map_t2, | |
| 'text_info': { | |
| 'Frame': f"Frame: {session['frame_num']}", | |
| 'Control': f"Steer: {steer:.2f}, Throttle: {throttle:.2f}, Brake: {brake}", | |
| 'Speed': f"Speed: {data.measurements.speed:.1f} km/h", | |
| 'Junction': f"Junction: {is_junction_prob:.2f}", | |
| 'Traffic Light': f"Red Light: {traffic_light_prob:.2f}", | |
| 'Stop Sign': f"Stop Sign: {stop_sign_prob:.2f}", | |
| 'Metadata': metadata | |
| }, | |
| 'object_counts': { | |
| 't0': counts_t0, | |
| 't1': counts_t1, | |
| 't2': counts_t2 | |
| } | |
| } | |
| dashboard_image = display.run_interface(interface_data) | |
| dashboard_b64 = encode_image_to_base64(dashboard_image) | |
| # 9. تجميع المخرجات النهائية | |
| response = RunStepOutput( | |
| model_outputs=ModelOutputs( | |
| traffic=traffic_np.tolist(), | |
| waypoints=waypoints_np.tolist(), | |
| is_junction=is_junction_prob, | |
| traffic_light_state=traffic_light_prob, | |
| stop_sign=stop_sign_prob | |
| ), | |
| control_commands=ControlCommands( | |
| steer=float(steer), | |
| throttle=float(throttle), | |
| brake=bool(brake) | |
| ), | |
| dashboard_image_b64=dashboard_b64 | |
| ) | |
| # تحديث رقم الإطار | |
| session['frame_num'] += 1 | |
| logger.info(f"Step completed for session {data.session_id}, frame {session['frame_num']}") | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error in run_step: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}") | |
| async def end_session(session_id: str): | |
| """ | |
| إنهاء جلسة المحاكاة | |
| """ | |
| if session_id not in SESSIONS: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| # حذف الجلسة | |
| del SESSIONS[session_id] | |
| logger.info(f"Session ended: {session_id}") | |
| return SessionResponse( | |
| session_id=session_id, | |
| message="Session ended successfully" | |
| ) | |
| async def list_sessions(): | |
| """ | |
| عرض قائمة الجلسات النشطة | |
| """ | |
| active_sessions = [] | |
| current_time = np.datetime64('now') | |
| for session_id, session_data in SESSIONS.items(): | |
| time_diff = current_time - session_data['last_activity'] | |
| active_sessions.append({ | |
| 'session_id': session_id, | |
| 'frame_count': session_data['frame_num'], | |
| 'created_at': str(session_data['created_at']), | |
| 'last_activity': str(session_data['last_activity']), | |
| 'inactive_minutes': float(time_diff / np.timedelta64(1, 'm')) | |
| }) | |
| return { | |
| 'total_sessions': len(active_sessions), | |
| 'sessions': active_sessions | |
| } | |
| async def cleanup_inactive_sessions(max_inactive_minutes: int = 30): | |
| """ | |
| تنظيف الجلسات غير النشطة | |
| """ | |
| current_time = np.datetime64('now') | |
| cleaned_sessions = [] | |
| for session_id in list(SESSIONS.keys()): | |
| session = SESSIONS[session_id] | |
| time_diff = current_time - session['last_activity'] | |
| inactive_minutes = float(time_diff / np.timedelta64(1, 'm')) | |
| if inactive_minutes > max_inactive_minutes: | |
| del SESSIONS[session_id] | |
| cleaned_sessions.append(session_id) | |
| logger.info(f"Cleaned up {len(cleaned_sessions)} inactive sessions") | |
| return { | |
| 'message': f"Cleaned up {len(cleaned_sessions)} inactive sessions", | |
| 'cleaned_sessions': cleaned_sessions, | |
| 'remaining_sessions': len(SESSIONS) | |
| } | |
| # ================== معالج الأخطاء ================== | |
| async def global_exception_handler(request, exc): | |
| logger.error(f"Global exception: {str(exc)}") | |
| return { | |
| "error": "Internal server error", | |
| "detail": str(exc) | |
| } | |
| # ================== تشغيل الخادم ================== | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |