Spaces:
Sleeping
Sleeping
| # training_space/app.py (FastAPI Backend) | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import subprocess | |
| import os | |
| import uuid | |
| from huggingface_hub import HfApi, HfFolder | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import logging | |
| app = FastAPI() | |
| # Configure Logging | |
| logging.basicConfig( | |
| filename='training.log', | |
| filemode='a', | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| level=logging.INFO | |
| ) | |
| # CORS Configuration | |
| origins = [ | |
| "https://Vishwas1-LLMBuilderPro.hf.space", # Replace with your Gradio frontend Space URL | |
| "http://localhost", # For local testing | |
| "https://web.postman.co", | |
| ] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Define the expected payload structure | |
| class TrainingRequest(BaseModel): | |
| task: str # 'generation' or 'classification' | |
| model_params: dict | |
| model_name: str | |
| dataset_name: str # The name of the existing Hugging Face dataset | |
| # Root Endpoint | |
| def read_root(): | |
| return { | |
| "message": "Welcome to the Training Space API!", | |
| "instructions": "To train a model, send a POST request to /train with the required parameters." | |
| } | |
| # Train Endpoint | |
| def train_model(request: TrainingRequest): | |
| try: | |
| logging.info(f"Received training request for model: {request.model_name}, Task: {request.task}") | |
| # Create a unique directory for this training session | |
| session_id = str(uuid.uuid4()) | |
| session_dir = f"./training_sessions/{session_id}" | |
| os.makedirs(session_dir, exist_ok=True) | |
| # No need to save dataset content; use dataset_name directly | |
| dataset_name = request.dataset_name | |
| # Define the absolute path to train_model.py | |
| TRAIN_MODEL_PATH = os.path.join(os.path.dirname(__file__), "train_model.py") | |
| # Prepare the command to run the training script with dataset_name | |
| cmd = [ | |
| "python", TRAIN_MODEL_PATH, | |
| "--task", request.task, | |
| "--model_name", request.model_name, | |
| "--dataset_name", dataset_name, # Pass dataset_name instead of dataset file path | |
| "--num_layers", str(request.model_params.get('num_layers', 12)), | |
| "--attention_heads", str(request.model_params.get('attention_heads', 1)), | |
| "--hidden_size", str(request.model_params.get('hidden_size', 64)), | |
| "--vocab_size", str(request.model_params.get('vocab_size', 30000)), | |
| "--sequence_length", str(request.model_params.get('sequence_length', 512)) | |
| ] | |
| # Start the training process as a background task in the root directory | |
| subprocess.Popen(cmd, cwd=os.path.dirname(__file__)) | |
| logging.info(f"Training started for model: {request.model_name}, Session ID: {session_id}") | |
| return {"status": "Training started", "session_id": session_id} | |
| except Exception as e: | |
| logging.error(f"Error during training request: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| except Exception as e: | |
| logging.error(f"Error during training request: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Optional: Status Endpoint | |
| def get_status(session_id: str): | |
| session_dir = f"./training_sessions/{session_id}" | |
| log_file = os.path.join(session_dir, "training.log") | |
| if not os.path.exists(log_file): | |
| raise HTTPException(status_code=404, detail="Session ID not found.") | |
| with open(log_file, "r", encoding="utf-8") as f: | |
| logs = f.read() | |
| return {"session_id": session_id, "logs": logs} | |