nurulajt commited on
Commit
fd3e04f
·
verified ·
1 Parent(s): cc89204

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +60 -9
api.py CHANGED
@@ -99,6 +99,12 @@ async def verify_api_key(credentials: Optional[HTTPAuthorizationCredentials] = S
99
 
100
  return True
101
 
 
 
 
 
 
 
102
  @app.on_event("startup")
103
  async def startup_event():
104
  load_models()
@@ -113,6 +119,22 @@ class ElasticsearchInferenceRequest(BaseModel):
113
  }
114
  }
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  class ElasticsearchInferenceResponse(BaseModel):
117
  embedding: List[float] = Field(..., description="Embedding vector for single input")
118
 
@@ -172,7 +194,7 @@ async def health():
172
  "api_key_required": REQUIRE_API_KEY
173
  }
174
 
175
- @app.post("/embed", response_model=Union[ElasticsearchInferenceResponse, ElasticsearchInferenceBatchResponse])
176
  async def create_embeddings_elasticsearch(
177
  request: ElasticsearchInferenceRequest,
178
  model: str = Query("jobbertv3", description="Model: jobbertv2, jobbertv3, jina, or voyage"),
@@ -224,10 +246,21 @@ async def create_embeddings_elasticsearch(
224
  )
225
  embeddings = result.embeddings
226
 
227
- if is_single:
228
- return ElasticsearchInferenceResponse(embedding=embeddings[0])
229
- else:
230
- return ElasticsearchInferenceBatchResponse(embeddings=embeddings)
 
 
 
 
 
 
 
 
 
 
 
231
  except Exception as e:
232
  raise HTTPException(status_code=500, detail=f"Voyage AI error: {str(e)}")
233
 
@@ -249,10 +282,28 @@ async def create_embeddings_elasticsearch(
249
 
250
  embeddings_list = embeddings.tolist()
251
 
252
- if is_single:
253
- return ElasticsearchInferenceResponse(embedding=embeddings_list[0])
254
- else:
255
- return ElasticsearchInferenceBatchResponse(embeddings=embeddings_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  except Exception as e:
257
  raise HTTPException(status_code=500, detail=f"Model error: {str(e)}")
258
 
 
99
 
100
  return True
101
 
102
+ def estimate_token_count(texts: List[str]) -> int:
103
+ """Estimate token count for input texts (rough approximation)"""
104
+ # Simple estimation: ~1 token per 4 characters
105
+ total_chars = sum(len(text) for text in texts)
106
+ return max(1, total_chars // 4)
107
+
108
  @app.on_event("startup")
109
  async def startup_event():
110
  load_models()
 
119
  }
120
  }
121
 
122
+ class EmbeddingObject(BaseModel):
123
+ object: str = Field("embedding", description="Object type")
124
+ index: int = Field(..., description="Index of the embedding")
125
+ embedding: List[float] = Field(..., description="Embedding vector")
126
+
127
+ class UsageInfo(BaseModel):
128
+ total_tokens: int = Field(..., description="Total tokens processed")
129
+ prompt_tokens: int = Field(..., description="Prompt tokens")
130
+
131
+ class OpenAIEmbeddingResponse(BaseModel):
132
+ model: str = Field(..., description="Model used for embeddings")
133
+ object: str = Field("list", description="Object type")
134
+ usage: UsageInfo = Field(..., description="Token usage information")
135
+ data: List[EmbeddingObject] = Field(..., description="List of embeddings")
136
+
137
+ # Legacy response models (kept for backward compatibility if needed)
138
  class ElasticsearchInferenceResponse(BaseModel):
139
  embedding: List[float] = Field(..., description="Embedding vector for single input")
140
 
 
194
  "api_key_required": REQUIRE_API_KEY
195
  }
196
 
197
+ @app.post("/embed", response_model=OpenAIEmbeddingResponse)
198
  async def create_embeddings_elasticsearch(
199
  request: ElasticsearchInferenceRequest,
200
  model: str = Query("jobbertv3", description="Model: jobbertv2, jobbertv3, jina, or voyage"),
 
246
  )
247
  embeddings = result.embeddings
248
 
249
+ # Calculate token usage
250
+ token_count = estimate_token_count(texts)
251
+
252
+ # Create OpenAI-compatible response
253
+ data = [
254
+ EmbeddingObject(index=i, embedding=emb)
255
+ for i, emb in enumerate(embeddings)
256
+ ]
257
+
258
+ return OpenAIEmbeddingResponse(
259
+ model="voyage-3",
260
+ object="list",
261
+ usage=UsageInfo(total_tokens=token_count, prompt_tokens=token_count),
262
+ data=data
263
+ )
264
  except Exception as e:
265
  raise HTTPException(status_code=500, detail=f"Voyage AI error: {str(e)}")
266
 
 
282
 
283
  embeddings_list = embeddings.tolist()
284
 
285
+ # Calculate token usage
286
+ token_count = estimate_token_count(texts)
287
+
288
+ # Create OpenAI-compatible response
289
+ data = [
290
+ EmbeddingObject(index=i, embedding=emb)
291
+ for i, emb in enumerate(embeddings_list)
292
+ ]
293
+
294
+ # Determine the full model name for response
295
+ model_display_name = {
296
+ "jobbertv2": "TechWolf/JobBERT-v2",
297
+ "jobbertv3": "TechWolf/JobBERT-v3",
298
+ "jina": "jina-embeddings-v3"
299
+ }.get(model_name, model_name)
300
+
301
+ return OpenAIEmbeddingResponse(
302
+ model=model_display_name,
303
+ object="list",
304
+ usage=UsageInfo(total_tokens=token_count, prompt_tokens=token_count),
305
+ data=data
306
+ )
307
  except Exception as e:
308
  raise HTTPException(status_code=500, detail=f"Model error: {str(e)}")
309