rkihacker commited on
Commit
41b75fe
·
verified ·
1 Parent(s): 0bc619c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +89 -165
main.py CHANGED
@@ -1,202 +1,126 @@
1
  import httpx
2
- from fastapi import FastAPI, Request, HTTPException, Depends
3
  from starlette.responses import StreamingResponse, JSONResponse
4
  from starlette.background import BackgroundTask
5
  import os
6
  import random
7
  import logging
8
  import time
9
- import json
10
- import asyncio
11
  from contextlib import asynccontextmanager
12
- from filelock import FileLock, Timeout
13
 
14
  # --- Production-Ready Configuration ---
15
  LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
16
  logging.basicConfig(
17
  level=LOG_LEVEL,
18
- format='%(asctime)s - PID:%(process)d - %(name)s - %(levelname)s - %(message)s'
19
  )
20
- logger = logging.getLogger(__name__)
21
-
22
- # --- Service Configuration ---
23
- ARTIFACT_URL = os.getenv("ARTIFACT_URL", "https://console.gmicloud.ai/api/v1/ie/artifact/get_public_artifacts")
24
- REFRESH_INTERVAL_SECONDS = int(os.getenv("REFRESH_INTERVAL_SECONDS", "30"))
25
-
26
- # --- Shared Cache File Configuration ---
27
- CACHE_DIR = "/dev/shm" if os.path.exists("/dev/shm") else "/tmp"
28
- CACHE_FILE_PATH = os.path.join(CACHE_DIR, "gmi_routing_table.json")
29
- LOCK_FILE_PATH = os.path.join(CACHE_DIR, "gmi_routing_table.lock")
30
-
31
- # --- In-Memory State for each Worker ---
32
- worker_model_routing_table = {}
33
- last_cache_check_time = 0
34
-
35
- # --- Retry Logic ---
36
- MAX_RETRIES = int(os.getenv("MAX_RETRIES", "5"))
37
- RETRY_STATUS_CODES = {429, 500, 502, 503, 504}
38
-
39
- # --- Core Caching and Refreshing Logic ---
40
-
41
- async def load_or_refresh_models():
42
- """
43
- Checks if the shared cache is stale. If so, attempts to acquire a lock
44
- and refresh it. This is designed to be safe for multiple processes.
45
- """
46
- global last_cache_check_time, worker_model_routing_table
47
-
48
- now = time.monotonic()
49
- if (now - last_cache_check_time) < REFRESH_INTERVAL_SECONDS:
50
- return
51
-
52
- lock = FileLock(LOCK_FILE_PATH)
53
- try:
54
- with lock.acquire(timeout=5):
55
- if os.path.exists(CACHE_FILE_PATH):
56
- mtime = os.path.getmtime(CACHE_FILE_PATH)
57
- if (time.time() - mtime) < REFRESH_INTERVAL_SECONDS:
58
- with open(CACHE_FILE_PATH, 'r') as f:
59
- worker_model_routing_table = json.load(f)
60
- last_cache_check_time = now
61
- logger.debug(f"Loaded fresh cache from file. {len(worker_model_routing_table)} models.")
62
- return
63
-
64
- logger.warning("Cache is stale. This worker is refreshing the model list from the API...")
65
- try:
66
- async with httpx.AsyncClient() as client:
67
- response = await client.get(ARTIFACT_URL, timeout=30.0)
68
- response.raise_for_status()
69
- artifacts = response.json()
70
-
71
- new_routing_table = {}
72
- for artifact in artifacts:
73
- # CORRECTLY get the API model identifier from `model_price.modelName`
74
- model_price_info = artifact.get("model_price")
75
- api_model_id = model_price_info.get("modelName") if model_price_info else None
76
-
77
- display_name = artifact.get("artifact_metadata", {}).get("artifact_name", "Unknown")
78
- endpoints = artifact.get("endpoints", [])
79
-
80
- # Condition: Must have a valid API model ID and at least one active endpoint URL
81
- if api_model_id and endpoints and endpoints[0].get("endpoint_url"):
82
- endpoint_url = endpoints[0]["endpoint_url"]
83
- new_routing_table[api_model_id] = endpoint_url
84
- logger.debug(f"Mapped model ID '{api_model_id}' to endpoint '{endpoint_url}'")
85
- else:
86
- logger.debug(f"Skipping model '{display_name}': Missing API model ID or active endpoint.")
87
-
88
- temp_path = CACHE_FILE_PATH + f".{os.getpid()}"
89
- with open(temp_path, 'w') as f:
90
- json.dump(new_routing_table, f)
91
- os.rename(temp_path, CACHE_FILE_PATH)
92
-
93
- worker_model_routing_table = new_routing_table
94
- logger.info(f"Successfully refreshed cache file with {len(worker_model_routing_table)} active models.")
95
-
96
- except Exception as e:
97
- logger.error(f"Failed to refresh model cache: {e}. Will use stale data if available.")
98
-
99
- except Timeout:
100
- logger.warning("Could not acquire lock, another process is updating. Reading from file.")
101
- except Exception as e:
102
- logger.error(f"An unexpected error occurred in cache management: {e}")
103
- finally:
104
- if os.path.exists(CACHE_FILE_PATH):
105
- try:
106
- with open(CACHE_FILE_PATH, 'r') as f:
107
- worker_model_routing_table = json.load(f)
108
- except (json.JSONDecodeError, FileNotFoundError):
109
- logger.error("Could not read cache file. Routing table may be empty.")
110
- last_cache_check_time = now
111
-
112
-
113
- # --- FastAPI Lifecycle & App Initialization ---
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  @asynccontextmanager
116
  async def lifespan(app: FastAPI):
