felixbet commited on
Commit
9bd55b2
·
verified ·
1 Parent(s): 25bfd3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -15
app.py CHANGED
@@ -1,22 +1,95 @@
1
- import os
 
 
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
 
4
  from transformers import BertTokenizer, BertConfig, TFBertModel
5
- import tensorflow as tf
6
 
7
  app = FastAPI()
8
 
9
- # start.sh exports this after extraction; keep a fallback for local/dev
10
- MODEL_DIR = os.environ.get("MODEL_DIR", "/app/bert_tf")
11
- os.makedirs(MODEL_DIR, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- # extra safety: if no vocab here, look 2 levels deep
14
- if not os.path.isfile(os.path.join(MODEL_DIR, "vocab.txt")):
15
- for root, dirs, files in os.walk(MODEL_DIR):
16
- if "vocab.txt" in files and "config.json" in files:
17
- MODEL_DIR = root
18
- break
19
 
 
 
20
  print("[app] Using MODEL_DIR:", MODEL_DIR)
21
 
22
  tok = BertTokenizer(vocab_file=os.path.join(MODEL_DIR, "vocab.txt"), do_lower_case=True)
@@ -31,9 +104,8 @@ def health():
31
  return {"ok": True}
32
 
33
  @app.post("/v1/embeddings")
34
- def emb(req: EmbReq):
35
- ids = tok(req.input, return_tensors="tf", truncation=True, max_length=128)
36
- out = model(**ids)
37
- # [CLS] pooled output
38
  vec = out.pooler_output[0].numpy().tolist()
39
  return {"embedding": vec, "dim": len(vec)}
 
1
+ # app.py — self-bootstrapping TF BioBERT embeddings API (HF Spaces-friendly)
2
+
3
+ import os, tarfile, glob, json, shutil, urllib.request
4
  from fastapi import FastAPI
5
  from pydantic import BaseModel
6
+ from typing import List
7
  from transformers import BertTokenizer, BertConfig, TFBertModel
8
+ import tensorflow as tf # noqa
9
 
10
  app = FastAPI()
11
 
12
+ # --- Config
13
+ MODEL_ROOT = os.environ.get("MODEL_ROOT", "/app/bert_tf")
14
+ WEIGHTS_URL = os.environ.get("WEIGHTS_URL_TAR_GZ", "").strip() # direct .tar.gz link (Dropbox must end with dl=1)
15
+ FALLBACK_VOCAB_URL = "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt"
16
+
17
+ os.makedirs(MODEL_ROOT, exist_ok=True)
18
+
19
+ def _extract_tar_gz(src: str, dest: str) -> None:
20
+ with tarfile.open(src, "r:gz") as tar:
21
+ def is_within(directory, target):
22
+ abs_directory = os.path.abspath(directory)
23
+ abs_target = os.path.abspath(target)
24
+ return os.path.commonpath([abs_directory]) == os.path.commonpath([abs_directory, abs_target])
25
+ for member in tar.getmembers():
26
+ target_path = os.path.join(dest, member.name)
27
+ if not is_within(dest, target_path):
28
+ raise RuntimeError("Blocked path traversal in tar")
29
+ tar.extractall(dest)
30
+
31
+ def ensure_weights_and_get_model_dir() -> str:
32
+ # If already prepared (vocab + any ckpt index) → reuse
33
+ maybe_vocab = glob.glob(os.path.join(MODEL_ROOT, "**", "vocab.txt"), recursive=True)
34
+ maybe_idx = glob.glob(os.path.join(MODEL_ROOT, "**", "model.ckpt-*.index"), recursive=True)
35
+ if maybe_vocab and maybe_idx:
36
+ # choose dir containing the first ckpt index
37
+ return os.path.dirname(maybe_idx[0])
38
+
39
+ # Otherwise download and extract the archive
40
+ if not WEIGHTS_URL:
41
+ print("[app] WEIGHTS_URL_TAR_GZ not set; will still try to run with fallback vocab if files exist.")
42
+ else:
43
+ print("[app] downloading weights:", WEIGHTS_URL)
44
+ local_tar = "/tmp/model.tar.gz"
45
+ urllib.request.urlretrieve(WEIGHTS_URL, local_tar)
46
+ print("[app] extracting:", local_tar, "->", MODEL_ROOT)
47
+ _extract_tar_gz(local_tar, MODEL_ROOT)
48
+
49
+ # Pick the folder that has a ckpt index
50
+ idx_files = glob.glob(os.path.join(MODEL_ROOT, "**", "model.ckpt-*.index"), recursive=True)
51
+ if not idx_files:
52
+ raise RuntimeError("No TensorFlow checkpoint index found under " + MODEL_ROOT)
53
+ model_dir = os.path.dirname(idx_files[0])
54
+
55
+ # Ensure checkpoint file points at the basename
56
+ basename = os.path.basename(idx_files[0]).replace(".index", "")
57
+ ckpt_meta = os.path.join(model_dir, "checkpoint")
58
+ if not os.path.isfile(ckpt_meta):
59
+ with open(ckpt_meta, "w") as f:
60
+ f.write(f'model_checkpoint_path: "{basename}"\n')
61
+
62
+ # Ensure config.json
63
+ cfg = os.path.join(model_dir, "config.json")
64
+ bcfg = os.path.join(model_dir, "bert_config.json")
65
+ if not os.path.isfile(cfg):
66
+ if os.path.isfile(bcfg):
67
+ shutil.copy(bcfg, cfg)
68
+ else:
69
+ with open(cfg, "w") as f:
70
+ json.dump({
71
+ "hidden_size": 768,
72
+ "num_attention_heads": 12,
73
+ "num_hidden_layers": 12,
74
+ "intermediate_size": 3072,
75
+ "hidden_act": "gelu",
76
+ "hidden_dropout_prob": 0.1,
77
+ "attention_probs_dropout_prob": 0.1,
78
+ "max_position_embeddings": 512,
79
+ "type_vocab_size": 2,
80
+ "vocab_size": 30522
81
+ }, f)
82
+
83
+ # Ensure vocab.txt (BioBERT uses BERT base uncased vocab)
84
+ vocab = os.path.join(model_dir, "vocab.txt")
85
+ if not os.path.isfile(vocab):
86
+ print("[app] vocab.txt missing; fetching BERT base uncased vocab…")
87
+ urllib.request.urlretrieve(FALLBACK_VOCAB_URL, vocab)
88
 
89
+ return model_dir
 
 
 
 
 
90
 
91
+ # Prepare weights (download/extract if needed), then load model
92
+ MODEL_DIR = ensure_weights_and_get_model_dir()
93
  print("[app] Using MODEL_DIR:", MODEL_DIR)
94
 
95
  tok = BertTokenizer(vocab_file=os.path.join(MODEL_DIR, "vocab.txt"), do_lower_case=True)
 
104
  return {"ok": True}
105
 
106
  @app.post("/v1/embeddings")
107
+ def embeddings(req: EmbReq):
108
+ enc = tok(req.input, return_tensors="tf", truncation=True, max_length=128)
109
+ out = model(**enc)
 
110
  vec = out.pooler_output[0].numpy().tolist()
111
  return {"embedding": vec, "dim": len(vec)}