nurulajt commited on
Commit
31f5cc4
·
verified ·
1 Parent(s): fd3e04f

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +143 -17
api.py CHANGED
@@ -32,11 +32,22 @@ app.add_middleware(
32
 
33
  MODELS = {}
34
  VOYAGE_API_KEY = os.environ.get('VOYAGE_API_KEY', '')
 
35
  API_KEY = os.environ.get('API_KEY', '')
36
  REQUIRE_API_KEY = os.environ.get('REQUIRE_API_KEY', 'false').lower() == 'true'
37
 
 
 
 
 
 
 
 
 
 
38
  security = HTTPBearer(auto_error=False)
39
  voyage_client = None
 
40
 
41
  logger.info(f"API Key authentication: {'ENABLED' if REQUIRE_API_KEY else 'DISABLED'}")
42
  if API_KEY:
@@ -54,25 +65,74 @@ if VOYAGE_API_KEY:
54
  except Exception as e:
55
  logger.warning(f"⚠️ Voyage AI initialization failed: {e}")
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def load_models():
58
- """Load embedding models on startup"""
 
 
59
  try:
60
  logger.info("Loading JobBERT-v2...")
61
- MODELS['jobbertv2'] = SentenceTransformer('TechWolf/JobBERT-v2')
62
  logger.info("✓ JobBERT-v2 loaded")
63
-
 
 
 
 
64
  logger.info("Loading JobBERT-v3...")
65
  MODELS['jobbertv3'] = SentenceTransformer('TechWolf/JobBERT-v3')
66
  logger.info("✓ JobBERT-v3 loaded")
67
-
 
 
 
 
68
  logger.info("Loading Jina AI embeddings-v3...")
69
  MODELS['jina'] = SentenceTransformer('jinaai/jina-embeddings-v3', trust_remote_code=True)
70
  logger.info("✓ Jina AI v3 loaded")
71
-
72
- logger.info("All models loaded successfully!")
73
  except Exception as e:
74
- logger.error(f"Error loading models: {e}")
75
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  async def verify_api_key(credentials: Optional[HTTPAuthorizationCredentials] = Security(security)):
78
  """Verify API key from Authorization header"""
@@ -105,6 +165,46 @@ def estimate_token_count(texts: List[str]) -> int:
105
  total_chars = sum(len(text) for text in texts)
106
  return max(1, total_chars // 4)
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  @app.on_event("startup")
109
  async def startup_event():
110
  load_models()
@@ -166,6 +266,7 @@ class HealthResponse(BaseModel):
166
  status: str
167
  models_loaded: List[str]
168
  voyage_available: bool
 
169
  api_key_required: bool
170
 
171
  @app.get("/", response_model=dict)
@@ -191,6 +292,7 @@ async def health():
191
  "status": "healthy",
192
  "models_loaded": models_loaded,
193
  "voyage_available": voyage_client is not None,
 
194
  "api_key_required": REQUIRE_API_KEY
195
  }
196
 
@@ -213,6 +315,7 @@ async def create_embeddings_elasticsearch(
213
  - `jobbertv2`: JobBERT-v2 (768-dim, job-specific)
214
  - `jobbertv3`: JobBERT-v3 (768-dim, job-specific, improved performance) - default
215
  - `jina`: Jina AI embeddings-v3 (1024-dim, general purpose)
 
216
  - `voyage`: Voyage AI (1024-dim, requires API key)
217
 
218
  **Jina AI Tasks (via query parameter):**
@@ -220,6 +323,10 @@ async def create_embeddings_elasticsearch(
220
  - `retrieval.passage`: For documents/passages
221
  - `text-matching`: For similarity matching (default)
222
 
 
 
 
 
223
  **Voyage AI Input Types (via query parameter):**
224
  - `document`: For documents/passages
225
  - `query`: For search queries
@@ -268,19 +375,23 @@ async def create_embeddings_elasticsearch(
268
  try:
269
  selected_model = MODELS[model_name]
270
 
271
- if model_name == "jina" and task:
 
 
 
 
272
  embeddings = selected_model.encode(
273
  texts,
274
  task=task,
275
  convert_to_numpy=True
276
  )
 
277
  else:
278
  embeddings = selected_model.encode(
279
  texts,
280
  convert_to_numpy=True
281
  )
282
-
283
- embeddings_list = embeddings.tolist()
284
 
285
  # Calculate token usage
286
  token_count = estimate_token_count(texts)
@@ -295,7 +406,8 @@ async def create_embeddings_elasticsearch(
295
  model_display_name = {
296
  "jobbertv2": "TechWolf/JobBERT-v2",
297
  "jobbertv3": "TechWolf/JobBERT-v3",
298
- "jina": "jina-embeddings-v3"
 
299
  }.get(model_name, model_name)
300
 
301
  return OpenAIEmbeddingResponse(
@@ -310,7 +422,7 @@ async def create_embeddings_elasticsearch(
310
  else:
311
  raise HTTPException(
312
  status_code=400,
313
- detail=f"Invalid model '{model_name}'. Choose from: jobbertv2, jobbertv3, jina, voyage"
314
  )
315
 
316
  @app.post("/embed/batch", response_model=BatchEmbeddingResponse)
@@ -325,6 +437,7 @@ async def create_embeddings_batch(
325
  - `jobbertv2`: JobBERT-v2 (768-dim, job-specific)
326
  - `jobbertv3`: JobBERT-v3 (768-dim, job-specific, improved performance)
327
  - `jina`: Jina AI embeddings-v3 (1024-dim, general purpose, supports task types)
 
328
  - `voyage`: Voyage AI (1024-dim, requires API key)
329
 
330
  **Jina AI Tasks:**
@@ -370,19 +483,23 @@ async def create_embeddings_batch(
370
  try:
371
  selected_model = MODELS[model_name]
372
 
373
- if model_name == "jina" and request.task:
 
 
 
 
374
  embeddings = selected_model.encode(
375
  request.texts,
376
  task=request.task,
377
  convert_to_numpy=True
378
  )
 
379
  else:
380
  embeddings = selected_model.encode(
381
  request.texts,
382
  convert_to_numpy=True
383
  )
384
-
385
- embeddings_list = embeddings.tolist()
386
  dimension = len(embeddings_list[0]) if embeddings_list else 0
387
 
388
  return BatchEmbeddingResponse(
@@ -397,7 +514,7 @@ async def create_embeddings_batch(
397
  else:
398
  raise HTTPException(
399
  status_code=400,
400
- detail=f"Invalid model '{model_name}'. Choose from: jobbertv2, jobbertv3, jina, voyage"
401
  )
402
 
403
  @app.get("/models")
@@ -426,6 +543,15 @@ async def list_models(authenticated: bool = Depends(verify_api_key)):
426
  "available": "jina" in MODELS,
427
  "tasks": ["retrieval.query", "retrieval.passage", "text-matching", "classification", "separation"]
428
  },
 
 
 
 
 
 
 
 
 
429
  "voyage": {
430
  "name": "voyage-3",
431
  "dimension": 1024,
 
32
 
33
  MODELS = {}
34
  VOYAGE_API_KEY = os.environ.get('VOYAGE_API_KEY', '')
35
+ FIREWORKS_API_KEY = os.environ.get('FIREWORKS_API_KEY', '')
36
  API_KEY = os.environ.get('API_KEY', '')
37
  REQUIRE_API_KEY = os.environ.get('REQUIRE_API_KEY', 'false').lower() == 'true'
38
 
39
+ # Set cache directories to writable location (important for Docker/HF Spaces)
40
+ os.environ['TRANSFORMERS_CACHE'] = os.environ.get('TRANSFORMERS_CACHE', '/tmp/transformers_cache')
41
+ os.environ['HF_HOME'] = os.environ.get('HF_HOME', '/tmp/huggingface')
42
+ os.environ['SENTENCE_TRANSFORMERS_HOME'] = os.environ.get('SENTENCE_TRANSFORMERS_HOME', '/tmp/sentence_transformers')
43
+
44
+ # Create cache directories if they don't exist
45
+ for cache_dir in [os.environ['TRANSFORMERS_CACHE'], os.environ['HF_HOME'], os.environ['SENTENCE_TRANSFORMERS_HOME']]:
46
+ os.makedirs(cache_dir, exist_ok=True)
47
+
48
  security = HTTPBearer(auto_error=False)
49
  voyage_client = None
50
+ fireworks_available = False
51
 
52
  logger.info(f"API Key authentication: {'ENABLED' if REQUIRE_API_KEY else 'DISABLED'}")
53
  if API_KEY:
 
65
  except Exception as e:
66
  logger.warning(f"⚠️ Voyage AI initialization failed: {e}")
67
 
68
+ if FIREWORKS_API_KEY:
69
+ try:
70
+ import requests
71
+ # Test Fireworks AI connection
72
+ test_response = requests.get(
73
+ "https://api.fireworks.ai/inference/v1/models",
74
+ headers={"Authorization": f"Bearer {FIREWORKS_API_KEY}"},
75
+ timeout=5
76
+ )
77
+ if test_response.status_code in [200, 401, 403]: # 401/403 means auth works, just list might be restricted
78
+ fireworks_available = True
79
+ logger.info("✓ Fireworks AI API key configured (Qwen3 available)")
80
+ else:
81
+ logger.warning(f"⚠️ Fireworks AI API key validation unclear (status: {test_response.status_code})")
82
+ # Still mark as available - the embeddings endpoint might work
83
+ fireworks_available = True
84
+ except ImportError:
85
+ logger.warning("⚠️ requests package not installed (needed for Fireworks AI)")
86
+ except Exception as e:
87
+ logger.warning(f"⚠️ Fireworks AI validation failed: {e}")
88
+ # Still mark as available if key is set
89
+ fireworks_available = True if FIREWORKS_API_KEY else False
90
+
91
  def load_models():
92
+ """Load embedding models on startup (gracefully handles failures)"""
93
+
94
+ # JobBERT-v2
95
  try:
96
  logger.info("Loading JobBERT-v2...")
97
+ # MODELS['jobbertv2'] = SentenceTransformer('TechWolf/JobBERT-v2')
98
  logger.info("✓ JobBERT-v2 loaded")
99
+ except Exception as e:
100
+ logger.warning(f"⚠️ JobBERT-v2 not loaded: {e}")
101
+
102
+ # JobBERT-v3
103
+ try:
104
  logger.info("Loading JobBERT-v3...")
105
  MODELS['jobbertv3'] = SentenceTransformer('TechWolf/JobBERT-v3')
106
  logger.info("✓ JobBERT-v3 loaded")
107
+ except Exception as e:
108
+ logger.warning(f"⚠️ JobBERT-v3 not loaded: {e}")
109
+
110
+ # Jina AI
111
+ try:
112
  logger.info("Loading Jina AI embeddings-v3...")
113
  MODELS['jina'] = SentenceTransformer('jinaai/jina-embeddings-v3', trust_remote_code=True)
114
  logger.info("✓ Jina AI v3 loaded")
 
 
115
  except Exception as e:
116
+ logger.warning(f"⚠️ Jina AI v3 not loaded: {e}")
117
+
118
+ # Qwen3-Embedding-8B via Fireworks AI (API-based, no download needed!)
119
+ if fireworks_available:
120
+ MODELS['qwen3'] = 'fireworks' # Mark as available via Fireworks AI
121
+ logger.info("✓ Qwen3-Embedding-8B available via Fireworks AI API (MTEB #1, no local model needed)")
122
+ else:
123
+ logger.warning("⚠️ Qwen3-Embedding-8B not available")
124
+ logger.warning(" To enable: Set FIREWORKS_API_KEY environment variable")
125
+ logger.warning(" Get API key at: https://fireworks.ai")
126
+ logger.warning(" This avoids 15GB local download!")
127
+
128
+ # Check if at least one model loaded
129
+ if not MODELS:
130
+ error_msg = "No embedding models could be loaded! Check logs above for details."
131
+ logger.error(error_msg)
132
+ raise RuntimeError(error_msg)
133
+
134
+ logger.info(f"Loaded models: {list(MODELS.keys())}")
135
+ logger.info("API ready!")
136
 
137
  async def verify_api_key(credentials: Optional[HTTPAuthorizationCredentials] = Security(security)):
138
  """Verify API key from Authorization header"""
 
165
  total_chars = sum(len(text) for text in texts)
166
  return max(1, total_chars // 4)
167
 
168
+ def get_fireworks_embeddings(texts: List[str], task: Optional[str] = None) -> List[List[float]]:
169
+ """
170
+ Get embeddings from Fireworks AI Qwen3-Embedding-8B
171
+
172
+ Args:
173
+ texts: List of texts to embed
174
+ task: Optional task type ('query' for instruction-aware)
175
+
176
+ Returns:
177
+ List of embedding vectors (4096-dim each)
178
+ """
179
+ import requests
180
+ import json
181
+
182
+ if not FIREWORKS_API_KEY:
183
+ raise Exception("FIREWORKS_API_KEY not configured")
184
+
185
+ # Fireworks AI embeddings endpoint
186
+ response = requests.post(
187
+ "https://api.fireworks.ai/inference/v1/embeddings",
188
+ headers={
189
+ "Accept": "application/json",
190
+ "Content-Type": "application/json",
191
+ "Authorization": f"Bearer {FIREWORKS_API_KEY}"
192
+ },
193
+ data=json.dumps({
194
+ "model": "accounts/fireworks/models/qwen3-embedding-8b",
195
+ "input": texts
196
+ }),
197
+ timeout=30
198
+ )
199
+
200
+ if response.status_code != 200:
201
+ raise Exception(f"Fireworks AI API error: {response.status_code} - {response.text}")
202
+
203
+ result = response.json()
204
+ embeddings = [item["embedding"] for item in result["data"]]
205
+
206
+ return embeddings
207
+
208
  @app.on_event("startup")
209
  async def startup_event():
210
  load_models()
 
266
  status: str
267
  models_loaded: List[str]
268
  voyage_available: bool
269
+ fireworks_available: bool
270
  api_key_required: bool
271
 
272
  @app.get("/", response_model=dict)
 
292
  "status": "healthy",
293
  "models_loaded": models_loaded,
294
  "voyage_available": voyage_client is not None,
295
+ "fireworks_available": fireworks_available,
296
  "api_key_required": REQUIRE_API_KEY
297
  }
298
 
 
315
  - `jobbertv2`: JobBERT-v2 (768-dim, job-specific)
316
  - `jobbertv3`: JobBERT-v3 (768-dim, job-specific, improved performance) - default
317
  - `jina`: Jina AI embeddings-v3 (1024-dim, general purpose)
318
+ - `qwen3`: Qwen3-Embedding-8B (4096-dim, MTEB #1, multilingual, 32k context)
319
  - `voyage`: Voyage AI (1024-dim, requires API key)
320
 
321
  **Jina AI Tasks (via query parameter):**
 
323
  - `retrieval.passage`: For documents/passages
324
  - `text-matching`: For similarity matching (default)
325
 
326
+ **Qwen3 Task (via query parameter):**
327
+ - `query`: For search queries (uses instruction-aware prompt)
328
+ - Default: Documents/passages (no instruction)
329
+
330
  **Voyage AI Input Types (via query parameter):**
331
  - `document`: For documents/passages
332
  - `query`: For search queries
 
375
  try:
376
  selected_model = MODELS[model_name]
377
 
378
+ # Qwen3 via Fireworks AI API (no local model)
379
+ if model_name == "qwen3" and selected_model == 'fireworks':
380
+ embeddings_list = get_fireworks_embeddings(texts, task=task)
381
+ # Jina AI with task type
382
+ elif model_name == "jina" and task:
383
  embeddings = selected_model.encode(
384
  texts,
385
  task=task,
386
  convert_to_numpy=True
387
  )
388
+ embeddings_list = embeddings.tolist()
389
  else:
390
  embeddings = selected_model.encode(
391
  texts,
392
  convert_to_numpy=True
393
  )
394
+ embeddings_list = embeddings.tolist()
 
395
 
396
  # Calculate token usage
397
  token_count = estimate_token_count(texts)
 
406
  model_display_name = {
407
  "jobbertv2": "TechWolf/JobBERT-v2",
408
  "jobbertv3": "TechWolf/JobBERT-v3",
409
+ "jina": "jina-embeddings-v3",
410
+ "qwen3": "Qwen/Qwen3-Embedding-8B"
411
  }.get(model_name, model_name)
412
 
413
  return OpenAIEmbeddingResponse(
 
422
  else:
423
  raise HTTPException(
424
  status_code=400,
425
+ detail=f"Invalid model '{model_name}'. Choose from: jobbertv2, jobbertv3, jina, qwen3, voyage"
426
  )
427
 
428
  @app.post("/embed/batch", response_model=BatchEmbeddingResponse)
 
437
  - `jobbertv2`: JobBERT-v2 (768-dim, job-specific)
438
  - `jobbertv3`: JobBERT-v3 (768-dim, job-specific, improved performance)
439
  - `jina`: Jina AI embeddings-v3 (1024-dim, general purpose, supports task types)
440
+ - `qwen3`: Qwen3-Embedding-8B (4096-dim, MTEB #1, multilingual, 32k context)
441
  - `voyage`: Voyage AI (1024-dim, requires API key)
442
 
443
  **Jina AI Tasks:**
 
483
  try:
484
  selected_model = MODELS[model_name]
485
 
486
+ # Qwen3 via Fireworks AI API (no local model)
487
+ if model_name == "qwen3" and selected_model == 'fireworks':
488
+ embeddings_list = get_fireworks_embeddings(request.texts, task=request.task)
489
+ # Jina AI with task type
490
+ elif model_name == "jina" and request.task:
491
  embeddings = selected_model.encode(
492
  request.texts,
493
  task=request.task,
494
  convert_to_numpy=True
495
  )
496
+ embeddings_list = embeddings.tolist()
497
  else:
498
  embeddings = selected_model.encode(
499
  request.texts,
500
  convert_to_numpy=True
501
  )
502
+ embeddings_list = embeddings.tolist()
 
503
  dimension = len(embeddings_list[0]) if embeddings_list else 0
504
 
505
  return BatchEmbeddingResponse(
 
514
  else:
515
  raise HTTPException(
516
  status_code=400,
517
+ detail=f"Invalid model '{model_name}'. Choose from: jobbertv2, jobbertv3, jina, qwen3, voyage"
518
  )
519
 
520
  @app.get("/models")
 
543
  "available": "jina" in MODELS,
544
  "tasks": ["retrieval.query", "retrieval.passage", "text-matching", "classification", "separation"]
545
  },
546
+ "qwen3": {
547
+ "name": "Qwen/Qwen3-Embedding-8B",
548
+ "dimension": 4096,
549
+ "description": "🏆 MTEB #1 multilingual model (100+ languages, 32k context, instruction-aware)",
550
+ "max_tokens": 32768,
551
+ "available": "qwen3" in MODELS,
552
+ "tasks": ["query", "document"],
553
+ "features": ["multilingual", "instruction-aware", "long-context"]
554
+ },
555
  "voyage": {
556
  "name": "voyage-3",
557
  "dimension": 1024,