117
- app.state.http_client = httpx.AsyncClient(timeout=None)
118
- await load_or_refresh_models()
119
- yield
120
- await app.state.http_client.aclose()
121
- logger.info("Application shutdown complete.")
122
 
 
123
  app = FastAPI(docs_url=None, redoc_url=None, lifespan=lifespan)
124
 
125
  # --- API Endpoints ---
126
 
 
 
127
  @app.get("/")
128
  async def health_check():
129
- return JSONResponse({
130
- "status": "ok",
131
- "active_models_in_memory": len(worker_model_routing_table)
132
- })
133
-
134
- @app.get("/v1/models", dependencies=[Depends(load_or_refresh_models)])
135
- async def list_models():
 
136
  """
137
- Lists all available models from the worker's in-memory cache.
138
- The dependency ensures the cache is checked for freshness before responding.
139
- """
140
- model_list = [
141
- { "id": model_id, "object": "model", "owned_by": "gmi-serving" }
142
- for model_id in sorted(worker_model_routing_table.keys()) # Sort for consistency
143
- ]
144
- return JSONResponse(content={"object": "list", "data": model_list})
145
-
146
- @app.post("/v1/chat/completions", dependencies=[Depends(load_or_refresh_models)])
147
- async def chat_completions_proxy(request: Request):
148
- """
149
- Forwards chat completion requests to the correct model endpoint.
150
- The dependency ensures the cache is checked for freshness before routing.
151
  """
152
  start_time = time.monotonic()
153
 
154
- body = await request.body()
155
- try:
156
- data = json.loads(body)
157
- model_name = data.get("model")
158
- if not model_name:
159
- raise HTTPException(status_code=400, detail="Missing 'model' field in request body.")
160
- except json.JSONDecodeError:
161
- raise HTTPException(status_code=400, detail="Invalid JSON in request body.")
162
-
163
- target_host = worker_model_routing_table.get(model_name)
164
- if not target_host:
165
- raise HTTPException(
166
- status_code=404,
167
- detail=f"Model '{model_name}' not found. It may be inactive or does not exist. Please check /v1/models."
168
- )
169
-
170
  client: httpx.AsyncClient = request.app.state.http_client
