Spaces:
Runtime error
Runtime error
| """ | |
| RunPod Manager - High-level management for RunPod instances | |
| Provides higher-level functions for managing RunPod instances including | |
| deployment, monitoring, and SSH access. | |
| """ | |
| import paramiko | |
| import time | |
| from typing import Optional, Dict, List | |
| from dataclasses import dataclass, field | |
| from runpod_client import RunPodClient, PodInfo | |
| class DeploymentConfig: | |
| """Configuration for RunPod deployment.""" | |
| name: str = "aura-training-pod" | |
| gpu_type: str = "NVIDIA A100 80GB PCIe" | |
| gpu_count: int = 1 | |
| storage_gb: int = 100 | |
| image: str = "runpod/pytorch:2.1.0-py3.10-cuda11.8.0-devel-ubuntu22.04" | |
| ports: str = "8888/http,22/tcp,7860/http" # Jupyter, SSH, Gradio | |
| class TrainingConfig: | |
| """Configuration for model training on RunPod.""" | |
| model_name: str = "Qwen/Qwen2.5-7B-Instruct" | |
| lora_rank: int = 8 | |
| learning_rate: float = 2e-4 | |
| num_epochs: int = 3 | |
| batch_size: int = 4 | |
| gradient_accumulation_steps: int = 4 | |
| use_4bit: bool = True | |
| max_length: int = 2048 | |
| class RunPodManager: | |
| """Manager for RunPod instances with deployment and monitoring""" | |
| def __init__(self, api_key: Optional[str] = None): | |
| self.client = RunPodClient(api_key) | |
| def deploy_training_pod( | |
| self, | |
| name: str, | |
| gpu_type: str = "NVIDIA A100 80GB PCIe", | |
| gpu_count: int = 1, | |
| storage_gb: int = 100 | |
| ) -> Optional[str]: | |
| """Deploy a pod configured for model training""" | |
| # Use PyTorch image with CUDA support | |
| image = "runpod/pytorch:2.1.0-py3.10-cuda11.8.0-devel-ubuntu22.04" | |
| print(f"Deploying training pod '{name}'...") | |
| print(f" GPU: {gpu_type} x{gpu_count}") | |
| print(f" Storage: {storage_gb}GB") | |
| pod_id = self.client.create_pod( | |
| name=name, | |
| image_name=image, | |
| gpu_type_id=gpu_type, | |
| gpu_count=gpu_count, | |
| volume_in_gb=storage_gb, | |
| container_disk_in_gb=50, | |
| ports="8888/http,22/tcp,7860/http" # Jupyter, SSH, Gradio | |
| ) | |
| if pod_id: | |
| print(f"Pod created: {pod_id}") | |
| print("Waiting for pod to start...") | |
| time.sleep(10) # Give it time to start | |
| return pod_id | |
| def get_pod_status(self, pod_id: str) -> Optional[Dict]: | |
| """Get current status of a pod""" | |
| pods = self.client.list_pods() | |
| for pod in pods: | |
| if pod.id == pod_id: | |
| return { | |
| "id": pod.id, | |
| "name": pod.name, | |
| "status": pod.status, | |
| "gpu_type": pod.gpu_type, | |
| "cost_per_hour": pod.cost_per_hour | |
| } | |
| return None | |
| def list_all_pods(self) -> List[PodInfo]: | |
| """List all pods""" | |
| return self.client.list_pods() | |
| def stop_pod(self, pod_id: str) -> bool: | |
| """Stop a running pod""" | |
| print(f"Stopping pod {pod_id}...") | |
| return self.client.stop_pod(pod_id) | |
| def terminate_pod(self, pod_id: str) -> bool: | |
| """Terminate a pod""" | |
| print(f"Terminating pod {pod_id}...") | |
| return self.client.terminate_pod(pod_id) | |
| def get_ssh_connection( | |
| self, | |
| pod_ip: str, | |
| username: str = "root", | |
| key_file: Optional[str] = None, | |
| password: Optional[str] = None | |
| ) -> Optional[paramiko.SSHClient]: | |
| """Get SSH connection to a pod""" | |
| ssh = paramiko.SSHClient() | |
| ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) | |
| try: | |
| if key_file: | |
| ssh.connect( | |
| pod_ip, | |
| username=username, | |
| key_filename=key_file, | |
| timeout=10 | |
| ) | |
| elif password: | |
| ssh.connect( | |
| pod_ip, | |
| username=username, | |
| password=password, | |
| timeout=10 | |
| ) | |
| else: | |
| print("Either key_file or password must be provided") | |
| return None | |
| return ssh | |
| except Exception as e: | |
| print(f"SSH connection failed: {e}") | |
| return None | |
| def execute_command( | |
| self, | |
| ssh: paramiko.SSHClient, | |
| command: str | |
| ) -> tuple[str, str]: | |
| """Execute a command via SSH""" | |
| stdin, stdout, stderr = ssh.exec_command(command) | |
| return stdout.read().decode(), stderr.read().decode() | |
| def upload_file( | |
| self, | |
| ssh: paramiko.SSHClient, | |
| local_path: str, | |
| remote_path: str | |
| ) -> bool: | |
| """Upload a file to the pod""" | |
| try: | |
| sftp = ssh.open_sftp() | |
| sftp.put(local_path, remote_path) | |
| sftp.close() | |
| return True | |
| except Exception as e: | |
| print(f"File upload failed: {e}") | |
| return False | |
| def download_file( | |
| self, | |
| ssh: paramiko.SSHClient, | |
| remote_path: str, | |
| local_path: str | |
| ) -> bool: | |
| """Download a file from the pod""" | |
| try: | |
| sftp = ssh.open_sftp() | |
| sftp.get(remote_path, local_path) | |
| sftp.close() | |
| return True | |
| except Exception as e: | |
| print(f"File download failed: {e}") | |
| return False | |
| def setup_training_environment( | |
| self, | |
| ssh: paramiko.SSHClient, | |
| requirements_file: Optional[str] = None | |
| ) -> bool: | |
| """Setup the training environment on a pod""" | |
| print("Setting up training environment...") | |
| # Update pip | |
| print("Updating pip...") | |
| stdout, stderr = self.execute_command(ssh, "pip install --upgrade pip") | |
| if requirements_file: | |
| # Upload requirements file | |
| print("Uploading requirements...") | |
| if not self.upload_file(ssh, requirements_file, "/tmp/requirements.txt"): | |
| return False | |
| # Install requirements | |
| print("Installing requirements...") | |
| stdout, stderr = self.execute_command( | |
| ssh, | |
| "pip install -r /tmp/requirements.txt" | |
| ) | |
| if stderr and "error" in stderr.lower(): | |
| print(f"Installation errors: {stderr}") | |
| return False | |
| print("Environment setup complete!") | |
| return True | |
| def monitor_training( | |
| self, | |
| ssh: paramiko.SSHClient, | |
| log_file: str = "/workspace/training.log", | |
| interval: int = 30 | |
| ): | |
| """Monitor training progress""" | |
| print(f"Monitoring training log: {log_file}") | |
| print(f"Checking every {interval} seconds...") | |
| print("Press Ctrl+C to stop monitoring\n") | |
| last_line_count = 0 | |
| try: | |
| while True: | |
| # Get log file content | |
| stdout, stderr = self.execute_command( | |
| ssh, | |
| f"cat {log_file} 2>/dev/null || echo 'Log file not found'" | |
| ) | |
| lines = stdout.strip().split('\n') | |
| new_lines = lines[last_line_count:] | |
| if new_lines and new_lines[0] != 'Log file not found': | |
| for line in new_lines: | |
| print(line) | |
| last_line_count = len(lines) | |
| time.sleep(interval) | |
| except KeyboardInterrupt: | |
| print("\nStopped monitoring") | |
| def get_available_gpus(self) -> List[Dict]: | |
| """Get list of available GPU types""" | |
| return self.client.get_gpu_types() | |
| def estimate_cost( | |
| self, | |
| gpu_type: str, | |
| gpu_count: int, | |
| hours: float | |
| ) -> Optional[float]: | |
| """Estimate cost for a training job""" | |
| pods = self.client.list_pods() | |
| # Find cost per hour for this GPU type | |
| for pod in pods: | |
| if pod.gpu_type == gpu_type and pod.gpu_count == gpu_count: | |
| total_cost = pod.cost_per_hour * hours | |
| return total_cost | |
| return None | |
| def run_training_on_pod( | |
| self, | |
| pod_id: str, | |
| training_data: List[Dict], | |
| model_name: str, | |
| lora_config: Dict, | |
| training_config: Dict | |
| ) -> bool: | |
| """Run training on RunPod pod instead of locally""" | |
| import json | |
| import tempfile | |
| print(f"Starting remote training on pod {pod_id}...") | |
| # 1. Get pod details to find SSH info | |
| pod_details = self.client.get_pod_details(pod_id) | |
| if not pod_details: | |
| print("Error: Could not get pod details") | |
| return False | |
| # Extract SSH connection info | |
| runtime = pod_details.get("runtime") | |
| if not runtime or not runtime.get("ports"): | |
| print("Error: Pod runtime not available. Pod may still be starting.") | |
| return False | |
| # Find SSH port | |
| ssh_port = None | |
| ssh_ip = None | |
| for port in runtime["ports"]: | |
| if port.get("privatePort") == 22: | |
| ssh_ip = port.get("ip") | |
| ssh_port = port.get("publicPort") | |
| break | |
| if not ssh_ip or not ssh_port: | |
| print("Error: SSH port not found in pod details") | |
| return False | |
| print(f"SSH Connection: {ssh_ip}:{ssh_port}") | |
| # 2. Save training data to temp file | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: | |
| json.dump(training_data, f) | |
| data_file = f.name | |
| # 3. Create training script | |
| training_script = f""" | |
| import json | |
| import sys | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer | |
| from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training | |
| from datasets import Dataset | |
| import torch | |
| print("Loading training data...") | |
| with open('/workspace/training_data.json', 'r') as f: | |
| data = json.load(f) | |
| print(f"Loaded {{len(data)}} training examples") | |
| print("Loading model: {model_name}") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "{model_name}", | |
| load_in_4bit=True, | |
| device_map="auto", | |
| torch_dtype=torch.float16 | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained("{model_name}") | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("Preparing model for training...") | |
| model = prepare_model_for_kbit_training(model) | |
| lora_config = LoraConfig( | |
| r={lora_config.get('r', 16)}, | |
| lora_alpha={lora_config.get('lora_alpha', 32)}, | |
| target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], | |
| lora_dropout=0.05, | |
| bias="none", | |
| task_type=TaskType.CAUSAL_LM | |
| ) | |
| model = get_peft_model(model, lora_config) | |
| model.print_trainable_parameters() | |
| print("Preparing dataset...") | |
| def format_data(example): | |
| text = f"###Instruction: {{example['instruction']}}\\n###Response: {{example['output']}}" | |
| return tokenizer(text, truncation=True, max_length=2048, padding="max_length") | |
| dataset = Dataset.from_list(data) | |
| dataset = dataset.map(format_data, batched=False) | |
| training_args = TrainingArguments( | |
| output_dir="/workspace/outputs", | |
| num_train_epochs={training_config.get('num_epochs', 3)}, | |
| per_device_train_batch_size={training_config.get('batch_size', 1)}, | |
| gradient_accumulation_steps={training_config.get('gradient_accumulation_steps', 16)}, | |
| learning_rate={training_config.get('learning_rate', 2e-4)}, | |
| logging_steps=10, | |
| save_steps=100, | |
| save_total_limit=2, | |
| fp16=True, | |
| report_to="none" | |
| ) | |
| print("Starting training...") | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=dataset | |
| ) | |
| trainer.train() | |
| print("Saving model...") | |
| model.save_pretrained("/workspace/final_model") | |
| tokenizer.save_pretrained("/workspace/final_model") | |
| print("Training complete!") | |
| """ | |
| # Save script to temp file | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: | |
| f.write(training_script) | |
| script_file = f.name | |
| print("Connecting to pod via SSH...") | |
| # Get path to SSH key | |
| import os | |
| key_path = os.path.join(os.getcwd(), ".ssh", "runpod_key") | |
| if not os.path.exists(key_path): | |
| print(f"Error: SSH key not found at {key_path}") | |
| print("Run: ssh-keygen -t ed25519 -f .ssh/runpod_key -N ''") | |
| print("Then add the public key to RunPod: https://www.runpod.io/console/user/settings") | |
| return False | |
| # Get SSH connection (RunPod uses root user by default) | |
| ssh = self.get_ssh_connection( | |
| pod_ip=ssh_ip, | |
| username="root", | |
| password=None, | |
| key_file=key_path | |
| ) | |
| if not ssh: | |
| print("Error: Could not establish SSH connection") | |
| print(f"Tried using key: {key_path}") | |
| print("Verify the public key is added to RunPod: https://www.runpod.io/console/user/settings") | |
| return False | |
| try: | |
| # Upload training data | |
| print("Uploading training data...") | |
| if not self.upload_file(ssh, data_file, "/workspace/training_data.json"): | |
| return False | |
| # Upload training script | |
| print("Uploading training script...") | |
| if not self.upload_file(ssh, script_file, "/workspace/train.py"): | |
| return False | |
| # Install required packages | |
| print("Installing required packages...") | |
| stdout, stderr = self.execute_command( | |
| ssh, | |
| "pip install transformers peft datasets accelerate bitsandbytes" | |
| ) | |
| # Execute training | |
| print("Starting training on pod...") | |
| print("Training will run in the background on the pod.") | |
| print("You can monitor progress by checking the pod's logs.") | |
| # Run training in background with nohup | |
| stdout, stderr = self.execute_command( | |
| ssh, | |
| "nohup python /workspace/train.py > /workspace/training.log 2>&1 &" | |
| ) | |
| print("\nTraining initiated successfully!") | |
| print("Training data uploaded to: /workspace/training_data.json") | |
| print("Training script uploaded to: /workspace/train.py") | |
| print("Training log available at: /workspace/training.log") | |
| print("\nTo monitor progress, you can:") | |
| print(f" 1. SSH to pod: ssh root@{ssh_ip} -p {ssh_port}") | |
| print(" 2. View logs: tail -f /workspace/training.log") | |
| return True | |
| except Exception as e: | |
| print(f"Error during remote training setup: {e}") | |
| return False | |
| finally: | |
| ssh.close() | |
| # Clean up temp files | |
| import os | |
| try: | |
| os.unlink(data_file) | |
| os.unlink(script_file) | |
| except: | |
| pass | |