rkihacker commited on
Commit
f5fc07f
·
verified ·
1 Parent(s): 4e6a1d2

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +131 -141
main.py CHANGED
@@ -1,5 +1,5 @@
1
  import httpx
2
- from fastapi import FastAPI, Request, HTTPException
3
  from starlette.responses import StreamingResponse, JSONResponse
4
  from starlette.background import BackgroundTask
5
  import os
@@ -7,93 +7,125 @@ import random
7
  import logging
8
  import time
9
  import json
 
10
  from contextlib import asynccontextmanager
 
11
 
12
  # --- Production-Ready Configuration ---
13
  LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
14
  logging.basicConfig(
15
  level=LOG_LEVEL,
16
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
17
  )
18
  logger = logging.getLogger(__name__)
19
 
20
- # URL to fetch the list of all available models and their endpoints
21
  ARTIFACT_URL = os.getenv("ARTIFACT_URL", "https://console.gmicloud.ai/api/v1/ie/artifact/get_public_artifacts")
 
22
 
23
- # Retry logic configuration
 
 
 
 
 
 
 
 
 
 
 
24
  MAX_RETRIES = int(os.getenv("MAX_RETRIES", "5"))
25
- DEFAULT_RETRY_CODES = "429,500,502,503,504"
26
- RETRY_CODES_STR = os.getenv("RETRY_CODES", DEFAULT_RETRY_CODES)
27
- try:
28
- RETRY_STATUS_CODES = {int(code.strip()) for code in RETRY_CODES_STR.split(',')}
29
- logger.info(f"Will retry on the following status codes: {RETRY_STATUS_CODES}")
30
- except ValueError:
31
- logger.error(f"Invalid RETRY_CODES format: '{RETRY_CODES_STR}'. Falling back to default: {DEFAULT_RETRY_CODES}")
32
- RETRY_STATUS_CODES = {int(code.strip()) for code in DEFAULT_RETRY_CODES.split(',')}
33
-
34
- # --- Helper Functions ---
35
-
36
- def generate_random_ip():
37
- """Generates a random, valid-looking IPv4 address."""
38
- return ".".join(str(random.randint(1, 254)) for _ in range(4))
39
-
40
- async def fetch_and_cache_models(app: FastAPI):
41
  """
42
- Fetches the list of public artifacts and caches a routing table.
43
- This runs once on application startup.
44
  """
45
- logger.info(f"Fetching model artifacts from: {ARTIFACT_URL}")
46
- model_routing_table = {}
 
 
 
 
 
 
 
 
47
  try:
48
- async with httpx.AsyncClient() as client:
49
- response = await client.get(ARTIFACT_URL, timeout=30.0)
50
- response.raise_for_status()
51
- artifacts = response.json()
52
-
53
- for artifact in artifacts:
54
- model_name = artifact.get("artifact_metadata", {}).get("artifact_name")
55
- endpoints = artifact.get("endpoints", [])
56
-
57
- # We only care about models that have a running endpoint
58
- if model_name and endpoints:
59
- # A model could have multiple endpoints, we'll just use the first one
60
- # A more advanced setup could load-balance between them
61
- endpoint_url = endpoints[0].get("endpoint_url")
62
- if endpoint_url:
63
- model_routing_table[model_name] = endpoint_url
64
-
65
- if not model_routing_table:
66
- logger.warning("No active model endpoints found from artifact URL.")
67
- else:
68
- logger.info(f"Successfully loaded {len(model_routing_table)} active models.")
69
- for name, url in model_routing_table.items():
70
- logger.debug(f" - Model: '{name}' -> Endpoint: '{url}'")
71
-
72
- except httpx.RequestError as e:
73
- logger.critical(f"Failed to fetch model artifacts on startup: {e}")
74
- # In a real-world scenario, you might want the app to fail starting
75
- # or handle this more gracefully. For now, we start with an empty table.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  except Exception as e:
77
- logger.critical(f"An unexpected error occurred during model fetching: {e}")
78
-
79
- app.state.model_routing_table = model_routing_table
 
 
 
 
 
 
 
 
 
 
80
 
