rkihacker commited on
Commit
0bc619c
·
verified ·
1 Parent(s): f5fc07f

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +19 -28
main.py CHANGED
@@ -24,13 +24,11 @@ ARTIFACT_URL = os.getenv("ARTIFACT_URL", "https://console.gmicloud.ai/api/v1/ie/
24
  REFRESH_INTERVAL_SECONDS = int(os.getenv("REFRESH_INTERVAL_SECONDS", "30"))
25
 
26
  # --- Shared Cache File Configuration ---
27
- # Using /dev/shm is faster as it's a RAM disk on Linux. Fallback to /tmp.
28
  CACHE_DIR = "/dev/shm" if os.path.exists("/dev/shm") else "/tmp"
29
  CACHE_FILE_PATH = os.path.join(CACHE_DIR, "gmi_routing_table.json")
30
  LOCK_FILE_PATH = os.path.join(CACHE_DIR, "gmi_routing_table.lock")
31
 
32
  # --- In-Memory State for each Worker ---
33
- # These are global variables *per worker process*.
34
  worker_model_routing_table = {}
35
  last_cache_check_time = 0
36
 
@@ -48,28 +46,22 @@ async def load_or_refresh_models():
48
  global last_cache_check_time, worker_model_routing_table
49
 
50
  now = time.monotonic()
51
- # 1. Quick check: If in-memory cache is fresh, do nothing.
52
  if (now - last_cache_check_time) < REFRESH_INTERVAL_SECONDS:
53
  return
54
 
55
- # 2. In-memory cache is stale, acquire a lock to check the shared file cache.
56
- # The lock prevents all workers from hitting the API at once.
57
  lock = FileLock(LOCK_FILE_PATH)
58
  try:
59
- with lock.acquire(timeout=5): # Wait max 5s for the lock
60
- # Re-check inside the lock, another process might have just updated the file.
61
  if os.path.exists(CACHE_FILE_PATH):
62
  mtime = os.path.getmtime(CACHE_FILE_PATH)
63
  if (time.time() - mtime) < REFRESH_INTERVAL_SECONDS:
64
- # File is fresh, just load it into this worker's memory
65
  with open(CACHE_FILE_PATH, 'r') as f:
66
  worker_model_routing_table = json.load(f)
67
  last_cache_check_time = now
68
- logger.info(f"Loaded fresh cache from file. {len(worker_model_routing_table)} models.")
69
  return
70
 
71
- # 3. We have the lock and the file cache is stale. This worker will be the updater.
72
- logger.warning("Cache is stale. This worker is refreshing the model list...")
73
  try:
74
  async with httpx.AsyncClient() as client:
75
  response = await client.get(ARTIFACT_URL, timeout=30.0)
@@ -78,39 +70,43 @@ async def load_or_refresh_models():
78
 
79
  new_routing_table = {}
80
  for artifact in artifacts:
81
- model_name = artifact.get("artifact_metadata", {}).get("artifact_name")
 
 
 
 
82
  endpoints = artifact.get("endpoints", [])
83
- if model_name and endpoints and endpoints[0].get("endpoint_url"):
84
- new_routing_table[model_name] = endpoints[0]["endpoint_url"]
 
 
 
 
 
 
85
 
86
- # Write to a temporary file and then atomically rename it
87
  temp_path = CACHE_FILE_PATH + f".{os.getpid()}"
88
  with open(temp_path, 'w') as f:
89
  json.dump(new_routing_table, f)
90
  os.rename(temp_path, CACHE_FILE_PATH)
91
 
92
  worker_model_routing_table = new_routing_table
93
- logger.info(f"Successfully refreshed cache file with {len(worker_model_routing_table)} models.")
94
 
95
  except Exception as e:
96
  logger.error(f"Failed to refresh model cache: {e}. Will use stale data if available.")
97
 
98
  except Timeout:
99
- logger.warning("Could not acquire lock to refresh cache, another process is likely updating. Reading from file.")
100
-
101
  except Exception as e:
102
  logger.error(f"An unexpected error occurred in cache management: {e}")
103
-
104
  finally:
105
- # 4. Ensure this worker's memory is up-to-date from the file,
106
- # especially if it failed to get the lock or an error occurred.
107
  if os.path.exists(CACHE_FILE_PATH):
108
  try:
109
  with open(CACHE_FILE_PATH, 'r') as f:
110
  worker_model_routing_table = json.load(f)
111
  except (json.JSONDecodeError, FileNotFoundError):
112
  logger.error("Could not read cache file. Routing table may be empty.")
113
-
114
  last_cache_check_time = now
115
 
116
 
@@ -118,9 +114,7 @@ async def load_or_refresh_models():
118
 
119
  @asynccontextmanager
120
  async def lifespan(app: FastAPI):
121
- """Manages the app's lifecycle for startup and shutdown."""
122
  app.state.http_client = httpx.AsyncClient(timeout=None)
123
- # Perform an initial fetch on startup for the first worker that starts.
124
  await load_or_refresh_models()
125
  yield
126
  await app.state.http_client.aclose()
@@ -132,7 +126,6 @@ app = FastAPI(docs_url=None, redoc_url=None, lifespan=lifespan)
132
 
133
  @app.get("/")
134
  async def health_check():
135
- """Provides a basic health check endpoint."""
136
  return JSONResponse({
137
  "status": "ok",
138
  "active_models_in_memory": len(worker_model_routing_table)
@@ -146,7 +139,7 @@ async def list_models():
146
  """
147
  model_list = [
148
  { "id": model_id, "object": "model", "owned_by": "gmi-serving" }
149
- for model_id in worker_model_routing_table.keys()
150
  ]
151
  return JSONResponse(content={"object": "list", "data": model_list})
152
 
@@ -177,7 +170,6 @@ async def chat_completions_proxy(request: Request):
177
  client: httpx.AsyncClient = request.app.state.http_client
178
  target_url = f"https://{target_host}{request.url.path}"
179
 
180
- # --- Prepare and Forward Request (logic is the same as before) ---
181
  request_headers = {k: v for k, v in request.headers.items() if k.lower() != 'host'}
182
  random_ip = ".".join(str(random.randint(1, 254)) for _ in range(4))
183
  request_headers.update({
@@ -207,5 +199,4 @@ async def chat_completions_proxy(request: Request):
207
  if attempt == MAX_RETRIES - 1:
208
  raise HTTPException(status_code=502, detail=f"Bad Gateway: Error connecting to model backend. {e}")
209
 
210
- # This part should ideally not be reached, but as a fallback:
211
  raise HTTPException(status_code=502, detail="Bad Gateway: Request failed after all retries.")
 
24
  REFRESH_INTERVAL_SECONDS = int(os.getenv("REFRESH_INTERVAL_SECONDS", "30"))
25
 
26
  # --- Shared Cache File Configuration ---
 
27
  CACHE_DIR = "/dev/shm" if os.path.exists("/dev/shm") else "/tmp"
28
  CACHE_FILE_PATH = os.path.join(CACHE_DIR, "gmi_routing_table.json")
29
  LOCK_FILE_PATH = os.path.join(CACHE_DIR, "gmi_routing_table.lock")
30
 
31
  # --- In-Memory State for each Worker ---
 
32
  worker_model_routing_table = {}
33
  last_cache_check_time = 0
34
 
 
46
  global last_cache_check_time, worker_model_routing_table
47
 
48
  now = time.monotonic()
 
49
  if (now - last_cache_check_time) < REFRESH_INTERVAL_SECONDS:
50
  return
51
 
 
 
52
  lock = FileLock(LOCK_FILE_PATH)
53
  try:
54
+ with lock.acquire(timeout=5):
 
55
  if os.path.exists(CACHE_FILE_PATH):
56
  mtime = os.path.getmtime(CACHE_FILE_PATH)
57
  if (time.time() - mtime) < REFRESH_INTERVAL_SECONDS:
 
58
  with open(CACHE_FILE_PATH, 'r') as f:
59
  worker_model_routing_table = json.load(f)
60
  last_cache_check_time = now
61
+ logger.debug(f"Loaded fresh cache from file. {len(worker_model_routing_table)} models.")
62
  return
63
 
64
+ logger.warning("Cache is stale. This worker is refreshing the model list from the API...")
 
65
  try:
66
  async with httpx.AsyncClient() as client:
67
  response = await client.get(ARTIFACT_URL, timeout=30.0)
 
70
 
71
  new_routing_table = {}
72
  for artifact in artifacts:
73
+ # CORRECTLY get the API model identifier from `model_price.modelName`
74
+ model_price_info = artifact.get("model_price")
75
+ api_model_id = model_price_info.get("modelName") if model_price_info else None
76
+
77
+ display_name = artifact.get("artifact_metadata", {}).get("artifact_name", "Unknown")
78
  endpoints = artifact.get("endpoints", [])
79
+
80
+ # Condition: Must have a valid API model ID and at least one active endpoint URL
81
+ if api_model_id and endpoints and endpoints[0].get("endpoint_url"):
82
+ endpoint_url = endpoints[0]["endpoint_url"]
83
+ new_routing_table[api_model_id] = endpoint_url
84
+ logger.debug(f"Mapped model ID '{api_model_id}' to endpoint '{endpoint_url}'")
85
+ else:
86
+ logger.debug(f"Skipping model '{display_name}': Missing API model ID or active endpoint.")
87
 
 
88
  temp_path = CACHE_FILE_PATH + f".{os.getpid()}"
89
  with open(temp_path, 'w') as f:
90
  json.dump(new_routing_table, f)
91
  os.rename(temp_path, CACHE_FILE_PATH)
92
 
93
  worker_model_routing_table = new_routing_table
94
+ logger.info(f"Successfully refreshed cache file with {len(worker_model_routing_table)} active models.")
95
 
96
  except Exception as e:
97
  logger.error(f"Failed to refresh model cache: {e}. Will use stale data if available.")
98
 
99
  except Timeout:
100
+ logger.warning("Could not acquire lock, another process is updating. Reading from file.")
 
101
  except Exception as e:
102
  logger.error(f"An unexpected error occurred in cache management: {e}")
 
103
  finally:
 
 
104
  if os.path.exists(CACHE_FILE_PATH):
105
  try:
106
  with open(CACHE_FILE_PATH, 'r') as f:
107
  worker_model_routing_table = json.load(f)
108
  except (json.JSONDecodeError, FileNotFoundError):
109
  logger.error("Could not read cache file. Routing table may be empty.")
 
110
  last_cache_check_time = now
111
 
112
 
 
114
 
115
  @asynccontextmanager
116
  async def lifespan(app: FastAPI):
 
117
  app.state.http_client = httpx.AsyncClient(timeout=None)
 
118
  await load_or_refresh_models()
119
  yield
120
  await app.state.http_client.aclose()
 
126
 
127
  @app.get("/")
128
  async def health_check():
 
129
  return JSONResponse({
130
  "status": "ok",
131
  "active_models_in_memory": len(worker_model_routing_table)
 
139
  """
140
  model_list = [
141
  { "id": model_id, "object": "model", "owned_by": "gmi-serving" }
142
+ for model_id in sorted(worker_model_routing_table.keys()) # Sort for consistency
143
  ]
144
  return JSONResponse(content={"object": "list", "data": model_list})
145
 
 
170
  client: httpx.AsyncClient = request.app.state.http_client
171
  target_url = f"https://{target_host}{request.url.path}"
172
 
 
173
  request_headers = {k: v for k, v in request.headers.items() if k.lower() != 'host'}
174
  random_ip = ".".join(str(random.randint(1, 254)) for _ in range(4))
175
  request_headers.update({
 
199
  if attempt == MAX_RETRIES - 1:
200
  raise HTTPException(status_code=502, detail=f"Bad Gateway: Error connecting to model backend. {e}")
201
 
 
202
  raise HTTPException(status_code=502, detail="Bad Gateway: Request failed after all retries.")