Spaces:
Sleeping
Sleeping
| from gradual.models import GradualInput, GradualOutput | |
| # from gradual.computations import compute_gradual_semantics | |
| from gradual.computations import compute_gradual_space | |
| from aba.aba_builder import prepare_aba_plus_framework, build_aba_framework_from_text | |
| from relations.predict_bert import predict_relation | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from fastapi.responses import FileResponse, StreamingResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| import torch | |
| import pandas as pd | |
| from pathlib import Path | |
| import asyncio | |
| import json | |
| import io | |
| import os | |
| cache_dir = "/tmp/hf_cache" | |
| os.environ["TRANSFORMERS_CACHE"] = cache_dir | |
| os.makedirs(cache_dir, exist_ok=True) | |
| # -------------------- Config -------------------- # | |
| ABA_EXAMPLES_DIR = Path("./aba/examples") | |
| SAMPLES_DIR = Path("./relations/examples/samples") | |
| GRADUAL_EXAMPLES_DIR = Path("./gradual/examples") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model_name = "edgar-demeude/bert-argument" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| model.to(device) | |
| # -------------------- App -------------------- # | |
| app = FastAPI(title="Argument Mining API") | |
| origins = ["http://localhost:3000", "http://127.0.0.1:3000"] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # -------------------- Endpoints -------------------- # | |
| def root(): | |
| return {"message": "Argument Mining API is running..."} | |
| # --- Predictions --- # | |
| def predict_text(arg1: str = Form(...), arg2: str = Form(...)): | |
| """Predict relation between two text arguments using BERT.""" | |
| result = predict_relation(arg1, arg2, model, tokenizer, device) | |
| return {"arg1": arg1, "arg2": arg2, "relation": result} | |
| async def predict_csv_stream(file: UploadFile): | |
| """Stream CSV predictions progressively using SSE.""" | |
| content = await file.read() | |
| df = pd.read_csv(io.StringIO(content.decode("utf-8")), quotechar='"') | |
| if len(df) > 250: | |
| df = df.head(250) | |
| async def event_generator(): | |
| total = len(df) | |
| completed = 0 | |
| for _, row in df.iterrows(): | |
| try: | |
| result = predict_relation( | |
| row["parent"], row["child"], model, tokenizer, device) | |
| completed += 1 | |
| payload = { | |
| "parent": row["parent"], | |
| "child": row["child"], | |
| "relation": result, | |
| "progress": completed / total | |
| } | |
| yield f"data: {json.dumps(payload)}\n\n" | |
| # FORCER flush | |
| await asyncio.sleep(0) | |
| except Exception as e: | |
| yield f"data: {json.dumps({'error': str(e), 'parent': row.get('parent'), 'child': row.get('child')})}\n\n" | |
| await asyncio.sleep(0) | |
| return StreamingResponse(event_generator(), media_type="text/event-stream") | |
| def list_samples(): | |
| files = [f for f in os.listdir(SAMPLES_DIR) if f.endswith(".csv")] | |
| return {"samples": files} | |
| def get_sample(filename: str): | |
| file_path = os.path.join(SAMPLES_DIR, filename) | |
| if not os.path.exists(file_path): | |
| return {"error": "Sample not found"} | |
| return FileResponse(file_path, media_type="text/csv") | |
| # --- ABA --- # | |
| async def aba_upload(file: UploadFile = File(...)): | |
| content = await file.read() | |
| text = content.decode("utf-8") | |
| aba_framework = build_aba_framework_from_text(text) | |
| aba_framework.generate_arguments() | |
| aba_framework.generate_attacks() | |
| results = { | |
| "assumptions": [str(a) for a in aba_framework.assumptions], | |
| "arguments": [str(arg) for arg in aba_framework.arguments], | |
| "attacks": [str(att) for att in aba_framework.attacks], | |
| } | |
| return results | |
| async def aba_upload(file: UploadFile = File(...)): | |
| content = await file.read() | |
| text = content.decode("utf-8") | |
| aba_framework = build_aba_framework_from_text(text) | |
| aba_framework = prepare_aba_plus_framework(aba_framework) | |
| aba_framework.make_aba_plus() | |
| results = { | |
| "assumptions": [str(a) for a in aba_framework.assumptions], | |
| "arguments": [str(arg) for arg in aba_framework.arguments], | |
| "attacks": [str(att) for att in aba_framework.attacks], | |
| "reverse_attacks": [str(ratt) for ratt in aba_framework.reverse_attacks], | |
| } | |
| return results | |
| def list_aba_examples(): | |
| examples = [f.name for f in ABA_EXAMPLES_DIR.glob("*.txt")] | |
| return {"examples": examples} | |
| def get_aba_example(filename: str): | |
| file_path = ABA_EXAMPLES_DIR / filename | |
| if not file_path.exists() or not file_path.is_file(): | |
| return {"error": "File not found"} | |
| return FileResponse(file_path, media_type="text/plain", filename=filename) | |
| # --- Gradual semantics --- # | |
| # @app.post("/gradual", response_model=GradualOutput) | |
| # def compute_gradual(input_data: GradualInput): | |
| # """API endpoint to compute Weighted h-Categorizer samples and convex hull.""" | |
| # return compute_gradual_semantics( | |
| # A=input_data.A, | |
| # R=input_data.R, | |
| # n_samples=input_data.n_samples, | |
| # max_iter=input_data.max_iter | |
| # ) | |
| def compute_gradual(input_data: GradualInput): | |
| """ | |
| API endpoint to compute Weighted h-Categorizer samples | |
| and their convex hull (acceptability degree space). | |
| """ | |
| num_args, hull_volume, hull_area, hull_points, samples, axes = compute_gradual_space( | |
| num_args=input_data.num_args, | |
| R=input_data.R, | |
| n_samples=input_data.n_samples, | |
| axes=input_data.axes, | |
| controlled_args=input_data.controlled_args, | |
| ) | |
| return GradualOutput( | |
| num_args=num_args, | |
| hull_volume=hull_volume, | |
| hull_area=hull_area, | |
| hull_points=hull_points, | |
| samples=samples, | |
| axes=axes, | |
| ) | |
| def list_gradual_examples(): | |
| """ | |
| List all available gradual semantics example files. | |
| Each example must be a JSON file with structure: | |
| { | |
| # "args": ["A", "B", "C"], | |
| # "relations": [["A", "B"], ["B", "C"]] | |
| "num_args": 3, | |
| "R": [["A", "B"], ["B", "C"], ["C", "A"]], | |
| } | |
| """ | |
| if not GRADUAL_EXAMPLES_DIR.exists(): | |
| return {"examples": []} | |
| examples = [] | |
| for file in GRADUAL_EXAMPLES_DIR.glob("*.json"): | |
| examples.append({ | |
| "name": file.stem, | |
| "path": file.name, | |
| "content": None | |
| }) | |
| return {"examples": examples} | |
| def get_gradual_example(example_name: str): | |
| """ | |
| Return the content of a specific gradual example file. | |
| Example: GET /gradual-examples/simple.json | |
| """ | |
| file_path = GRADUAL_EXAMPLES_DIR / example_name | |
| if not file_path.exists(): | |
| raise HTTPException(status_code=404, detail="Example not found") | |
| try: | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| content = json.load(f) | |
| return JSONResponse(content=content) | |
| except json.JSONDecodeError: | |
| raise HTTPException( | |
| status_code=400, detail="Invalid JSON format in example file") | |