81
 
82
- # --- HTTPX Client Lifecycle Management ---
83
 
84
  @asynccontextmanager
85
  async def lifespan(app: FastAPI):
86
  """Manages the app's lifecycle for startup and shutdown."""
87
- # Create a single, long-lived HTTP client for forwarding requests
88
- # No base_url as we will be calling different hosts dynamically
89
- async with httpx.AsyncClient(timeout=None) as client:
90
- app.state.http_client = client
91
- # Fetch and cache model routes on startup
92
- await fetch_and_cache_models(app)
93
- yield
94
  logger.info("Application shutdown complete.")
95
 
96
- # Initialize the FastAPI app with the lifespan manager and disabled docs
97
  app = FastAPI(docs_url=None, redoc_url=None, lifespan=lifespan)
98
 
99
  # --- API Endpoints ---
@@ -103,36 +135,29 @@ async def health_check():
103
  """Provides a basic health check endpoint."""
104
  return JSONResponse({
105
  "status": "ok",
106
- "active_models": len(app.state.model_routing_table)
107
  })
108
 
109
- @app.get("/v1/models")
110
- async def list_models(request: Request):
111
  """
112
- Lists all available models discovered at startup.
113
- Formatted to be compatible with the OpenAI API.
114
  """
115
- model_routing_table = request.app.state.model_routing_table
116
  model_list = [
117
- {
118
- "id": model_id,
119
- "object": "model",
120
- "created": int(time.time()),
121
- "owned_by": "gmi-serving",
122
- }
123
- for model_id in model_routing_table.keys()
124
  ]
125
  return JSONResponse(content={"object": "list", "data": model_list})
126
 
127
-
128
- @app.post("/v1/chat/completions")
129
  async def chat_completions_proxy(request: Request):
130
  """
131
  Forwards chat completion requests to the correct model endpoint.
 
132
  """
133
  start_time = time.monotonic()
134
 
135
- # --- 1. Get Model Name and Find Target Host ---
136
  body = await request.body()
137
  try:
138
  data = json.loads(body)
@@ -142,80 +167,45 @@ async def chat_completions_proxy(request: Request):
142
  except json.JSONDecodeError:
143
  raise HTTPException(status_code=400, detail="Invalid JSON in request body.")
144
 
145
- model_routing_table = request.app.state.model_routing_table
146
- target_host = model_routing_table.get(model_name)
147
-
148
  if not target_host:
149
  raise HTTPException(
150
  status_code=404,
151
- detail=f"Model '{model_name}' not found or is not currently active."
152
  )
153
 
154
- # --- 2. Prepare and Forward the Request ---
155
  client: httpx.AsyncClient = request.app.state.http_client
156
-
157
- # Construct the full URL to the backend service
158
  target_url = f"https://{target_host}{request.url.path}"
159
 
160
- request_headers = dict(request.headers)
161
- request_headers.pop("host", None)
162
-
163
- random_ip = generate_random_ip()
164
- spoof_headers = {
165
  "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/141.0.0.0 Safari/537.36",
166
- "x-forwarded-for": random_ip,
167
- "x-real-ip": random_ip,
168
- }
169
- request_headers.update(spoof_headers)
170
-
171
- logger.info(
172
- f"Routing request for model '{model_name}' to {target_url} "
173
- f"(Client: '{request.client.host}', Spoofed IP: {random_ip})"
174
- )
175
 
176
- # --- 3. Execute with Retry Logic ---
177
- last_exception = None
178
  for attempt in range(MAX_RETRIES):
179
  try:
180
- rp_req = client.build_request(
181
- method=request.method, url=target_url, headers=request_headers, content=body
182
- )
183
- rp_resp = await client.send(rp_req, stream=True)
184
 
185
- # If status is not retryable OR it's the last attempt, stream the response
186
- if rp_resp.status_code not in RETRY_STATUS_CODES or attempt == MAX_RETRIES - 1:
187
  duration_ms = (time.monotonic() - start_time) * 1000
