felixbet commited on
Commit
25bfd3b
·
verified ·
1 Parent(s): ac7f7f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -21
app.py CHANGED
@@ -1,30 +1,39 @@
1
  import os
2
- from transformers import BertTokenizer, BertConfig, TFBertModel
3
  from fastapi import FastAPI
 
 
 
4
 
5
  app = FastAPI()
6
 
 
7
  MODEL_DIR = os.environ.get("MODEL_DIR", "/app/bert_tf")
8
-
9
- # Guard: create dir if missing; avoid listing non-existent paths
10
  os.makedirs(MODEL_DIR, exist_ok=True)
11
 
12
- # Probe one level deep only if there are entries
13
- candidates = [MODEL_DIR]
14
- try:
15
- for x in os.listdir(MODEL_DIR):
16
- p = os.path.join(MODEL_DIR, x)
17
- if os.path.isdir(p):
18
- candidates.append(p)
19
- except FileNotFoundError:
20
- pass
21
-
22
- for d in candidates:
23
- if (os.path.isfile(os.path.join(d, "vocab.txt"))
24
- and os.path.isfile(os.path.join(d, "config.json"))):
25
- MODEL_DIR = d
26
- break
27
-
28
- tok = BertTokenizer(vocab_file=f"{MODEL_DIR}/vocab.txt", do_lower_case=True)
29
- cfg = BertConfig.from_json_file(f"{MODEL_DIR}/config.json")
30
  model= TFBertModel.from_pretrained(MODEL_DIR, from_tf=True, config=cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
2
  from fastapi import FastAPI
3
+ from pydantic import BaseModel
4
+ from transformers import BertTokenizer, BertConfig, TFBertModel
5
+ import tensorflow as tf
6
 
7
  app = FastAPI()
8
 
9
+ # start.sh exports this after extraction; keep a fallback for local/dev
10
  MODEL_DIR = os.environ.get("MODEL_DIR", "/app/bert_tf")
 
 
11
  os.makedirs(MODEL_DIR, exist_ok=True)
12
 
13
+ # extra safety: if no vocab here, look 2 levels deep
14
+ if not os.path.isfile(os.path.join(MODEL_DIR, "vocab.txt")):
15
+ for root, dirs, files in os.walk(MODEL_DIR):
16
+ if "vocab.txt" in files and "config.json" in files:
17
+ MODEL_DIR = root
18
+ break
19
+
20
+ print("[app] Using MODEL_DIR:", MODEL_DIR)
21
+
22
+ tok = BertTokenizer(vocab_file=os.path.join(MODEL_DIR, "vocab.txt"), do_lower_case=True)
23
+ cfg = BertConfig.from_json_file(os.path.join(MODEL_DIR, "config.json"))
 
 
 
 
 
 
 
24
  model= TFBertModel.from_pretrained(MODEL_DIR, from_tf=True, config=cfg)
25
+
26
+ class EmbReq(BaseModel):
27
+ input: str
28
+
29
+ @app.get("/health")
30
+ def health():
31
+ return {"ok": True}
32
+
33
+ @app.post("/v1/embeddings")
34
+ def emb(req: EmbReq):
35
+ ids = tok(req.input, return_tensors="tf", truncation=True, max_length=128)
36
+ out = model(**ids)
37
+ # [CLS] pooled output
38
+ vec = out.pooler_output[0].numpy().tolist()
39
+ return {"embedding": vec, "dim": len(vec)}