p2002814 commited on
Commit
f2dc286
·
1 Parent(s): c5f9116

now predicting csv progressively for a smooth animation

Browse files
Files changed (1) hide show
  1. app.py +60 -61
app.py CHANGED
@@ -1,98 +1,106 @@
1
  import os
 
 
 
 
2
 
3
- cache_dir = "/tmp/hf_cache"
4
- os.environ["TRANSFORMERS_CACHE"] = cache_dir
5
- os.makedirs(cache_dir, exist_ok=True)
6
-
7
- from fastapi import FastAPI, UploadFile, File, Form
8
- from fastapi.responses import FileResponse
9
- from fastapi.middleware.cors import CORSMiddleware
10
  import pandas as pd
11
- from pathlib import Path
12
  import torch
13
- import io
14
-
15
- from relations.predict_bert import predict_relation
16
-
17
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
18
 
19
- # ABA imports
20
  from aba.aba_builder import prepare_aba_plus_framework, build_aba_framework_from_text
21
 
22
- app = FastAPI(title="Argument Mining API")
23
-
24
- # CORS middleware
25
- origins = [
26
- "http://localhost:3000",
27
- "http://127.0.0.1:3000",
28
- ]
29
-
30
- app.add_middleware(
31
- CORSMiddleware,
32
- allow_origins=["*"], # allow all origins
33
- allow_credentials=True,
34
- allow_methods=["*"],
35
- allow_headers=["*"],
36
- )
37
 
38
  EXAMPLES_DIR = Path("./aba/exemples")
39
  SAMPLES_DIR = Path("./relations/exemples/samples")
40
-
41
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
 
43
- # Load model at startup once
44
  model_name = "edgar-demeude/bert-argument"
45
  tokenizer = AutoTokenizer.from_pretrained(model_name)
46
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
47
  model.to(device)
48
 
 
 
49
 
 
 
 
 
 
 
 
 
 
 
50
  @app.get("/")
51
  def root():
52
  return {"message": "Argument Mining API is running..."}
53
 
54
 
55
- # ---------------- BERT Prediction Endpoints ---------------- #
56
-
57
  @app.post("/predict-text")
58
  def predict_text(arg1: str = Form(...), arg2: str = Form(...)):
59
  """Predict relation between two text arguments using BERT."""
60
  result = predict_relation(arg1, arg2, model, tokenizer, device)
61
- return {
62
- "arg1": arg1,
63
- "arg2": arg2,
64
- "relation": result
65
- }
66
 
67
 
68
  @app.post("/predict-csv")
69
  async def predict_csv(file: UploadFile):
70
  """Predict relations for pairs of arguments from a CSV file (max 250 rows)."""
71
  content = await file.read()
72
- # Utiliser StringIO + quotechar='"'
73
  df = pd.read_csv(io.StringIO(content.decode("utf-8")), quotechar='"')
74
-
75
  if len(df) > 250:
76
  df = df.head(250)
77
 
78
  results = []
79
  for _, row in df.iterrows():
80
- result = predict_relation(
81
- row["parent"],
82
- row["child"],
83
- model,
84
- tokenizer,
85
- device
86
- )
87
- results.append({
88
- "parent": row["parent"],
89
- "child": row["child"],
90
- "relation": result
91
- })
92
 
93
  return {"results": results, "note": "Limited to 250 rows max"}
94
 
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  @app.get("/samples")
97
  def list_samples():
98
  files = [f for f in os.listdir(SAMPLES_DIR) if f.endswith(".csv")]
@@ -106,19 +114,12 @@ def get_sample(filename: str):
106
  return {"error": "Sample not found"}
107
  return FileResponse(file_path, media_type="text/csv")
108
 
109
- # ---------------- ABA API ---------------- #
110
 
111
  @app.post("/aba-upload")
112
  async def aba_upload(file: UploadFile = File(...)):
113
- """
114
- Upload a .txt file containing an ABA framework definition
115
- and return the generated ABA+ framework.
116
- """
117
- # Read file contents
118
  content = await file.read()
119
- text = content.decode("utf-8") # assume UTF-8 encoding
120
 
121
- # Build ABA framework
122
  aba_framework = build_aba_framework_from_text(text)
123
  aba_framework = prepare_aba_plus_framework(aba_framework)
124
  aba_framework.make_aba_plus()
@@ -134,14 +135,12 @@ async def aba_upload(file: UploadFile = File(...)):
134
 
135
  @app.get("/aba-examples")
136
  def list_aba_examples():
137
- """Lists all sample files available on the server side."""
138
  examples = [f.name for f in EXAMPLES_DIR.glob("*.txt")]
139
  return {"examples": examples}
140
 
141
 
142
  @app.get("/aba-examples/{filename}")
143
  def get_aba_example(filename: str):
144
- """Returns the contents of a specific ABA sample file."""
145
  file_path = EXAMPLES_DIR / filename
146
  if not file_path.exists() or not file_path.is_file():
147
  return {"error": "File not found"}
 
1
  import os
2
+ import io
3
+ import json
4
+ import asyncio
5
+ from pathlib import Path
6
 
 
 
 
 
 
 
 
7
  import pandas as pd
 