188
- log_func = logger.info if rp_resp.is_success else logger.warning
189
- log_func(f"Request finished for '{model_name}': {request.method} {request.url.path} status_code={rp_resp.status_code} latency={duration_ms:.2f}ms")
190
-
191
- return StreamingResponse(
192
- rp_resp.aiter_raw(),
193
- status_code=rp_resp.status_code,
194
- headers=rp_resp.headers,
195
- background=BackgroundTask(rp_resp.aclose),
196
- )
197
-
198
- # Otherwise, log and prepare for retry
199
- logger.warning(
200
- f"Attempt {attempt + 1}/{MAX_RETRIES} for '{model_name}' failed with status {rp_resp.status_code}. Retrying..."
201
- )
202
- await rp_resp.aclose() # Ensure the connection is closed before retrying
203
- await asyncio.sleep(1 * (2 ** attempt)) # Exponential backoff
204
-
205
- except httpx.ConnectError as e:
206
- last_exception = e
207
- logger.warning(f"Attempt {attempt + 1}/{MAX_RETRIES} for '{model_name}' failed with connection error: {e}")
208
-
209
- except Exception as e:
210
- last_exception = e
211
- logger.error(f"An unexpected error occurred during request forwarding: {e}")
212
- break # Don't retry on unexpected errors
213
 
214
- # --- 4. Handle Final Failure ---
215
- duration_ms = (time.monotonic() - start_time) * 1000
216
- logger.critical(f"Request failed for model '{model_name}' after {MAX_RETRIES} attempts. Cannot connect to target: {target_url}. Latency: {duration_ms:.2f}ms")
 
 
 
 
 
217
 
218
- raise HTTPException(
219
- status_code=502,
220
- detail=f"Bad Gateway: Cannot connect to model backend for '{model_name}' after {MAX_RETRIES} attempts. Last error: {last_exception}"
221
- )
 
1
  import httpx
2
+ from fastapi import FastAPI, Request, HTTPException, Depends
3
  from starlette.responses import StreamingResponse, JSONResponse
4
  from starlette.background import BackgroundTask
5
  import os
 
7
  import logging
8
  import time
9
  import json
10
+ import asyncio
11
  from contextlib import asynccontextmanager
12
+ from filelock import FileLock, Timeout
13
 
14
  # --- Production-Ready Configuration ---
15
  LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
16
  logging.basicConfig(
17
  level=LOG_LEVEL,
18
+ format='%(asctime)s - PID:%(process)d - %(name)s - %(levelname)s - %(message)s'
19
  )
20
  logger = logging.getLogger(__name__)
21
 
22
+ # --- Service Configuration ---
23
  ARTIFACT_URL = os.getenv("ARTIFACT_URL", "https://console.gmicloud.ai/api/v1/ie/artifact/get_public_artifacts")
24
+ REFRESH_INTERVAL_SECONDS = int(os.getenv("REFRESH_INTERVAL_SECONDS", "30"))
25
 
26
+ # --- Shared Cache File Configuration ---
27
+ # Using /dev/shm is faster as it's a RAM disk on Linux. Fallback to /tmp.
28
+ CACHE_DIR = "/dev/shm" if os.path.exists("/dev/shm") else "/tmp"
29
+ CACHE_FILE_PATH = os.path.join(CACHE_DIR, "gmi_routing_table.json")
30
+ LOCK_FILE_PATH = os.path.join(CACHE_DIR, "gmi_routing_table.lock")
31
+
32
+ # --- In-Memory State for each Worker ---
33
+ # These are global variables *per worker process*.
34
+ worker_model_routing_table = {}
35
+ last_cache_check_time = 0
36
+
37
+ # --- Retry Logic ---
38
  MAX_RETRIES = int(os.getenv("MAX_RETRIES", "5"))
39
+ RETRY_STATUS_CODES = {429, 500, 502, 503, 504}
40
+
41
+ # --- Core Caching and Refreshing Logic ---
42
+
43
+ async def load_or_refresh_models():
 
 
 
 
 
 
 
 
 
 
 
44
  """
45
+ Checks if the shared cache is stale. If so, attempts to acquire a lock
46
+ and refresh it. This is designed to be safe for multiple processes.
47
  """
