Spaces:
Sleeping
Sleeping
p2002814
commited on
Commit
·
f2dc286
1
Parent(s):
c5f9116
now predicting csv progressively for a smooth animation
Browse files
app.py
CHANGED
|
@@ -1,98 +1,106 @@
|
|
| 1 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
cache_dir = "/tmp/hf_cache"
|
| 4 |
-
os.environ["TRANSFORMERS_CACHE"] = cache_dir
|
| 5 |
-
os.makedirs(cache_dir, exist_ok=True)
|
| 6 |
-
|
| 7 |
-
from fastapi import FastAPI, UploadFile, File, Form
|
| 8 |
-
from fastapi.responses import FileResponse
|
| 9 |
-
from fastapi.middleware.cors import CORSMiddleware
|
| 10 |
import pandas as pd
|
| 11 |
-
from pathlib import Path
|
| 12 |
import torch
|
| 13 |
-
import
|
| 14 |
-
|
| 15 |
-
from
|
| 16 |
-
|
| 17 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 18 |
|
| 19 |
-
|
| 20 |
from aba.aba_builder import prepare_aba_plus_framework, build_aba_framework_from_text
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
"http://localhost:3000",
|
| 27 |
-
"http://127.0.0.1:3000",
|
| 28 |
-
]
|
| 29 |
-
|
| 30 |
-
app.add_middleware(
|
| 31 |
-
CORSMiddleware,
|
| 32 |
-
allow_origins=["*"], # allow all origins
|
| 33 |
-
allow_credentials=True,
|
| 34 |
-
allow_methods=["*"],
|
| 35 |
-
allow_headers=["*"],
|
| 36 |
-
)
|
| 37 |
|
| 38 |
EXAMPLES_DIR = Path("./aba/exemples")
|
| 39 |
SAMPLES_DIR = Path("./relations/exemples/samples")
|
| 40 |
-
|
| 41 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 42 |
|
| 43 |
-
# Load model at startup once
|
| 44 |
model_name = "edgar-demeude/bert-argument"
|
| 45 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 46 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
| 47 |
model.to(device)
|
| 48 |
|
|
|
|
|
|
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
@app.get("/")
|
| 51 |
def root():
|
| 52 |
return {"message": "Argument Mining API is running..."}
|
| 53 |
|
| 54 |
|
| 55 |
-
# ---------------- BERT Prediction Endpoints ---------------- #
|
| 56 |
-
|
| 57 |
@app.post("/predict-text")
|
| 58 |
def predict_text(arg1: str = Form(...), arg2: str = Form(...)):
|
| 59 |
"""Predict relation between two text arguments using BERT."""
|
| 60 |
result = predict_relation(arg1, arg2, model, tokenizer, device)
|
| 61 |
-
return {
|
| 62 |
-
"arg1": arg1,
|
| 63 |
-
"arg2": arg2,
|
| 64 |
-
"relation": result
|
| 65 |
-
}
|
| 66 |
|
| 67 |
|
| 68 |
@app.post("/predict-csv")
|
| 69 |
async def predict_csv(file: UploadFile):
|
| 70 |
"""Predict relations for pairs of arguments from a CSV file (max 250 rows)."""
|
| 71 |
content = await file.read()
|
| 72 |
-
# Utiliser StringIO + quotechar='"'
|
| 73 |
df = pd.read_csv(io.StringIO(content.decode("utf-8")), quotechar='"')
|
| 74 |
-
|
| 75 |
if len(df) > 250:
|
| 76 |
df = df.head(250)
|
| 77 |
|
| 78 |
results = []
|
| 79 |
for _, row in df.iterrows():
|
| 80 |
-
result = predict_relation(
|
| 81 |
-
|
| 82 |
-
row["child"],
|
| 83 |
-
model,
|
| 84 |
-
tokenizer,
|
| 85 |
-
device
|
| 86 |
-
)
|
| 87 |
-
results.append({
|
| 88 |
-
"parent": row["parent"],
|
| 89 |
-
"child": row["child"],
|
| 90 |
-
"relation": result
|
| 91 |
-
})
|
| 92 |
|
| 93 |
return {"results": results, "note": "Limited to 250 rows max"}
|
| 94 |
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
@app.get("/samples")
|
| 97 |
def list_samples():
|
| 98 |
files = [f for f in os.listdir(SAMPLES_DIR) if f.endswith(".csv")]
|
|
@@ -106,19 +114,12 @@ def get_sample(filename: str):
|
|
| 106 |
return {"error": "Sample not found"}
|
| 107 |
return FileResponse(file_path, media_type="text/csv")
|
| 108 |
|
| 109 |
-
# ---------------- ABA API ---------------- #
|
| 110 |
|
| 111 |
@app.post("/aba-upload")
|
| 112 |
async def aba_upload(file: UploadFile = File(...)):
|
| 113 |
-
"""
|
| 114 |
-
Upload a .txt file containing an ABA framework definition
|
| 115 |
-
and return the generated ABA+ framework.
|
| 116 |
-
"""
|
| 117 |
-
# Read file contents
|
| 118 |
content = await file.read()
|
| 119 |
-
text = content.decode("utf-8")
|
| 120 |
|
| 121 |
-
# Build ABA framework
|
| 122 |
aba_framework = build_aba_framework_from_text(text)
|
| 123 |
aba_framework = prepare_aba_plus_framework(aba_framework)
|
| 124 |
aba_framework.make_aba_plus()
|
|
@@ -134,14 +135,12 @@ async def aba_upload(file: UploadFile = File(...)):
|
|
| 134 |
|
| 135 |
@app.get("/aba-examples")
|
| 136 |
def list_aba_examples():
|
| 137 |
-
"""Lists all sample files available on the server side."""
|
| 138 |
examples = [f.name for f in EXAMPLES_DIR.glob("*.txt")]
|
| 139 |
return {"examples": examples}
|
| 140 |
|
| 141 |
|
| 142 |
@app.get("/aba-examples/{filename}")
|
| 143 |
def get_aba_example(filename: str):
|
| 144 |
-
"""Returns the contents of a specific ABA sample file."""
|
| 145 |
file_path = EXAMPLES_DIR / filename
|
| 146 |
if not file_path.exists() or not file_path.is_file():
|
| 147 |
return {"error": "File not found"}
|
|
|
|
| 1 |
import os
|
| 2 |
+
import io
|
| 3 |
+
import json
|
| 4 |
+
import asyncio
|
| 5 |
+
from pathlib import Path
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
import pandas as pd
|
|
|
|
| 8 |
import torch
|
| 9 |
+
from fastapi import FastAPI, UploadFile, File, Form
|
| 10 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 11 |
+
from fastapi.responses import FileResponse, StreamingResponse
|
|
|
|
| 12 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 13 |
|
| 14 |
+
from relations.predict_bert import predict_relation
|
| 15 |
from aba.aba_builder import prepare_aba_plus_framework, build_aba_framework_from_text
|
| 16 |
|
| 17 |
+
# -------------------- Config -------------------- #
|
| 18 |
+
cache_dir = "/tmp/hf_cache"
|
| 19 |
+
os.environ["TRANSFORMERS_CACHE"] = cache_dir
|
| 20 |
+
os.makedirs(cache_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
EXAMPLES_DIR = Path("./aba/exemples")
|
| 23 |
SAMPLES_DIR = Path("./relations/exemples/samples")
|
|
|
|
| 24 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 25 |
|
|
|
|
| 26 |
model_name = "edgar-demeude/bert-argument"
|
| 27 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 28 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
| 29 |
model.to(device)
|
| 30 |
|
| 31 |
+
# -------------------- App -------------------- #
|
| 32 |
+
app = FastAPI(title="Argument Mining API")
|
| 33 |
|
| 34 |
+
origins = ["http://localhost:3000", "http://127.0.0.1:3000"]
|
| 35 |
+
app.add_middleware(
|
| 36 |
+
CORSMiddleware,
|
| 37 |
+
allow_origins=["*"],
|
| 38 |
+
allow_credentials=True,
|
| 39 |
+
allow_methods=["*"],
|
| 40 |
+
allow_headers=["*"],
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# -------------------- Endpoints -------------------- #
|
| 44 |
@app.get("/")
|
| 45 |
def root():
|
| 46 |
return {"message": "Argument Mining API is running..."}
|
| 47 |
|
| 48 |
|
|
|
|
|
|
|
| 49 |
@app.post("/predict-text")
|
| 50 |
def predict_text(arg1: str = Form(...), arg2: str = Form(...)):
|
| 51 |
"""Predict relation between two text arguments using BERT."""
|
| 52 |
result = predict_relation(arg1, arg2, model, tokenizer, device)
|
| 53 |
+
return {"arg1": arg1, "arg2": arg2, "relation": result}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
|
| 56 |
@app.post("/predict-csv")
|
| 57 |
async def predict_csv(file: UploadFile):
|
| 58 |
"""Predict relations for pairs of arguments from a CSV file (max 250 rows)."""
|
| 59 |
content = await file.read()
|
|
|
|
| 60 |
df = pd.read_csv(io.StringIO(content.decode("utf-8")), quotechar='"')
|
|
|
|
| 61 |
if len(df) > 250:
|
| 62 |
df = df.head(250)
|
| 63 |
|
| 64 |
results = []
|
| 65 |
for _, row in df.iterrows():
|
| 66 |
+
result = predict_relation(row["parent"], row["child"], model, tokenizer, device)
|
| 67 |
+
results.append({"parent": row["parent"], "child": row["child"], "relation": result})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
return {"results": results, "note": "Limited to 250 rows max"}
|
| 70 |
|
| 71 |
|
| 72 |
+
@app.post("/predict-csv-stream")
|
| 73 |
+
async def predict_csv_stream(file: UploadFile):
|
| 74 |
+
"""Stream CSV predictions progressively using SSE."""
|
| 75 |
+
content = await file.read()
|
| 76 |
+
df = pd.read_csv(io.StringIO(content.decode("utf-8")), quotechar='"')
|
| 77 |
+
if len(df) > 250:
|
| 78 |
+
df = df.head(250)
|
| 79 |
+
|
| 80 |
+
async def event_generator():
|
| 81 |
+
total = len(df)
|
| 82 |
+
completed = 0
|
| 83 |
+
for _, row in df.iterrows():
|
| 84 |
+
try:
|
| 85 |
+
result = predict_relation(row["parent"], row["child"], model, tokenizer, device)
|
| 86 |
+
completed += 1
|
| 87 |
+
payload = {
|
| 88 |
+
"parent": row["parent"],
|
| 89 |
+
"child": row["child"],
|
| 90 |
+
"relation": result,
|
| 91 |
+
"progress": completed / total
|
| 92 |
+
}
|
| 93 |
+
yield f"data: {json.dumps(payload)}\n\n"
|
| 94 |
+
# FORCER flush
|
| 95 |
+
await asyncio.sleep(0)
|
| 96 |
+
except Exception as e:
|
| 97 |
+
yield f"data: {json.dumps({'error': str(e), 'parent': row.get('parent'), 'child': row.get('child')})}\n\n"
|
| 98 |
+
await asyncio.sleep(0)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
| 102 |
+
|
| 103 |
+
|
| 104 |
@app.get("/samples")
|
| 105 |
def list_samples():
|
| 106 |
files = [f for f in os.listdir(SAMPLES_DIR) if f.endswith(".csv")]
|
|
|
|
| 114 |
return {"error": "Sample not found"}
|
| 115 |
return FileResponse(file_path, media_type="text/csv")
|
| 116 |
|
|
|
|
| 117 |
|
| 118 |
@app.post("/aba-upload")
|
| 119 |
async def aba_upload(file: UploadFile = File(...)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
content = await file.read()
|
| 121 |
+
text = content.decode("utf-8")
|
| 122 |
|
|
|
|
| 123 |
aba_framework = build_aba_framework_from_text(text)
|
| 124 |
aba_framework = prepare_aba_plus_framework(aba_framework)
|
| 125 |
aba_framework.make_aba_plus()
|
|
|
|
| 135 |
|
| 136 |
@app.get("/aba-examples")
|
| 137 |
def list_aba_examples():
|
|
|
|
| 138 |
examples = [f.name for f in EXAMPLES_DIR.glob("*.txt")]
|
| 139 |
return {"examples": examples}
|
| 140 |
|
| 141 |
|
| 142 |
@app.get("/aba-examples/{filename}")
|
| 143 |
def get_aba_example(filename: str):
|
|
|
|
| 144 |
file_path = EXAMPLES_DIR / filename
|
| 145 |
if not file_path.exists() or not file_path.is_file():
|
| 146 |
return {"error": "File not found"}
|