8
  import torch
9
+ from fastapi import FastAPI, UploadFile, File, Form
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from fastapi.responses import FileResponse, StreamingResponse
 
12
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
13
 
14
+ from relations.predict_bert import predict_relation
15
  from aba.aba_builder import prepare_aba_plus_framework, build_aba_framework_from_text
16
 
17
+ # -------------------- Config -------------------- #
18
+ cache_dir = "/tmp/hf_cache"
19
+ os.environ["TRANSFORMERS_CACHE"] = cache_dir
20
+ os.makedirs(cache_dir, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  EXAMPLES_DIR = Path("./aba/exemples")
23
  SAMPLES_DIR = Path("./relations/exemples/samples")
 
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
 
 
26
  model_name = "edgar-demeude/bert-argument"
27
  tokenizer = AutoTokenizer.from_pretrained(model_name)
28
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
29
  model.to(device)
30
 
31
+ # -------------------- App -------------------- #
32
+ app = FastAPI(title="Argument Mining API")
33
 
34
+ origins = ["http://localhost:3000", "http://127.0.0.1:3000"]
35
+ app.add_middleware(
36
+ CORSMiddleware,
37
+ allow_origins=["*"],
38
+ allow_credentials=True,
39
+ allow_methods=["*"],
40
+ allow_headers=["*"],
41
+ )
42
+
43
+ # -------------------- Endpoints -------------------- #
44
  @app.get("/")
45
  def root():
46
  return {"message": "Argument Mining API is running..."}
47
 
48
 
 
 
49
  @app.post("/predict-text")
50
  def predict_text(arg1: str = Form(...), arg2: str = Form(...)):
51
  """Predict relation between two text arguments using BERT."""
52
  result = predict_relation(arg1, arg2, model, tokenizer, device)
53
+ return {"arg1": arg1, "arg2": arg2, "relation": result}
 
 
 
 
54
 
55
 
56
  @app.post("/predict-csv")
57
  async def predict_csv(file: UploadFile):
58
  """Predict relations for pairs of arguments from a CSV file (max 250 rows)."""
59
  content = await file.read()
 
60
  df = pd.read_csv(io.StringIO(content.decode("utf-8")), quotechar='"')
 
61
  if len(df) > 250:
62
  df = df.head(250)
63
 
64
  results = []
65
  for _, row in df.iterrows():
66
+ result = predict_relation(row["parent"], row["child"], model, tokenizer, device)
67
+ results.append({"parent": row["parent"], "child": row["child"], "relation": result})
 
 
 
 
 
 
 
 
 
 
68
 
69
  return {"results": results, "note": "Limited to 250 rows max"}
70
 
71
 
72
+ @app.post("/predict-csv-stream")
73
+ async def predict_csv_stream(file: UploadFile):
74
+ """Stream CSV predictions progressively using SSE."""
75
+ content = await file.read()
76
+ df = pd.read_csv(io.StringIO(content.decode("utf-8")), quotechar='"')
77
+ if len(df) > 250:
78
+ df = df.head(250)
79
+
80
+ async def event_generator():
81
+ total = len(df)
82
+ completed = 0
83
+ for _, row in df.iterrows():
84
+ try:
85
+ result = predict_relation(row["parent"], row["child"], model, tokenizer, device)
86
+ completed += 1
87
+ payload = {
88
+ "parent": row["parent"],
89
+ "child": row["child"],
90
+ "relation": result,
91
+ "progress": completed / total
92
+ }
93
+ yield f"data: {json.dumps(payload)}\n\n"
94
+ # FORCER flush
95
+ await asyncio.sleep(0)
96
+ except Exception as e:
97
+ yield f"data: {json.dumps({'error': str(e), 'parent': row.get('parent'), 'child': row.get('child')})}\n\n"
98
+ await asyncio.sleep(0)
99
+
100
+
101
+ return StreamingResponse(event_generator(), media_type="text/event-stream")
102
+
103
+
104
  @app.get("/samples")
105
  def list_samples():
106
  files = [f for f in os.listdir(SAMPLES_DIR) if f.endswith(".csv")]
 
114
  return {"error": "Sample not found"}
115
  return FileResponse(file_path, media_type="text/csv")
116
 
 
117
 
118
  @app.post("/aba-upload")
119
  async def aba_upload(file: UploadFile = File(...)):
 
 
 
 
 
120
  content = await file.read()
121
+ text = content.decode("utf-8")
122
 
 
123
  aba_framework = build_aba_framework_from_text(text)
124
  aba_framework = prepare_aba_plus_framework(aba_framework)
125
  aba_framework.make_aba_plus()
 
135
 
136
  @app.get("/aba-examples")
137
  def list_aba_examples():
 
138
  examples = [f.name for f in EXAMPLES_DIR.glob("*.txt")]
139
  return {"examples": examples}
140
 
141
 
142
  @app.get("/aba-examples/{filename}")
143
  def get_aba_example(filename: str):
 
144
  file_path = EXAMPLES_DIR / filename
145
  if not file_path.exists() or not file_path.is_file():
146
  return {"error": "File not found"}