171
- target_url = f"https://{target_host}{request.url.path}"
172
-
173
- request_headers = {k: v for k, v in request.headers.items() if k.lower() != 'host'}
174
- random_ip = ".".join(str(random.randint(1, 254)) for _ in range(4))
175
- request_headers.update({
 
 
 
 
176
  "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/141.0.0.0 Safari/537.36",
177
- "x-forwarded-for": random_ip, "x-real-ip": random_ip
178
- })
 
 
 
 
 
 
 
 
 
 
179
 
180
- logger.info(f"Routing '{model_name}' to {target_url} (Client: {request.client.host})")
181
 
 
182
  for attempt in range(MAX_RETRIES):
183
  try:
184
- req = client.build_request(method=request.method, url=target_url, headers=request_headers, content=body)
185
- resp = await client.send(req, stream=True)
 
 
186
 
187
- if resp.status_code not in RETRY_STATUS_CODES or attempt == MAX_RETRIES - 1:
188
  duration_ms = (time.monotonic() - start_time) * 1000
189
- log_func = logger.info if resp.is_success else logger.warning
190
- log_func(f"Request finished for '{model_name}': status={resp.status_code} latency={duration_ms:.2f}ms")
191
- return StreamingResponse(resp.aiter_raw(), status_code=resp.status_code, headers=resp.headers, background=BackgroundTask(resp.aclose))
192
-
193
- logger.warning(f"Attempt {attempt + 1}/{MAX_RETRIES} for '{model_name}' failed with status {resp.status_code}. Retrying...")
194
- await resp.aclose()
195
- await asyncio.sleep(0.5 * (2 ** attempt))
196
-
197
- except Exception as e:
198
- logger.error(f"Request forwarding failed for '{model_name}' on attempt {attempt + 1}: {e}")
199
- if attempt == MAX_RETRIES - 1:
200
- raise HTTPException(status_code=502, detail=f"Bad Gateway: Error connecting to model backend. {e}")
 
 
 
 
 
 
 
 
 
201
 
202
- raise HTTPException(status_code=502, detail="Bad Gateway: Request failed after all retries.")
 
 
 
 
1
  import httpx
2
+ from fastapi import FastAPI, Request, HTTPException
3
  from starlette.responses import StreamingResponse, JSONResponse
4
  from starlette.background import BackgroundTask
5
  import os
6
  import random
7
  import logging
8
  import time
 
 
9
  from contextlib import asynccontextmanager
 
10
 
11
  # --- Production-Ready Configuration ---
12
  LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
