p2002814 commited on
Commit
a0e1005
·
1 Parent(s): f2dc286

now predicting csv progressively for a smooth animation

Browse files
Files changed (1) hide show
  1. app.py +5 -19
app.py CHANGED
@@ -1,4 +1,9 @@
1
  import os
 
 
 
 
 
2
  import io
3
  import json
4
  import asyncio
@@ -15,9 +20,6 @@ from relations.predict_bert import predict_relation
15
  from aba.aba_builder import prepare_aba_plus_framework, build_aba_framework_from_text
16
 
17
  # -------------------- Config -------------------- #
18
- cache_dir = "/tmp/hf_cache"
19
- os.environ["TRANSFORMERS_CACHE"] = cache_dir
20
- os.makedirs(cache_dir, exist_ok=True)
21
 
22
  EXAMPLES_DIR = Path("./aba/exemples")
23
  SAMPLES_DIR = Path("./relations/exemples/samples")
@@ -53,22 +55,6 @@ def predict_text(arg1: str = Form(...), arg2: str = Form(...)):
53
  return {"arg1": arg1, "arg2": arg2, "relation": result}
54
 
55
 
56
- @app.post("/predict-csv")
57
- async def predict_csv(file: UploadFile):
58
- """Predict relations for pairs of arguments from a CSV file (max 250 rows)."""
59
- content = await file.read()
60
- df = pd.read_csv(io.StringIO(content.decode("utf-8")), quotechar='"')
61
- if len(df) > 250:
62
- df = df.head(250)
63
-
64
- results = []
65
- for _, row in df.iterrows():
66
- result = predict_relation(row["parent"], row["child"], model, tokenizer, device)
67
- results.append({"parent": row["parent"], "child": row["child"], "relation": result})
68
-
69
- return {"results": results, "note": "Limited to 250 rows max"}
70
-
71
-
72
  @app.post("/predict-csv-stream")
73
  async def predict_csv_stream(file: UploadFile):
74
  """Stream CSV predictions progressively using SSE."""
 
1
  import os
2
+
3
+ cache_dir = "/tmp/hf_cache"
4
+ os.environ["TRANSFORMERS_CACHE"] = cache_dir
5
+ os.makedirs(cache_dir, exist_ok=True)
6
+
7
  import io
8
  import json
9
  import asyncio
 
20
  from aba.aba_builder import prepare_aba_plus_framework, build_aba_framework_from_text
21
 
22
  # -------------------- Config -------------------- #
 
 
 
23
 
24
  EXAMPLES_DIR = Path("./aba/exemples")
25
  SAMPLES_DIR = Path("./relations/exemples/samples")
 
55
  return {"arg1": arg1, "arg2": arg2, "relation": result}
56
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  @app.post("/predict-csv-stream")
59
  async def predict_csv_stream(file: UploadFile):
60
  """Stream CSV predictions progressively using SSE."""