Spaces:
Sleeping
Sleeping
p2002814
commited on
Commit
·
a0e1005
1
Parent(s):
f2dc286
now predicting csv progressively for a smooth animation
Browse files
app.py
CHANGED
|
@@ -1,4 +1,9 @@
|
|
| 1 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import io
|
| 3 |
import json
|
| 4 |
import asyncio
|
|
@@ -15,9 +20,6 @@ from relations.predict_bert import predict_relation
|
|
| 15 |
from aba.aba_builder import prepare_aba_plus_framework, build_aba_framework_from_text
|
| 16 |
|
| 17 |
# -------------------- Config -------------------- #
|
| 18 |
-
cache_dir = "/tmp/hf_cache"
|
| 19 |
-
os.environ["TRANSFORMERS_CACHE"] = cache_dir
|
| 20 |
-
os.makedirs(cache_dir, exist_ok=True)
|
| 21 |
|
| 22 |
EXAMPLES_DIR = Path("./aba/exemples")
|
| 23 |
SAMPLES_DIR = Path("./relations/exemples/samples")
|
|
@@ -53,22 +55,6 @@ def predict_text(arg1: str = Form(...), arg2: str = Form(...)):
|
|
| 53 |
return {"arg1": arg1, "arg2": arg2, "relation": result}
|
| 54 |
|
| 55 |
|
| 56 |
-
@app.post("/predict-csv")
|
| 57 |
-
async def predict_csv(file: UploadFile):
|
| 58 |
-
"""Predict relations for pairs of arguments from a CSV file (max 250 rows)."""
|
| 59 |
-
content = await file.read()
|
| 60 |
-
df = pd.read_csv(io.StringIO(content.decode("utf-8")), quotechar='"')
|
| 61 |
-
if len(df) > 250:
|
| 62 |
-
df = df.head(250)
|
| 63 |
-
|
| 64 |
-
results = []
|
| 65 |
-
for _, row in df.iterrows():
|
| 66 |
-
result = predict_relation(row["parent"], row["child"], model, tokenizer, device)
|
| 67 |
-
results.append({"parent": row["parent"], "child": row["child"], "relation": result})
|
| 68 |
-
|
| 69 |
-
return {"results": results, "note": "Limited to 250 rows max"}
|
| 70 |
-
|
| 71 |
-
|
| 72 |
@app.post("/predict-csv-stream")
|
| 73 |
async def predict_csv_stream(file: UploadFile):
|
| 74 |
"""Stream CSV predictions progressively using SSE."""
|
|
|
|
| 1 |
import os
|
| 2 |
+
|
| 3 |
+
cache_dir = "/tmp/hf_cache"
|
| 4 |
+
os.environ["TRANSFORMERS_CACHE"] = cache_dir
|
| 5 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 6 |
+
|
| 7 |
import io
|
| 8 |
import json
|
| 9 |
import asyncio
|
|
|
|
| 20 |
from aba.aba_builder import prepare_aba_plus_framework, build_aba_framework_from_text
|
| 21 |
|
| 22 |
# -------------------- Config -------------------- #
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
EXAMPLES_DIR = Path("./aba/exemples")
|
| 25 |
SAMPLES_DIR = Path("./relations/exemples/samples")
|
|
|
|
| 55 |
return {"arg1": arg1, "arg2": arg2, "relation": result}
|
| 56 |
|
| 57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
@app.post("/predict-csv-stream")
|
| 59 |
async def predict_csv_stream(file: UploadFile):
|
| 60 |
"""Stream CSV predictions progressively using SSE."""
|