48
+ global last_cache_check_time, worker_model_routing_table
49
+
50
+ now = time.monotonic()
51
+ # 1. Quick check: If in-memory cache is fresh, do nothing.
52
+ if (now - last_cache_check_time) < REFRESH_INTERVAL_SECONDS:
53
+ return
54
+
55
+ # 2. In-memory cache is stale, acquire a lock to check the shared file cache.
56
+ # The lock prevents all workers from hitting the API at once.
57
+ lock = FileLock(LOCK_FILE_PATH)
58
  try:
59
+ with lock.acquire(timeout=5): # Wait max 5s for the lock
60
+ # Re-check inside the lock, another process might have just updated the file.
61
+ if os.path.exists(CACHE_FILE_PATH):
62
+ mtime = os.path.getmtime(CACHE_FILE_PATH)
63
+ if (time.time() - mtime) < REFRESH_INTERVAL_SECONDS:
64
+ # File is fresh, just load it into this worker's memory
65
+ with open(CACHE_FILE_PATH, 'r') as f:
66
+ worker_model_routing_table = json.load(f)
67
+ last_cache_check_time = now
68
+ logger.info(f"Loaded fresh cache from file. {len(worker_model_routing_table)} models.")
69
+ return
70
+
71
+ # 3. We have the lock and the file cache is stale. This worker will be the updater.
72
+ logger.warning("Cache is stale. This worker is refreshing the model list...")
73
+ try:
74
+ async with httpx.AsyncClient() as client:
75
+ response = await client.get(ARTIFACT_URL, timeout=30.0)
76
+ response.raise_for_status()
77
+ artifacts = response.json()
78
+
79
+ new_routing_table = {}
80
+ for artifact in artifacts:
81
+ model_name = artifact.get("artifact_metadata", {}).get("artifact_name")
82
+ endpoints = artifact.get("endpoints", [])
83
+ if model_name and endpoints and endpoints[0].get("endpoint_url"):
84
+ new_routing_table[model_name] = endpoints[0]["endpoint_url"]
85
+
86
+ # Write to a temporary file and then atomically rename it
87
+ temp_path = CACHE_FILE_PATH + f".{os.getpid()}"
88
+ with open(temp_path, 'w') as f:
89
+ json.dump(new_routing_table, f)
90
+ os.rename(temp_path, CACHE_FILE_PATH)
91
+
92
+ worker_model_routing_table = new_routing_table
93
+ logger.info(f"Successfully refreshed cache file with {len(worker_model_routing_table)} models.")
94
+
95
+ except Exception as e:
96
+ logger.error(f"Failed to refresh model cache: {e}. Will use stale data if available.")
97
+
98
+ except Timeout:
99
+ logger.warning("Could not acquire lock to refresh cache, another process is likely updating. Reading from file.")
100
+
101
  except Exception as e:
102
+ logger.error(f"An unexpected error occurred in cache management: {e}")
103
+
104
+ finally:
105
+ # 4. Ensure this worker's memory is up-to-date from the file,
106
+ # especially if it failed to get the lock or an error occurred.
107
+ if os.path.exists(CACHE_FILE_PATH):
108
+ try:
109
+ with open(CACHE_FILE_PATH, 'r') as f:
110
+ worker_model_routing_table = json.load(f)
111
+ except (json.JSONDecodeError, FileNotFoundError):
112
+ logger.error("Could not read cache file. Routing table may be empty.")
113
+
114
+ last_cache_check_time = now
115
 
116
 
117
+ # --- FastAPI Lifecycle & App Initialization ---
118
 
119
  @asynccontextmanager
120
  async def lifespan(app: FastAPI):
121
  """Manages the app's lifecycle for startup and shutdown."""
122
+ app.state.http_client = httpx.AsyncClient(timeout=None)
123
+ # Perform an initial fetch on startup for the first worker that starts.
124
+ await load_or_refresh_models()
125
+ yield
126
+ await app.state.http_client.aclose()
 
 
127
  logger.info("Application shutdown complete.")
