Spaces:
Sleeping
Sleeping
File size: 7,660 Bytes
b6710f4 e68ee99 a0e1005 ba9863e f2dc286 bd90063 be3e7e8 bb34072 f2dc286 bd90063 f2dc286 b6710f4 bd90063 ba9863e be3e7e8 bd90063 bb34072 f2dc286 bd90063 f2dc286 b6710f4 f2dc286 c5f9116 ba9863e c5f9116 ba9863e be3e7e8 ba9863e d02e800 ba9863e f2dc286 ba9863e bd90063 ba9863e bb34072 ba9863e be3e7e8 ba9863e bb34072 ba9863e be3e7e8 ba9863e bb34072 be3e7e8 b6710f4 be3e7e8 b6710f4 be3e7e8 b6710f4 be3e7e8 b6710f4 be3e7e8 b6710f4 be3e7e8 b6710f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 |
from gradual.models import GradualInput, GradualOutput
# from gradual.computations import compute_gradual_semantics
from gradual.computations import compute_gradual_space
from aba.aba_builder import prepare_aba_plus_framework, build_aba_framework_from_text
from relations.predict_bert import predict_relation
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from fastapi.responses import FileResponse, StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
import torch
import pandas as pd
from pathlib import Path
import asyncio
import json
import io
import os
cache_dir = "/tmp/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.makedirs(cache_dir, exist_ok=True)
# -------------------- Config -------------------- #
ABA_EXAMPLES_DIR = Path("./aba/examples")
SAMPLES_DIR = Path("./relations/examples/samples")
GRADUAL_EXAMPLES_DIR = Path("./gradual/examples")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "edgar-demeude/bert-argument"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.to(device)
# -------------------- App -------------------- #
app = FastAPI(title="Argument Mining API")
origins = ["http://localhost:3000", "http://127.0.0.1:3000"]
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# -------------------- Endpoints -------------------- #
@app.get("/")
def root():
return {"message": "Argument Mining API is running..."}
# --- Predictions --- #
@app.post("/predict-text")
def predict_text(arg1: str = Form(...), arg2: str = Form(...)):
"""Predict relation between two text arguments using BERT."""
result = predict_relation(arg1, arg2, model, tokenizer, device)
return {"arg1": arg1, "arg2": arg2, "relation": result}
@app.post("/predict-csv-stream")
async def predict_csv_stream(file: UploadFile):
"""Stream CSV predictions progressively using SSE."""
content = await file.read()
df = pd.read_csv(io.StringIO(content.decode("utf-8")), quotechar='"')
if len(df) > 250:
df = df.head(250)
async def event_generator():
total = len(df)
completed = 0
for _, row in df.iterrows():
try:
result = predict_relation(
row["parent"], row["child"], model, tokenizer, device)
completed += 1
payload = {
"parent": row["parent"],
"child": row["child"],
"relation": result,
"progress": completed / total
}
yield f"data: {json.dumps(payload)}\n\n"
# FORCER flush
await asyncio.sleep(0)
except Exception as e:
yield f"data: {json.dumps({'error': str(e), 'parent': row.get('parent'), 'child': row.get('child')})}\n\n"
await asyncio.sleep(0)
return StreamingResponse(event_generator(), media_type="text/event-stream")
@app.get("/samples")
def list_samples():
files = [f for f in os.listdir(SAMPLES_DIR) if f.endswith(".csv")]
return {"samples": files}
@app.get("/samples/{filename}")
def get_sample(filename: str):
file_path = os.path.join(SAMPLES_DIR, filename)
if not os.path.exists(file_path):
return {"error": "Sample not found"}
return FileResponse(file_path, media_type="text/csv")
# --- ABA --- #
@app.post("/aba-upload")
async def aba_upload(file: UploadFile = File(...)):
content = await file.read()
text = content.decode("utf-8")
aba_framework = build_aba_framework_from_text(text)
aba_framework.generate_arguments()
aba_framework.generate_attacks()
results = {
"assumptions": [str(a) for a in aba_framework.assumptions],
"arguments": [str(arg) for arg in aba_framework.arguments],
"attacks": [str(att) for att in aba_framework.attacks],
}
return results
@app.post("/aba-plus-upload")
async def aba_upload(file: UploadFile = File(...)):
content = await file.read()
text = content.decode("utf-8")
aba_framework = build_aba_framework_from_text(text)
aba_framework = prepare_aba_plus_framework(aba_framework)
aba_framework.make_aba_plus()
results = {
"assumptions": [str(a) for a in aba_framework.assumptions],
"arguments": [str(arg) for arg in aba_framework.arguments],
"attacks": [str(att) for att in aba_framework.attacks],
"reverse_attacks": [str(ratt) for ratt in aba_framework.reverse_attacks],
}
return results
@app.get("/aba-examples")
def list_aba_examples():
examples = [f.name for f in ABA_EXAMPLES_DIR.glob("*.txt")]
return {"examples": examples}
@app.get("/aba-examples/{filename}")
def get_aba_example(filename: str):
file_path = ABA_EXAMPLES_DIR / filename
if not file_path.exists() or not file_path.is_file():
return {"error": "File not found"}
return FileResponse(file_path, media_type="text/plain", filename=filename)
# --- Gradual semantics --- #
# @app.post("/gradual", response_model=GradualOutput)
# def compute_gradual(input_data: GradualInput):
# """API endpoint to compute Weighted h-Categorizer samples and convex hull."""
# return compute_gradual_semantics(
# A=input_data.A,
# R=input_data.R,
# n_samples=input_data.n_samples,
# max_iter=input_data.max_iter
# )
@app.post("/gradual", response_model=GradualOutput)
def compute_gradual(input_data: GradualInput):
"""
API endpoint to compute Weighted h-Categorizer samples
and their convex hull (acceptability degree space).
"""
num_args, hull_volume, hull_area, hull_points, samples, axes = compute_gradual_space(
num_args=input_data.num_args,
R=input_data.R,
n_samples=input_data.n_samples,
axes=input_data.axes,
controlled_args=input_data.controlled_args,
)
return GradualOutput(
num_args=num_args,
hull_volume=hull_volume,
hull_area=hull_area,
hull_points=hull_points,
samples=samples,
axes=axes,
)
@app.get("/gradual-examples")
def list_gradual_examples():
"""
List all available gradual semantics example files.
Each example must be a JSON file with structure:
{
# "args": ["A", "B", "C"],
# "relations": [["A", "B"], ["B", "C"]]
"num_args": 3,
"R": [["A", "B"], ["B", "C"], ["C", "A"]],
}
"""
if not GRADUAL_EXAMPLES_DIR.exists():
return {"examples": []}
examples = []
for file in GRADUAL_EXAMPLES_DIR.glob("*.json"):
examples.append({
"name": file.stem,
"path": file.name,
"content": None
})
return {"examples": examples}
@app.get("/gradual-examples/{example_name}")
def get_gradual_example(example_name: str):
"""
Return the content of a specific gradual example file.
Example: GET /gradual-examples/simple.json
"""
file_path = GRADUAL_EXAMPLES_DIR / example_name
if not file_path.exists():
raise HTTPException(status_code=404, detail="Example not found")
try:
with open(file_path, "r", encoding="utf-8") as f:
content = json.load(f)
return JSONResponse(content=content)
except json.JSONDecodeError:
raise HTTPException(
status_code=400, detail="Invalid JSON format in example file")
|