LaunchLLM / runpod_manager.py
Bmccloud22's picture
Deploy LaunchLLM - Production AI Training Platform
90a59c9 verified
"""
RunPod Manager - High-level management for RunPod instances
Provides higher-level functions for managing RunPod instances including
deployment, monitoring, and SSH access.
"""
import paramiko
import time
from typing import Optional, Dict, List
from dataclasses import dataclass, field
from runpod_client import RunPodClient, PodInfo
@dataclass
class DeploymentConfig:
"""Configuration for RunPod deployment."""
name: str = "aura-training-pod"
gpu_type: str = "NVIDIA A100 80GB PCIe"
gpu_count: int = 1
storage_gb: int = 100
image: str = "runpod/pytorch:2.1.0-py3.10-cuda11.8.0-devel-ubuntu22.04"
ports: str = "8888/http,22/tcp,7860/http" # Jupyter, SSH, Gradio
@dataclass
class TrainingConfig:
"""Configuration for model training on RunPod."""
model_name: str = "Qwen/Qwen2.5-7B-Instruct"
lora_rank: int = 8
learning_rate: float = 2e-4
num_epochs: int = 3
batch_size: int = 4
gradient_accumulation_steps: int = 4
use_4bit: bool = True
max_length: int = 2048
class RunPodManager:
"""Manager for RunPod instances with deployment and monitoring"""
def __init__(self, api_key: Optional[str] = None):
self.client = RunPodClient(api_key)
def deploy_training_pod(
self,
name: str,
gpu_type: str = "NVIDIA A100 80GB PCIe",
gpu_count: int = 1,
storage_gb: int = 100
) -> Optional[str]:
"""Deploy a pod configured for model training"""
# Use PyTorch image with CUDA support
image = "runpod/pytorch:2.1.0-py3.10-cuda11.8.0-devel-ubuntu22.04"
print(f"Deploying training pod '{name}'...")
print(f" GPU: {gpu_type} x{gpu_count}")
print(f" Storage: {storage_gb}GB")
pod_id = self.client.create_pod(
name=name,
image_name=image,
gpu_type_id=gpu_type,
gpu_count=gpu_count,
volume_in_gb=storage_gb,
container_disk_in_gb=50,
ports="8888/http,22/tcp,7860/http" # Jupyter, SSH, Gradio
)
if pod_id:
print(f"Pod created: {pod_id}")
print("Waiting for pod to start...")
time.sleep(10) # Give it time to start
return pod_id
def get_pod_status(self, pod_id: str) -> Optional[Dict]:
"""Get current status of a pod"""
pods = self.client.list_pods()
for pod in pods:
if pod.id == pod_id:
return {
"id": pod.id,
"name": pod.name,
"status": pod.status,
"gpu_type": pod.gpu_type,
"cost_per_hour": pod.cost_per_hour
}
return None
def list_all_pods(self) -> List[PodInfo]:
"""List all pods"""
return self.client.list_pods()
def stop_pod(self, pod_id: str) -> bool:
"""Stop a running pod"""
print(f"Stopping pod {pod_id}...")
return self.client.stop_pod(pod_id)
def terminate_pod(self, pod_id: str) -> bool:
"""Terminate a pod"""
print(f"Terminating pod {pod_id}...")
return self.client.terminate_pod(pod_id)
def get_ssh_connection(
self,
pod_ip: str,
username: str = "root",
key_file: Optional[str] = None,
password: Optional[str] = None
) -> Optional[paramiko.SSHClient]:
"""Get SSH connection to a pod"""
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
try:
if key_file:
ssh.connect(
pod_ip,
username=username,
key_filename=key_file,
timeout=10
)
elif password:
ssh.connect(
pod_ip,
username=username,
password=password,
timeout=10
)
else:
print("Either key_file or password must be provided")
return None
return ssh
except Exception as e:
print(f"SSH connection failed: {e}")
return None
def execute_command(
self,
ssh: paramiko.SSHClient,
command: str
) -> tuple[str, str]:
"""Execute a command via SSH"""
stdin, stdout, stderr = ssh.exec_command(command)
return stdout.read().decode(), stderr.read().decode()
def upload_file(
self,
ssh: paramiko.SSHClient,
local_path: str,
remote_path: str
) -> bool:
"""Upload a file to the pod"""
try:
sftp = ssh.open_sftp()
sftp.put(local_path, remote_path)
sftp.close()
return True
except Exception as e:
print(f"File upload failed: {e}")
return False
def download_file(
self,
ssh: paramiko.SSHClient,
remote_path: str,
local_path: str
) -> bool:
"""Download a file from the pod"""
try:
sftp = ssh.open_sftp()
sftp.get(remote_path, local_path)
sftp.close()
return True
except Exception as e:
print(f"File download failed: {e}")
return False
def setup_training_environment(
self,
ssh: paramiko.SSHClient,
requirements_file: Optional[str] = None
) -> bool:
"""Setup the training environment on a pod"""
print("Setting up training environment...")
# Update pip
print("Updating pip...")
stdout, stderr = self.execute_command(ssh, "pip install --upgrade pip")
if requirements_file:
# Upload requirements file
print("Uploading requirements...")
if not self.upload_file(ssh, requirements_file, "/tmp/requirements.txt"):
return False
# Install requirements
print("Installing requirements...")
stdout, stderr = self.execute_command(
ssh,
"pip install -r /tmp/requirements.txt"
)
if stderr and "error" in stderr.lower():
print(f"Installation errors: {stderr}")
return False
print("Environment setup complete!")
return True
def monitor_training(
self,
ssh: paramiko.SSHClient,
log_file: str = "/workspace/training.log",
interval: int = 30
):
"""Monitor training progress"""
print(f"Monitoring training log: {log_file}")
print(f"Checking every {interval} seconds...")
print("Press Ctrl+C to stop monitoring\n")
last_line_count = 0
try:
while True:
# Get log file content
stdout, stderr = self.execute_command(
ssh,
f"cat {log_file} 2>/dev/null || echo 'Log file not found'"
)
lines = stdout.strip().split('\n')
new_lines = lines[last_line_count:]
if new_lines and new_lines[0] != 'Log file not found':
for line in new_lines:
print(line)
last_line_count = len(lines)
time.sleep(interval)
except KeyboardInterrupt:
print("\nStopped monitoring")
def get_available_gpus(self) -> List[Dict]:
"""Get list of available GPU types"""
return self.client.get_gpu_types()
def estimate_cost(
self,
gpu_type: str,
gpu_count: int,
hours: float
) -> Optional[float]:
"""Estimate cost for a training job"""
pods = self.client.list_pods()
# Find cost per hour for this GPU type
for pod in pods:
if pod.gpu_type == gpu_type and pod.gpu_count == gpu_count:
total_cost = pod.cost_per_hour * hours
return total_cost
return None
def run_training_on_pod(
self,
pod_id: str,
training_data: List[Dict],
model_name: str,
lora_config: Dict,
training_config: Dict
) -> bool:
"""Run training on RunPod pod instead of locally"""
import json
import tempfile
print(f"Starting remote training on pod {pod_id}...")
# 1. Get pod details to find SSH info
pod_details = self.client.get_pod_details(pod_id)
if not pod_details:
print("Error: Could not get pod details")
return False
# Extract SSH connection info
runtime = pod_details.get("runtime")
if not runtime or not runtime.get("ports"):
print("Error: Pod runtime not available. Pod may still be starting.")
return False
# Find SSH port
ssh_port = None
ssh_ip = None
for port in runtime["ports"]:
if port.get("privatePort") == 22:
ssh_ip = port.get("ip")
ssh_port = port.get("publicPort")
break
if not ssh_ip or not ssh_port:
print("Error: SSH port not found in pod details")
return False
print(f"SSH Connection: {ssh_ip}:{ssh_port}")
# 2. Save training data to temp file
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(training_data, f)
data_file = f.name
# 3. Create training script
training_script = f"""
import json
import sys
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
from datasets import Dataset
import torch
print("Loading training data...")
with open('/workspace/training_data.json', 'r') as f:
data = json.load(f)
print(f"Loaded {{len(data)}} training examples")
print("Loading model: {model_name}")
model = AutoModelForCausalLM.from_pretrained(
"{model_name}",
load_in_4bit=True,
device_map="auto",
torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained("{model_name}")
tokenizer.pad_token = tokenizer.eos_token
print("Preparing model for training...")
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
r={lora_config.get('r', 16)},
lora_alpha={lora_config.get('lora_alpha', 32)},
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type=TaskType.CAUSAL_LM
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
print("Preparing dataset...")
def format_data(example):
text = f"###Instruction: {{example['instruction']}}\\n###Response: {{example['output']}}"
return tokenizer(text, truncation=True, max_length=2048, padding="max_length")
dataset = Dataset.from_list(data)
dataset = dataset.map(format_data, batched=False)
training_args = TrainingArguments(
output_dir="/workspace/outputs",
num_train_epochs={training_config.get('num_epochs', 3)},
per_device_train_batch_size={training_config.get('batch_size', 1)},
gradient_accumulation_steps={training_config.get('gradient_accumulation_steps', 16)},
learning_rate={training_config.get('learning_rate', 2e-4)},
logging_steps=10,
save_steps=100,
save_total_limit=2,
fp16=True,
report_to="none"
)
print("Starting training...")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset
)
trainer.train()
print("Saving model...")
model.save_pretrained("/workspace/final_model")
tokenizer.save_pretrained("/workspace/final_model")
print("Training complete!")
"""
# Save script to temp file
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
f.write(training_script)
script_file = f.name
print("Connecting to pod via SSH...")
# Get path to SSH key
import os
key_path = os.path.join(os.getcwd(), ".ssh", "runpod_key")
if not os.path.exists(key_path):
print(f"Error: SSH key not found at {key_path}")
print("Run: ssh-keygen -t ed25519 -f .ssh/runpod_key -N ''")
print("Then add the public key to RunPod: https://www.runpod.io/console/user/settings")
return False
# Get SSH connection (RunPod uses root user by default)
ssh = self.get_ssh_connection(
pod_ip=ssh_ip,
username="root",
password=None,
key_file=key_path
)
if not ssh:
print("Error: Could not establish SSH connection")
print(f"Tried using key: {key_path}")
print("Verify the public key is added to RunPod: https://www.runpod.io/console/user/settings")
return False
try:
# Upload training data
print("Uploading training data...")
if not self.upload_file(ssh, data_file, "/workspace/training_data.json"):
return False
# Upload training script
print("Uploading training script...")
if not self.upload_file(ssh, script_file, "/workspace/train.py"):
return False
# Install required packages
print("Installing required packages...")
stdout, stderr = self.execute_command(
ssh,
"pip install transformers peft datasets accelerate bitsandbytes"
)
# Execute training
print("Starting training on pod...")
print("Training will run in the background on the pod.")
print("You can monitor progress by checking the pod's logs.")
# Run training in background with nohup
stdout, stderr = self.execute_command(
ssh,
"nohup python /workspace/train.py > /workspace/training.log 2>&1 &"
)
print("\nTraining initiated successfully!")
print("Training data uploaded to: /workspace/training_data.json")
print("Training script uploaded to: /workspace/train.py")
print("Training log available at: /workspace/training.log")
print("\nTo monitor progress, you can:")
print(f" 1. SSH to pod: ssh root@{ssh_ip} -p {ssh_port}")
print(" 2. View logs: tail -f /workspace/training.log")
return True
except Exception as e:
print(f"Error during remote training setup: {e}")
return False
finally:
ssh.close()
# Clean up temp files
import os
try:
os.unlink(data_file)
os.unlink(script_file)
except:
pass