128
 
 
129
  app = FastAPI(docs_url=None, redoc_url=None, lifespan=lifespan)
130
 
131
  # --- API Endpoints ---
 
135
  """Provides a basic health check endpoint."""
136
  return JSONResponse({
137
  "status": "ok",
138
+ "active_models_in_memory": len(worker_model_routing_table)
139
  })
140
 
141
+ @app.get("/v1/models", dependencies=[Depends(load_or_refresh_models)])
142
+ async def list_models():
143
  """
144
+ Lists all available models from the worker's in-memory cache.
145
+ The dependency ensures the cache is checked for freshness before responding.
146
  """
 
147
  model_list = [
148
+ { "id": model_id, "object": "model", "owned_by": "gmi-serving" }
149
+ for model_id in worker_model_routing_table.keys()
 
 
 
 
 
150
  ]
151
  return JSONResponse(content={"object": "list", "data": model_list})
152
 
153
+ @app.post("/v1/chat/completions", dependencies=[Depends(load_or_refresh_models)])
 
154
  async def chat_completions_proxy(request: Request):
155
  """
156
  Forwards chat completion requests to the correct model endpoint.
157
+ The dependency ensures the cache is checked for freshness before routing.
158
  """
159
  start_time = time.monotonic()
160
 
 
161
  body = await request.body()
162
  try:
163
  data = json.loads(body)
 
167
  except json.JSONDecodeError:
168
  raise HTTPException(status_code=400, detail="Invalid JSON in request body.")
169
 
170
+ target_host = worker_model_routing_table.get(model_name)
 
 
171
  if not target_host:
172
  raise HTTPException(
173
  status_code=404,
174
+ detail=f"Model '{model_name}' not found. It may be inactive or does not exist. Please check /v1/models."
175
  )
176
 
 
177
  client: httpx.AsyncClient = request.app.state.http_client
 
 
178
  target_url = f"https://{target_host}{request.url.path}"
179
 
180
+ # --- Prepare and Forward Request (logic is the same as before) ---
181
+ request_headers = {k: v for k, v in request.headers.items() if k.lower() != 'host'}
182
+ random_ip = ".".join(str(random.randint(1, 254)) for _ in range(4))
183
+ request_headers.update({
 
184
  "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/141.0.0.0 Safari/537.36",
185
+ "x-forwarded-for": random_ip, "x-real-ip": random_ip
186
+ })
187
+
188
+ logger.info(f"Routing '{model_name}' to {target_url} (Client: {request.client.host})")
 
 
 
 
 
189
 
 
 
190
  for attempt in range(MAX_RETRIES):
191
  try:
192
+ req = client.build_request(method=request.method, url=target_url, headers=request_headers, content=body)
193
+ resp = await client.send(req, stream=True)
 
 
194
 
195
+ if resp.status_code not in RETRY_STATUS_CODES or attempt == MAX_RETRIES - 1:
 
196
  duration_ms = (time.monotonic() - start_time) * 1000
197
+ log_func = logger.info if resp.is_success else logger.warning
198
+ log_func(f"Request finished for '{model_name}': status={resp.status_code} latency={duration_ms:.2f}ms")
199
+ return StreamingResponse(resp.aiter_raw(), status_code=resp.status_code, headers=resp.headers, background=BackgroundTask(resp.aclose))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
+ logger.warning(f"Attempt {attempt + 1}/{MAX_RETRIES} for '{model_name}' failed with status {resp.status_code}. Retrying...")
202
+ await resp.aclose()
203
+ await asyncio.sleep(0.5 * (2 ** attempt))
204
+
205
+ except Exception as e:
206
+ logger.error(f"Request forwarding failed for '{model_name}' on attempt {attempt + 1}: {e}")
207
+ if attempt == MAX_RETRIES - 1:
208
+ raise HTTPException(status_code=502, detail=f"Bad Gateway: Error connecting to model backend. {e}")
209
 
210
+ # This part should ideally not be reached, but as a fallback:
211
+ raise HTTPException(status_code=502, detail="Bad Gateway: Request failed after all retries.")