p2002814 commited on
Commit
bb34072
·
1 Parent(s): a1fcb70

now using the best fine tuned bert model, uploaded to hugging face hub

Browse files
.gitattributes CHANGED
@@ -1,2 +1 @@
1
- models/*.pth filter=lfs diff=lfs merge=lfs -text
2
- models/model.pth filter=lfs diff=lfs merge=lfs -text
 
1
+ models/** filter=lfs diff=lfs merge=lfs -text
 
app.py CHANGED
@@ -3,15 +3,16 @@ from fastapi.responses import FileResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  import pandas as pd
5
  from pathlib import Path
 
6
 
7
  from relations.tests import run_tests
8
- from relations.predict import predict_relation
9
- from relations.mlp import load_model_and_metadata
10
- from relations.processor import ArgumentDataProcessor
11
  from exemples.claims import test_cases
12
 
 
 
13
  # ABA imports
14
- from aba.aba_builder import build_aba_framework, prepare_aba_plus_framework, build_aba_framework_from_text
15
 
16
  app = FastAPI(title="Argument Mining API")
17
 
@@ -31,59 +32,60 @@ app.add_middleware(
31
 
32
  EXAMPLES_DIR = Path("./aba/exemples")
33
 
34
- # Load ML model at startup
35
- PYTORCH_MODEL_PATH = "models/model.pth"
36
- model_type = "pytorch"
37
- model, embedding_model, best_threshold, label_encoder = load_model_and_metadata(
38
- PYTORCH_MODEL_PATH, model_type
39
- )
40
- processor = ArgumentDataProcessor()
 
41
 
42
  @app.get("/")
43
  def root():
44
  return {"message": "Argument Mining API is running..."}
45
 
46
 
47
- # ---------------- ML Prediction Endpoints ---------------- #
48
 
49
  @app.post("/predict-test")
50
  def predict_test():
51
- """Run predefined test cases for model validation."""
52
- run_tests(model, embedding_model, processor, best_threshold, label_encoder, model_type, test_cases)
53
  return {"message": "Test cases executed. Check server logs for details."}
54
 
55
 
56
  @app.post("/predict-text")
57
  def predict_text(arg1: str = Form(...), arg2: str = Form(...)):
58
- """Predict relation between two text arguments."""
59
- relation = predict_relation(arg1, arg2, model, embedding_model, processor, best_threshold, label_encoder, model_type)
60
- return {"arg1": arg1, "arg2": arg2, "relation": relation}
 
 
 
 
61
 
62
 
63
  @app.post("/predict-csv")
64
  async def predict_csv(file: UploadFile):
65
  """Predict relations for pairs of arguments from a CSV file (max 100 rows)."""
66
  df = pd.read_csv(file.file)
67
-
68
  if len(df) > 100:
69
  df = df.head(100)
70
 
71
  results = []
72
  for _, row in df.iterrows():
73
- relation = predict_relation(
74
  row["parent"],
75
  row["child"],
76
  model,
77
- embedding_model,
78
- processor,
79
- best_threshold,
80
- label_encoder,
81
- model_type
82
  )
83
  results.append({
84
  "parent": row["parent"],
85
  "child": row["child"],
86
- "relation": relation
87
  })
88
 
89
  return {"results": results, "note": "Limited to 100 rows max"}
@@ -114,16 +116,18 @@ async def aba_upload(file: UploadFile = File(...)):
114
  }
115
  return results
116
 
 
117
  @app.get("/aba-examples")
118
  def list_aba_examples():
119
  """Lists all sample files available on the server side."""
120
  examples = [f.name for f in EXAMPLES_DIR.glob("*.txt")]
121
  return {"examples": examples}
122
 
 
123
  @app.get("/aba-examples/{filename}")
124
  def get_aba_example(filename: str):
125
  """Returns the contents of a specific ABA sample file."""
126
  file_path = EXAMPLES_DIR / filename
127
  if not file_path.exists() or not file_path.is_file():
128
  return {"error": "File not found"}
129
- return FileResponse(file_path, media_type="text/plain", filename=filename)
 
3
  from fastapi.middleware.cors import CORSMiddleware
4
  import pandas as pd
5
  from pathlib import Path
6
+ import torch
7
 
8
  from relations.tests import run_tests
9
+ from relations.predict_bert import predict_relation
 
 
10
  from exemples.claims import test_cases
11
 
12
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
13
+
14
  # ABA imports
15
+ from aba.aba_builder import prepare_aba_plus_framework, build_aba_framework_from_text
16
 
17
  app = FastAPI(title="Argument Mining API")
18
 
 
32
 
33
  EXAMPLES_DIR = Path("./aba/exemples")
34
 
35
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+
37
+ # Load model at startup once
38
+ model_name = "edgar-demeude/bert-argument"
39
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
40
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
41
+ model.to(device)
42
+
43
 
44
  @app.get("/")
45
  def root():
46
  return {"message": "Argument Mining API is running..."}
47
 
48
 
49
+ # ---------------- BERT Prediction Endpoints ---------------- #
50
 
51
  @app.post("/predict-test")
52
  def predict_test():
53
+ """Run predefined test cases for BERT model validation."""
54
+ run_tests(model, tokenizer, device, test_cases)
55
  return {"message": "Test cases executed. Check server logs for details."}
56
 
57
 
58
  @app.post("/predict-text")
59
  def predict_text(arg1: str = Form(...), arg2: str = Form(...)):
60
+ """Predict relation between two text arguments using BERT."""
61
+ result = predict_relation(arg1, arg2, model, tokenizer, device)
62
+ return {
63
+ "arg1": arg1,
64
+ "arg2": arg2,
65
+ "relation": result
66
+ }
67
 
68
 
69
  @app.post("/predict-csv")
70
  async def predict_csv(file: UploadFile):
71
  """Predict relations for pairs of arguments from a CSV file (max 100 rows)."""
72
  df = pd.read_csv(file.file)
 
73
  if len(df) > 100:
74
  df = df.head(100)
75
 
76
  results = []
77
  for _, row in df.iterrows():
78
+ result = predict_relation(
79
  row["parent"],
80
  row["child"],
81
  model,
82
+ tokenizer,
83
+ device
 
 
 
84
  )
85
  results.append({
86
  "parent": row["parent"],
87
  "child": row["child"],
88
+ "relation": result
89
  })
90
 
91
  return {"results": results, "note": "Limited to 100 rows max"}
 
116
  }
117
  return results
118
 
119
+
120
  @app.get("/aba-examples")
121
  def list_aba_examples():
122
  """Lists all sample files available on the server side."""
123
  examples = [f.name for f in EXAMPLES_DIR.glob("*.txt")]
124
  return {"examples": examples}
125
 
126
+
127
  @app.get("/aba-examples/{filename}")
128
  def get_aba_example(filename: str):
129
  """Returns the contents of a specific ABA sample file."""
130
  file_path = EXAMPLES_DIR / filename
131
  if not file_path.exists() or not file_path.is_file():
132
  return {"error": "File not found"}
133
+ return FileResponse(file_path, media_type="text/plain", filename=filename)
hugging_hub.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # File to automatically push the model to hugging face hub
2
+
3
+ from huggingface_hub import upload_folder
4
+
5
+ upload_folder(
6
+ repo_id="edgar-demeude/bert-argument",
7
+ folder_path="./models/bert-argument",
8
+ repo_type="model"
9
+ )
models/model.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4a8a1413586023e10aadde9a86e01b6425814605767924e818f1e189926ad86b
3
- size 9667219
 
 
 
 
relations/predict.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
  from .embeddings import generate_embeddings
3
 
4
- def predict_relation(arg1, arg2, model, embedding_model, processor, best_threshold, label_encoder, model_type="pytorch"):
5
  embeddings = generate_embeddings(arg1, arg2, embedding_model, processor)
6
 
7
  if model_type == "pytorch":
 
1
  import torch
2
  from .embeddings import generate_embeddings
3
 
4
+ def predict_relation_old(arg1, arg2, model, embedding_model, processor, best_threshold, label_encoder, model_type="pytorch"):
5
  embeddings = generate_embeddings(arg1, arg2, embedding_model, processor)
6
 
7
  if model_type == "pytorch":
relations/predict_bert.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+
4
+ def load_bert_model(model_path="../models/bert-argument", device=None):
5
+ if device is None:
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
9
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
10
+ model.to(device)
11
+ model.eval()
12
+
13
+ return model, tokenizer, device
14
+
15
+ def predict_relation(parent_text, child_text, model, tokenizer, device, max_length=256):
16
+ """
17
+ Predicts whether the relation between parent and child is Support or Attack.
18
+ """
19
+ model.eval()
20
+
21
+ # Tokenization
22
+ encoding = tokenizer(
23
+ parent_text,
24
+ child_text,
25
+ add_special_tokens=True,
26
+ max_length=max_length,
27
+ padding='max_length',
28
+ truncation='only_second',
29
+ return_attention_mask=True,
30
+ return_tensors='pt'
31
+ )
32
+
33
+ input_ids = encoding['input_ids'].to(device)
34
+ attention_mask = encoding['attention_mask'].to(device)
35
+
36
+ with torch.no_grad():
37
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
38
+ logits = outputs.logits
39
+ probs = torch.softmax(logits, dim=1)
40
+ pred = torch.argmax(probs, dim=1).item()
41
+ confidence = probs[0][pred].item()
42
+
43
+ relation = "Support" if pred == 1 else "Attack"
44
+
45
+ return {
46
+ "predicted_label": relation,
47
+ "probability": confidence,
48
+ "confidence": confidence
49
+ }
relations/tests.py CHANGED
@@ -1,12 +1,11 @@
1
  import time
2
- from .predict import predict_relation
3
 
4
- def print_pretty_prediction(result: dict, test_case_num: int, expected: str, claim1: str, claim2: str, best_threshold: float):
5
  """Pretty-print the result of a single test case."""
6
  prediction = result["predicted_label"]
7
  confidence = result["confidence"]
8
  probability = result["probability"]
9
-
10
  status = "✅ Correct" if prediction == expected else "❌ Incorrect"
11
 
12
  print(f"\n{'='*70}")
@@ -16,16 +15,13 @@ def print_pretty_prediction(result: dict, test_case_num: int, expected: str, cla
16
  print(f"Claim 2: {claim2}")
17
  print(f"\nExpected: {expected}")
18
  print(f"Predicted: {prediction}")
19
- print(f"Probability: {probability:.4f} (threshold: {best_threshold:.3f})")
20
  print(f"Confidence: {confidence:.2%}")
21
  print(f"Status: {status}")
22
 
23
 
24
- def run_tests(model, embedding_model, processor, best_threshold, label_encoder, model_type, test_cases):
25
- """
26
- Run a list of test cases and display results.
27
- Each test case must be a dict with keys: 'claim1', 'claim2', 'expected'
28
- """
29
  print("\n" + "="*70)
30
  print("RUNNING TEST CASES")
31
  print("="*70)
@@ -38,28 +34,17 @@ def run_tests(model, embedding_model, processor, best_threshold, label_encoder,
38
  case["claim1"],
39
  case["claim2"],
40
  model,
41
- embedding_model,
42
- processor,
43
- best_threshold,
44
- label_encoder,
45
- model_type,
46
- )
47
- print_pretty_prediction(
48
- result,
49
- test_case_num=i,
50
- expected=case["expected"],
51
- claim1=case["claim1"],
52
- claim2=case["claim2"],
53
- best_threshold=best_threshold
54
  )
 
55
 
56
  if result["predicted_label"] == case["expected"]:
57
  correct_predictions += 1
58
 
59
- # Final summary
60
  accuracy = (correct_predictions / len(test_cases)) * 100
61
  elapsed_time = time.time() - start_time
62
-
63
  print(f"\n{'='*70}")
64
  print("SUMMARY")
65
  print(f"{'='*70}")
 
1
  import time
2
+ from .predict_bert import predict_relation
3
 
4
+ def print_pretty_prediction(result: dict, test_case_num: int, expected: str, claim1: str, claim2: str):
5
  """Pretty-print the result of a single test case."""
6
  prediction = result["predicted_label"]
7
  confidence = result["confidence"]
8
  probability = result["probability"]
 
9
  status = "✅ Correct" if prediction == expected else "❌ Incorrect"
10
 
11
  print(f"\n{'='*70}")
 
15
  print(f"Claim 2: {claim2}")
16
  print(f"\nExpected: {expected}")
17
  print(f"Predicted: {prediction}")
18
+ print(f"Probability: {probability:.4f}")
19
  print(f"Confidence: {confidence:.2%}")
20
  print(f"Status: {status}")
21
 
22
 
23
+ def run_tests(model, tokenizer, device, test_cases):
24
+ """Run test cases using the BERT model."""
 
 
 
25
  print("\n" + "="*70)
26
  print("RUNNING TEST CASES")
27
  print("="*70)
 
34
  case["claim1"],
35
  case["claim2"],
36
  model,
37
+ tokenizer,
38
+ device
 
 
 
 
 
 
 
 
 
 
 
39
  )
40
+ print_pretty_prediction(result, i, case["expected"], case["claim1"], case["claim2"])
41
 
42
  if result["predicted_label"] == case["expected"]:
43
  correct_predictions += 1
44
 
45
+ # Summary
46
  accuracy = (correct_predictions / len(test_cases)) * 100
47
  elapsed_time = time.time() - start_time
 
48
  print(f"\n{'='*70}")
49
  print("SUMMARY")
50
  print(f"{'='*70}")