import os cache_dir = "/tmp/hf_cache" os.environ["TRANSFORMERS_CACHE"] = cache_dir os.makedirs(cache_dir, exist_ok=True) from gradual.models import GradualInput, GradualOutput # from gradual.computations import compute_gradual_semantics from gradual.computations import compute_gradual_space from aba.aba_builder import prepare_aba_plus_framework, build_aba_framework_from_text from relations.predict_bert import predict_relation from transformers import AutoTokenizer, AutoModelForSequenceClassification from fastapi.responses import FileResponse, StreamingResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware from fastapi import FastAPI, UploadFile, File, Form, HTTPException import torch import pandas as pd from pathlib import Path import asyncio import json import io from aba.models import ( RuleDTO, FrameworkSnapshot, TransformationStep, ABAApiResponseModel, ABAPlusDTO, MetaInfo, ) from copy import deepcopy from datetime import datetime # -------------------- Config -------------------- # ABA_EXAMPLES_DIR = Path("./aba/examples") SAMPLES_DIR = Path("./relations/examples/samples") GRADUAL_EXAMPLES_DIR = Path("./gradual/examples") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_name = "edgar-demeude/bert-argument" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name) model.to(device) # -------------------- App -------------------- # app = FastAPI(title="Argument Mining API") origins = ["http://localhost:3000", "http://127.0.0.1:3000"] app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # -------------------- Endpoints -------------------- # @app.get("/") def root(): return {"message": "Argument Mining API is running..."} # --- Predictions --- # @app.post("/predict-text") def predict_text(arg1: str = Form(...), arg2: str = Form(...)): """Predict relation between two text arguments using BERT.""" result = predict_relation(arg1, arg2, model, tokenizer, device) return {"arg1": arg1, "arg2": arg2, "relation": result} @app.post("/predict-csv-stream") async def predict_csv_stream(file: UploadFile): """Stream CSV predictions progressively using SSE.""" content = await file.read() df = pd.read_csv(io.StringIO(content.decode("utf-8")), quotechar='"') if len(df) > 250: df = df.head(250) async def event_generator(): total = len(df) completed = 0 for _, row in df.iterrows(): try: result = predict_relation( row["parent"], row["child"], model, tokenizer, device) completed += 1 payload = { "parent": row["parent"], "child": row["child"], "relation": result, "progress": completed / total } yield f"data: {json.dumps(payload)}\n\n" # FORCER flush await asyncio.sleep(0) except Exception as e: yield f"data: {json.dumps({'error': str(e), 'parent': row.get('parent'), 'child': row.get('child')})}\n\n" await asyncio.sleep(0) return StreamingResponse(event_generator(), media_type="text/event-stream") @app.get("/samples") def list_samples(): files = [f for f in os.listdir(SAMPLES_DIR) if f.endswith(".csv")] return {"samples": files} @app.get("/samples/{filename}") def get_sample(filename: str): file_path = os.path.join(SAMPLES_DIR, filename) if not os.path.exists(file_path): return {"error": "Sample not found"} return FileResponse(file_path, media_type="text/csv") # --- ABA --- # def _make_snapshot(fw) -> FrameworkSnapshot: return FrameworkSnapshot( language=[str(l) for l in sorted(fw.language, key=str)], assumptions=[str(a) for a in sorted(fw.assumptions, key=str)], rules=[ RuleDTO( id=r.rule_name, head=str(r.head), body=[str(b) for b in sorted(r.body, key=str)], ) for r in sorted(fw.rules, key=lambda r: r.rule_name) ], contraries=[ (str(c.contraried_literal), str(c.contrary_attacker)) for c in sorted(fw.contraries, key=str) ], preferences={ str(k): [str(v) for v in sorted(vals, key=str)] for k, vals in (fw.preferences or {}).items() } if getattr(fw, "preferences", None) else None, ) def _format_set(s) -> str: # s may be a Python set/frozenset of Literal or strings. try: items = sorted([str(x) for x in s], key=str) except Exception: # fallback if s is already a string like "{a,b}" return str(s) return "{" + ",".join(items) + "}" async def _process_aba_framework( text: str, enable_aba_plus: bool = False, ) -> dict: """ Core processing logic for ABA frameworks. Args: text: The uploaded file content as text enable_aba_plus: If True, compute ABA+ elements Returns: Complete response with before/after snapshots and all computations """ # === 1. Build original framework === base_framework = build_aba_framework_from_text(text) base_framework.generate_arguments() base_framework.generate_attacks() original_snapshot = _make_snapshot(base_framework) # --- Classical (argument-level) data --- original_arguments = [str(arg) for arg in sorted(base_framework.arguments, key=str)] original_attacks = [str(att) for att in sorted(base_framework.attacks, key=str)] original_reverse_attacks = [] # === 2. Transform framework === transformed_framework = deepcopy(base_framework).transform_aba() transformations = _detect_transformations(base_framework, transformed_framework) # --- Initialize containers --- original_assumption_sets = [] final_assumption_sets = [] original_aba_plus_attacks = [] final_aba_plus_attacks = [] original_reverse_attacks = [] final_reverse_attacks = [] warnings = [] # === 3. ABA+ computations === if enable_aba_plus: # --- ABA+ on original framework --- fw_plus_original = prepare_aba_plus_framework(deepcopy(base_framework)) fw_plus_original.generate_arguments() fw_plus_original.generate_attacks() fw_plus_original.make_aba_plus() original_assumption_sets = sorted( [_format_set(s) for s in getattr(fw_plus_original, "assumption_combinations", [])], key=lambda x: (len(x), x), ) original_aba_plus_attacks = [ f"{_format_set(src)} → {_format_set(dst)}" for (src, dst) in sorted( getattr(fw_plus_original, "normal_attacks", []), key=lambda p: (str(p[0]), str(p[1])), ) ] original_reverse_attacks = [ f"{_format_set(src)} → {_format_set(dst)}" for (src, dst) in sorted( getattr(fw_plus_original, "reverse_attacks", []), key=lambda p: (str(p[0]), str(p[1])), ) ] # --- Ensure transformed framework is consistent before ABA+ --- transformed_framework.generate_arguments() transformed_framework.generate_attacks() # --- Compute ABA+ on transformed framework --- fw_plus_transformed = prepare_aba_plus_framework(deepcopy(transformed_framework)) fw_plus_transformed.generate_arguments() fw_plus_transformed.generate_attacks() fw_plus_transformed.make_aba_plus() final_assumption_sets = sorted( [_format_set(s) for s in getattr(fw_plus_transformed, "assumption_combinations", [])], key=lambda x: (len(x), x), ) # Debug sanity checks print("DEBUG: fw_plus_transformed.assumptions =", getattr(fw_plus_transformed, "assumptions", [])) print("DEBUG: fw_plus_transformed.normal_attacks =", getattr(fw_plus_transformed, "normal_attacks", [])) final_aba_plus_attacks = [ f"{_format_set(src)} → {_format_set(dst)}" for (src, dst) in sorted( getattr(fw_plus_transformed, "normal_attacks", []), key=lambda p: (str(p[0]), str(p[1])), ) ] final_reverse_attacks = [ f"{_format_set(src)} → {_format_set(dst)}" for (src, dst) in sorted( getattr(fw_plus_transformed, "reverse_attacks", []), key=lambda p: (str(p[0]), str(p[1])), ) ] warnings = _validate_aba_plus_framework(fw_plus_transformed) else: warnings = _validate_framework(transformed_framework) # === 4. Classical ABA computations (arguments + attacks) === base_framework.generate_arguments() base_framework.generate_attacks() transformed_framework.generate_arguments() transformed_framework.generate_attacks() original_arguments = [str(arg) for arg in sorted(base_framework.arguments, key=str)] original_arguments_attacks = [str(att) for att in sorted(base_framework.attacks, key=str)] final_arguments = [str(arg) for arg in sorted(transformed_framework.arguments, key=str)] final_arguments_attacks = [str(att) for att in sorted(transformed_framework.attacks, key=str)] # === 5. Snapshots === original_snapshot = _make_snapshot(base_framework) final_snapshot = _make_snapshot(transformed_framework) # === 6. Build response with CORRECT structure === before_state = { "framework": original_snapshot.dict(), "arguments": original_arguments, "arguments_attacks": original_arguments_attacks, "argument_attacks": original_arguments_attacks, # same as arguments_attacks for classical ABA "assumption_set_attacks": original_aba_plus_attacks if enable_aba_plus else [], "reverse_attacks": original_reverse_attacks if enable_aba_plus else [], "assumption_sets": original_assumption_sets if enable_aba_plus else [], } after_state = { "framework": final_snapshot.dict(), "arguments": final_arguments, "arguments_attacks": final_arguments_attacks, "argument_attacks": final_arguments_attacks, # same as arguments_attacks for classical ABA "assumption_set_attacks": original_aba_plus_attacks if enable_aba_plus else [], "reverse_attacks": original_reverse_attacks if enable_aba_plus else [], "assumption_sets": original_assumption_sets if enable_aba_plus else [], } response = { "meta": { "request_id": f"req-{datetime.utcnow().timestamp()}", "timestamp": datetime.utcnow().isoformat(), "transformed": any(t.get("applied", False) for t in [_transform_to_dict(t) for t in transformations]), "transformations_applied": [ t.get("step") for t in [_transform_to_dict(t) for t in transformations] if t.get("applied", False) ], "warnings": warnings, "errors": [], }, "before_transformation": before_state, "after_transformation": after_state, "transformations": [_transform_to_dict(t) for t in transformations], } return response def _detect_transformations( base_framework, transformed_framework, ) -> list: """ Detect and describe which transformations were applied. """ transformations = [] if transformed_framework.language == base_framework.language and \ transformed_framework.rules == base_framework.rules: # No transformation needed transformations.append({ "step": "none", "applied": False, "reason": "The framework was already non-circular and atomic.", "description": "No transformation applied.", "result_snapshot": None, }) return transformations # Determine transformation type was_circular = base_framework.is_aba_circular() was_atomic = base_framework.is_aba_atomic() step_name = "non_circular" if was_circular else "atomic" reason = "circular dependencies" if was_circular else "non-atomic rules" transformations.append({ "step": step_name, "applied": True, "reason": f"The framework contained {reason}.", "description": f"Transformed into a {step_name.replace('_', '-')} version.", "result_snapshot": _make_snapshot(transformed_framework), }) return transformations def _transform_to_dict(t): """Convert TransformationStep to dict if needed.""" if isinstance(t, dict): return t return { "step": t.step, "applied": t.applied, "reason": t.reason, "description": t.description, "result_snapshot": t.result_snapshot, } def _validate_framework(framework) -> list[str]: """ Validate framework and return any warnings. """ warnings = [] if hasattr(framework, "preferences") and framework.preferences: all_assumptions = {str(a) for a in framework.assumptions} pref_keys = {str(k) for k in framework.preferences.keys()} if not pref_keys.issubset(all_assumptions): warnings.append( "Incomplete preference relation: not all assumptions appear in the preference mapping." ) return warnings def _validate_aba_plus_framework(framework) -> list[str]: """ Validate ABA+ framework and return any warnings. """ return _validate_framework(framework) @app.post("/aba-upload") async def aba_upload(file: UploadFile = File(...)): """ Handle classical ABA framework generation. Returns: original & final frameworks with arguments and attacks (no ABA+ data) """ content = await file.read() text = content.decode("utf-8") return await _process_aba_framework(text, enable_aba_plus=False) @app.post("/aba-plus-upload") async def aba_plus_upload(file: UploadFile = File(...)): """ Handle ABA+ framework generation. Returns: original & final frameworks with arguments, attacks, AND reverse_attacks for both """ content = await file.read() text = content.decode("utf-8") return await _process_aba_framework(text, enable_aba_plus=True) @app.get("/aba-examples") def list_aba_examples(): examples = [f.name for f in ABA_EXAMPLES_DIR.glob("*.txt")] return {"examples": examples} @app.get("/aba-examples/{filename}") def get_aba_example(filename: str): file_path = ABA_EXAMPLES_DIR / filename if not file_path.exists() or not file_path.is_file(): return {"error": "File not found"} return FileResponse(file_path, media_type="text/plain", filename=filename) # --- Gradual semantics --- # # @app.post("/gradual", response_model=GradualOutput) # def compute_gradual(input_data: GradualInput): # """API endpoint to compute Weighted h-Categorizer samples and convex hull.""" # return compute_gradual_semantics( # A=input_data.A, # R=input_data.R, # n_samples=input_data.n_samples, # max_iter=input_data.max_iter # ) @app.post("/gradual", response_model=GradualOutput) def compute_gradual(input_data: GradualInput): """ API endpoint to compute Weighted h-Categorizer samples and their convex hull (acceptability degree space). """ num_args, hull_volume, hull_area, hull_points, samples, axes = compute_gradual_space( num_args=input_data.num_args, R=input_data.R, n_samples=input_data.n_samples, axes=input_data.axes, controlled_args=input_data.controlled_args, ) return GradualOutput( num_args=num_args, hull_volume=hull_volume, hull_area=hull_area, hull_points=hull_points, samples=samples, axes=axes, ) @app.get("/gradual-examples") def list_gradual_examples(): """ List all available gradual semantics example files. Each example must be a JSON file with structure: { # "args": ["A", "B", "C"], # "relations": [["A", "B"], ["B", "C"]] "num_args": 3, "R": [["A", "B"], ["B", "C"], ["C", "A"]], } """ if not GRADUAL_EXAMPLES_DIR.exists(): return {"examples": []} examples = [] for file in GRADUAL_EXAMPLES_DIR.glob("*.json"): examples.append({ "name": file.stem, "path": file.name, "content": None }) return {"examples": examples} @app.get("/gradual-examples/{example_name}") def get_gradual_example(example_name: str): """ Return the content of a specific gradual example file. Example: GET /gradual-examples/simple.json """ file_path = GRADUAL_EXAMPLES_DIR / example_name if not file_path.exists(): raise HTTPException(status_code=404, detail="Example not found") try: with open(file_path, "r", encoding="utf-8") as f: content = json.load(f) return JSONResponse(content=content) except json.JSONDecodeError: raise HTTPException( status_code=400, detail="Invalid JSON format in example file")