13
  logging.basicConfig(
14
  level=LOG_LEVEL,
15
+ format='%(asctime)s - %(levelname)s - %(message)s'
16
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ TARGET_URL = os.getenv("TARGET_URL", "https://api.gmi-serving.com")
19
+ MAX_RETRIES = int(os.getenv("MAX_RETRIES", "15"))
20
+ DEFAULT_RETRY_CODES = "429,500,502,503,504"
21
+ RETRY_CODES_STR = os.getenv("RETRY_CODES", DEFAULT_RETRY_CODES)
22
+ try:
23
+ RETRY_STATUS_CODES = {int(code.strip()) for code in RETRY_CODES_STR.split(',')}
24
+ logging.info(f"Will retry on the following status codes: {RETRY_STATUS_CODES}")
25
+ except ValueError:
26
+ logging.error(f"Invalid RETRY_CODES format: '{RETRY_CODES_STR}'. Falling back to default: {DEFAULT_RETRY_CODES}")
27
+ RETRY_STATUS_CODES = {int(code.strip()) for code in DEFAULT_RETRY_CODES.split(',')}
28
+
29
+ # --- Helper Function ---
30
+ def generate_random_ip():
31
+ """Generates a random, valid-looking IPv4 address."""
32
+ return ".".join(str(random.randint(1, 254)) for _ in range(4))
33
+
34
+ # --- HTTPX Client Lifecycle Management ---
35
  @asynccontextmanager
36
  async def lifespan(app: FastAPI):
37
+ """Manages the lifecycle of the HTTPX client."""
38
+ async with httpx.AsyncClient(base_url=TARGET_URL, timeout=None) as client:
39
+ app.state.http_client = client
40
+ yield
 
41
 
42
+ # Initialize the FastAPI app with the lifespan manager and disabled docs
43
  app = FastAPI(docs_url=None, redoc_url=None, lifespan=lifespan)
44
 
45
  # --- API Endpoints ---
46
 
47
+ # 1. Health Check Route (Defined FIRST)
48
+ # This specific route will be matched before the catch-all proxy route.
49
  @app.get("/")
50
  async def health_check():
51
+ """Provides a basic health check endpoint."""
52
+ return JSONResponse({"status": "ok", "target": TARGET_URL})
53
+
54
+ # 2. Catch-All Reverse Proxy Route (Defined SECOND)
55
+ # This will capture ALL other requests (e.g., /completions, /v1/models, etc.)
56
+ # and forward them. This eliminates any redirect issues.
57
+ @app.api_route("/{full_path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
58
+ async def reverse_proxy_handler(request: Request):
59
  """
60
+ A catch-all reverse proxy that forwards requests to the target URL with
61
+ enhanced retry logic and latency logging.
 
 
 
 
 
 
 
 
 
 
 
 
62
  """
63
  start_time = time.monotonic()
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  client: httpx.AsyncClient = request.app.state.http_client
66
+ url = httpx.URL(path=request.url.path, query=request.url.query.encode("utf-8"))
67
+
68
+ request_headers = dict(request.headers)
69
+ request_headers.pop("host", None)
70
+
71
+ random_ip = generate_random_ip()
72
+ logging.info(f"Client '{request.client.host}' proxied with spoofed IP: {random_ip} for path: {url.path}")
73
+
74
+ specific_headers = {
75
  "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/141.0.0.0 Safari/537.36",
76
+ "x-forwarded-for": random_ip,
77
+ "x-real-ip": random_ip,
78
+ "x-originating-ip": random_ip,
79
+ "x-remote-ip": random_ip,
80
+ "x-remote-addr": random_ip,
81
+ "x-host": random_ip,
82
+ "x-forwarded-host": random_ip,
83
+ }
84
+ request_headers.update(specific_headers)
85
+
86
+ if "authorization" in request.headers:
87
+ request_headers["authorization"] = request.headers["authorization"]
88
 
89
+ body = await request.body()
90
 
91
+ last_exception = None
92
  for attempt in range(MAX_RETRIES):
93
  try:
94
+ rp_req = client.build_request(
95
+ method=request.method, url=url, headers=request_headers, content=body
96
+ )
97
+ rp_resp = await client.send(rp_req, stream=True)
98
 
99
+ if rp_resp.status_code not in RETRY_STATUS_CODES or attempt == MAX_RETRIES - 1:
100
  duration_ms = (time.monotonic() - start_time) * 1000
101
+ log_func = logging.info if rp_resp.is_success else logging.warning
102
+ log_func(f"Request finished: {request.method} {request.url.path} status_code={rp_resp.status_code} latency={duration_ms:.2f}ms")
103
+
104
+ return StreamingResponse(
105
+ rp_resp.aiter_raw(),
106
+ status_code=rp_resp.status_code,
107
+ headers=rp_resp.headers,
108
+ background=BackgroundTask(rp_resp.aclose),
109
+ )
110
+
111
+ logging.warning(
112
+ f"Attempt {attempt + 1}/{MAX_RETRIES} for {url.path} failed with status {rp_resp.status_code}. Retrying..."
113
+ )
114
+ await rp_resp.aclose()
115
+
116
+ except httpx.ConnectError as e:
117
+ last_exception = e
118
+ logging.warning(f"Attempt {attempt + 1}/{MAX_RETRIES} for {url.path} failed with connection error: {e}")
119
+
120
+ duration_ms = (time.monotonic() - start_time) * 1000
121
+ logging.critical(f"Request failed, cannot connect to target: {request.method} {request.url.path} status_code=502 latency={duration_ms:.2f}ms")
122
 
123
+ raise HTTPException(
124
+ status_code=502,
125
+ detail=f"Bad Gateway: Cannot connect to target service after {MAX_RETRIES} attempts. {last_exception}"
126
+ )