""" RunPod Client - Low-level GraphQL API client for RunPod Provides direct access to RunPod's GraphQL API for pod management. """ import os import requests from typing import Optional, List, Dict from dataclasses import dataclass @dataclass class PodInfo: """Information about a RunPod pod""" id: str name: str status: str gpu_type: str gpu_count: int cost_per_hour: float runtime: Optional[Dict] = None class RunPodClient: """Low-level client for RunPod GraphQL API""" def __init__(self, api_key: Optional[str] = None): self.api_key = api_key or os.getenv("RUNPOD_API_KEY") if not self.api_key: raise ValueError("RunPod API key required. Set RUNPOD_API_KEY environment variable.") self.endpoint = "https://api.runpod.io/graphql" self.headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}" } def _query(self, query: str, variables: Optional[Dict] = None) -> Dict: """Execute a GraphQL query""" payload = { "query": query, "variables": variables or {} } response = requests.post( self.endpoint, json=payload, headers=self.headers, timeout=30 ) if response.status_code != 200: raise Exception(f"GraphQL request failed: {response.status_code} {response.text}") return response.json() def list_pods(self) -> List[PodInfo]: """List all pods""" query = """ query { myself { pods { id name desiredStatus runtime { gpus { id } } machine { podHostId } costPerHr gpuCount } } } """ result = self._query(query) if "errors" in result: print(f"Error listing pods: {result['errors']}") return [] pods_data = result.get("data", {}).get("myself", {}).get("pods", []) pods = [] for pod_data in pods_data: gpu_type = "GPU" # Generic GPU type since API doesn't provide type details if pod_data.get("runtime") and pod_data["runtime"].get("gpus"): gpu_id = pod_data["runtime"]["gpus"][0].get("id", "") if gpu_id: gpu_type = f"GPU-{gpu_id[:8]}" # Use shortened GPU ID pods.append(PodInfo( id=pod_data["id"], name=pod_data["name"], status=pod_data.get("desiredStatus", "unknown"), gpu_type=gpu_type, gpu_count=pod_data.get("gpuCount", 0), cost_per_hour=pod_data.get("costPerHr", 0.0), runtime=pod_data.get("runtime") )) return pods def create_pod( self, name: str, image_name: str, gpu_type_id: str, gpu_count: int = 1, volume_in_gb: int = 100, container_disk_in_gb: int = 50, ports: str = "8888/http" ) -> Optional[str]: """Create a new pod""" query = """ mutation($input: PodFindAndDeployOnDemandInput!) { podFindAndDeployOnDemand(input: $input) { id name desiredStatus } } """ variables = { "input": { "name": name, "imageName": image_name, "gpuTypeId": gpu_type_id, "gpuCount": gpu_count, "volumeInGb": volume_in_gb, "containerDiskInGb": container_disk_in_gb, "ports": ports, "cloudType": "ALL" } } result = self._query(query, variables) if "errors" in result: print(f"Error creating pod: {result['errors']}") return None pod_data = result.get("data", {}).get("podFindAndDeployOnDemand") if pod_data: return pod_data["id"] return None def stop_pod(self, pod_id: str) -> bool: """Stop a running pod""" query = """ mutation($input: PodStopInput!) { podStop(input: $input) { id desiredStatus } } """ variables = { "input": { "podId": pod_id } } result = self._query(query, variables) if "errors" in result: print(f"Error stopping pod: {result['errors']}") return False return True def terminate_pod(self, pod_id: str) -> bool: """Terminate a pod""" query = """ mutation($input: PodTerminateInput!) { podTerminate(input: $input) } """ variables = { "input": { "podId": pod_id } } result = self._query(query, variables) if "errors" in result: print(f"Error terminating pod: {result['errors']}") return False return True def get_gpu_types(self) -> List[Dict]: """Get available GPU types""" query = """ query { gpuTypes { id displayName memoryInGb secureCloud communityCloud } } """ result = self._query(query) if "errors" in result: print(f"Error getting GPU types: {result['errors']}") return [] gpu_types = result.get("data", {}).get("gpuTypes", []) return gpu_types def get_pod_details(self, pod_id: str) -> Optional[Dict]: """Get detailed information about a specific pod""" query = """ query($podId: String!) { pod(input: {podId: $podId}) { id name desiredStatus runtime { gpus { id } ports { ip isIpPublic privatePort publicPort type } } machine { podHostId } gpuCount costPerHr } } """ variables = {"podId": pod_id} result = self._query(query, variables) if "errors" in result: print(f"Error getting pod details: {result['errors']}") return None return result.get("data", {}).get("pod")