Spaces:
Sleeping
Sleeping
p2002814
commited on
Commit
·
bb34072
1
Parent(s):
a1fcb70
now using the best fine tuned bert model, uploaded to hugging face hub
Browse files- .gitattributes +1 -2
- app.py +30 -26
- hugging_hub.py +9 -0
- models/model.pth +0 -3
- relations/predict.py +1 -1
- relations/predict_bert.py +49 -0
- relations/tests.py +9 -24
.gitattributes
CHANGED
|
@@ -1,2 +1 @@
|
|
| 1 |
-
models
|
| 2 |
-
models/model.pth filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
models/** filter=lfs diff=lfs merge=lfs -text
|
|
|
app.py
CHANGED
|
@@ -3,15 +3,16 @@ from fastapi.responses import FileResponse
|
|
| 3 |
from fastapi.middleware.cors import CORSMiddleware
|
| 4 |
import pandas as pd
|
| 5 |
from pathlib import Path
|
|
|
|
| 6 |
|
| 7 |
from relations.tests import run_tests
|
| 8 |
-
from relations.
|
| 9 |
-
from relations.mlp import load_model_and_metadata
|
| 10 |
-
from relations.processor import ArgumentDataProcessor
|
| 11 |
from exemples.claims import test_cases
|
| 12 |
|
|
|
|
|
|
|
| 13 |
# ABA imports
|
| 14 |
-
from aba.aba_builder import
|
| 15 |
|
| 16 |
app = FastAPI(title="Argument Mining API")
|
| 17 |
|
|
@@ -31,59 +32,60 @@ app.add_middleware(
|
|
| 31 |
|
| 32 |
EXAMPLES_DIR = Path("./aba/exemples")
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
)
|
| 40 |
-
|
|
|
|
| 41 |
|
| 42 |
@app.get("/")
|
| 43 |
def root():
|
| 44 |
return {"message": "Argument Mining API is running..."}
|
| 45 |
|
| 46 |
|
| 47 |
-
# ----------------
|
| 48 |
|
| 49 |
@app.post("/predict-test")
|
| 50 |
def predict_test():
|
| 51 |
-
"""Run predefined test cases for model validation."""
|
| 52 |
-
run_tests(model,
|
| 53 |
return {"message": "Test cases executed. Check server logs for details."}
|
| 54 |
|
| 55 |
|
| 56 |
@app.post("/predict-text")
|
| 57 |
def predict_text(arg1: str = Form(...), arg2: str = Form(...)):
|
| 58 |
-
"""Predict relation between two text arguments."""
|
| 59 |
-
|
| 60 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
|
| 63 |
@app.post("/predict-csv")
|
| 64 |
async def predict_csv(file: UploadFile):
|
| 65 |
"""Predict relations for pairs of arguments from a CSV file (max 100 rows)."""
|
| 66 |
df = pd.read_csv(file.file)
|
| 67 |
-
|
| 68 |
if len(df) > 100:
|
| 69 |
df = df.head(100)
|
| 70 |
|
| 71 |
results = []
|
| 72 |
for _, row in df.iterrows():
|
| 73 |
-
|
| 74 |
row["parent"],
|
| 75 |
row["child"],
|
| 76 |
model,
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
best_threshold,
|
| 80 |
-
label_encoder,
|
| 81 |
-
model_type
|
| 82 |
)
|
| 83 |
results.append({
|
| 84 |
"parent": row["parent"],
|
| 85 |
"child": row["child"],
|
| 86 |
-
"relation":
|
| 87 |
})
|
| 88 |
|
| 89 |
return {"results": results, "note": "Limited to 100 rows max"}
|
|
@@ -114,16 +116,18 @@ async def aba_upload(file: UploadFile = File(...)):
|
|
| 114 |
}
|
| 115 |
return results
|
| 116 |
|
|
|
|
| 117 |
@app.get("/aba-examples")
|
| 118 |
def list_aba_examples():
|
| 119 |
"""Lists all sample files available on the server side."""
|
| 120 |
examples = [f.name for f in EXAMPLES_DIR.glob("*.txt")]
|
| 121 |
return {"examples": examples}
|
| 122 |
|
|
|
|
| 123 |
@app.get("/aba-examples/{filename}")
|
| 124 |
def get_aba_example(filename: str):
|
| 125 |
"""Returns the contents of a specific ABA sample file."""
|
| 126 |
file_path = EXAMPLES_DIR / filename
|
| 127 |
if not file_path.exists() or not file_path.is_file():
|
| 128 |
return {"error": "File not found"}
|
| 129 |
-
return FileResponse(file_path, media_type="text/plain", filename=filename)
|
|
|
|
| 3 |
from fastapi.middleware.cors import CORSMiddleware
|
| 4 |
import pandas as pd
|
| 5 |
from pathlib import Path
|
| 6 |
+
import torch
|
| 7 |
|
| 8 |
from relations.tests import run_tests
|
| 9 |
+
from relations.predict_bert import predict_relation
|
|
|
|
|
|
|
| 10 |
from exemples.claims import test_cases
|
| 11 |
|
| 12 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 13 |
+
|
| 14 |
# ABA imports
|
| 15 |
+
from aba.aba_builder import prepare_aba_plus_framework, build_aba_framework_from_text
|
| 16 |
|
| 17 |
app = FastAPI(title="Argument Mining API")
|
| 18 |
|
|
|
|
| 32 |
|
| 33 |
EXAMPLES_DIR = Path("./aba/exemples")
|
| 34 |
|
| 35 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 36 |
+
|
| 37 |
+
# Load model at startup once
|
| 38 |
+
model_name = "edgar-demeude/bert-argument"
|
| 39 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 40 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
| 41 |
+
model.to(device)
|
| 42 |
+
|
| 43 |
|
| 44 |
@app.get("/")
|
| 45 |
def root():
|
| 46 |
return {"message": "Argument Mining API is running..."}
|
| 47 |
|
| 48 |
|
| 49 |
+
# ---------------- BERT Prediction Endpoints ---------------- #
|
| 50 |
|
| 51 |
@app.post("/predict-test")
|
| 52 |
def predict_test():
|
| 53 |
+
"""Run predefined test cases for BERT model validation."""
|
| 54 |
+
run_tests(model, tokenizer, device, test_cases)
|
| 55 |
return {"message": "Test cases executed. Check server logs for details."}
|
| 56 |
|
| 57 |
|
| 58 |
@app.post("/predict-text")
|
| 59 |
def predict_text(arg1: str = Form(...), arg2: str = Form(...)):
|
| 60 |
+
"""Predict relation between two text arguments using BERT."""
|
| 61 |
+
result = predict_relation(arg1, arg2, model, tokenizer, device)
|
| 62 |
+
return {
|
| 63 |
+
"arg1": arg1,
|
| 64 |
+
"arg2": arg2,
|
| 65 |
+
"relation": result
|
| 66 |
+
}
|
| 67 |
|
| 68 |
|
| 69 |
@app.post("/predict-csv")
|
| 70 |
async def predict_csv(file: UploadFile):
|
| 71 |
"""Predict relations for pairs of arguments from a CSV file (max 100 rows)."""
|
| 72 |
df = pd.read_csv(file.file)
|
|
|
|
| 73 |
if len(df) > 100:
|
| 74 |
df = df.head(100)
|
| 75 |
|
| 76 |
results = []
|
| 77 |
for _, row in df.iterrows():
|
| 78 |
+
result = predict_relation(
|
| 79 |
row["parent"],
|
| 80 |
row["child"],
|
| 81 |
model,
|
| 82 |
+
tokenizer,
|
| 83 |
+
device
|
|
|
|
|
|
|
|
|
|
| 84 |
)
|
| 85 |
results.append({
|
| 86 |
"parent": row["parent"],
|
| 87 |
"child": row["child"],
|
| 88 |
+
"relation": result
|
| 89 |
})
|
| 90 |
|
| 91 |
return {"results": results, "note": "Limited to 100 rows max"}
|
|
|
|
| 116 |
}
|
| 117 |
return results
|
| 118 |
|
| 119 |
+
|
| 120 |
@app.get("/aba-examples")
|
| 121 |
def list_aba_examples():
|
| 122 |
"""Lists all sample files available on the server side."""
|
| 123 |
examples = [f.name for f in EXAMPLES_DIR.glob("*.txt")]
|
| 124 |
return {"examples": examples}
|
| 125 |
|
| 126 |
+
|
| 127 |
@app.get("/aba-examples/{filename}")
|
| 128 |
def get_aba_example(filename: str):
|
| 129 |
"""Returns the contents of a specific ABA sample file."""
|
| 130 |
file_path = EXAMPLES_DIR / filename
|
| 131 |
if not file_path.exists() or not file_path.is_file():
|
| 132 |
return {"error": "File not found"}
|
| 133 |
+
return FileResponse(file_path, media_type="text/plain", filename=filename)
|
hugging_hub.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# File to automatically push the model to hugging face hub
|
| 2 |
+
|
| 3 |
+
from huggingface_hub import upload_folder
|
| 4 |
+
|
| 5 |
+
upload_folder(
|
| 6 |
+
repo_id="edgar-demeude/bert-argument",
|
| 7 |
+
folder_path="./models/bert-argument",
|
| 8 |
+
repo_type="model"
|
| 9 |
+
)
|
models/model.pth
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:4a8a1413586023e10aadde9a86e01b6425814605767924e818f1e189926ad86b
|
| 3 |
-
size 9667219
|
|
|
|
|
|
|
|
|
|
|
|
relations/predict.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import torch
|
| 2 |
from .embeddings import generate_embeddings
|
| 3 |
|
| 4 |
-
def
|
| 5 |
embeddings = generate_embeddings(arg1, arg2, embedding_model, processor)
|
| 6 |
|
| 7 |
if model_type == "pytorch":
|
|
|
|
| 1 |
import torch
|
| 2 |
from .embeddings import generate_embeddings
|
| 3 |
|
| 4 |
+
def predict_relation_old(arg1, arg2, model, embedding_model, processor, best_threshold, label_encoder, model_type="pytorch"):
|
| 5 |
embeddings = generate_embeddings(arg1, arg2, embedding_model, processor)
|
| 6 |
|
| 7 |
if model_type == "pytorch":
|
relations/predict_bert.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 3 |
+
|
| 4 |
+
def load_bert_model(model_path="../models/bert-argument", device=None):
|
| 5 |
+
if device is None:
|
| 6 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 7 |
+
|
| 8 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 9 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
| 10 |
+
model.to(device)
|
| 11 |
+
model.eval()
|
| 12 |
+
|
| 13 |
+
return model, tokenizer, device
|
| 14 |
+
|
| 15 |
+
def predict_relation(parent_text, child_text, model, tokenizer, device, max_length=256):
|
| 16 |
+
"""
|
| 17 |
+
Predicts whether the relation between parent and child is Support or Attack.
|
| 18 |
+
"""
|
| 19 |
+
model.eval()
|
| 20 |
+
|
| 21 |
+
# Tokenization
|
| 22 |
+
encoding = tokenizer(
|
| 23 |
+
parent_text,
|
| 24 |
+
child_text,
|
| 25 |
+
add_special_tokens=True,
|
| 26 |
+
max_length=max_length,
|
| 27 |
+
padding='max_length',
|
| 28 |
+
truncation='only_second',
|
| 29 |
+
return_attention_mask=True,
|
| 30 |
+
return_tensors='pt'
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
input_ids = encoding['input_ids'].to(device)
|
| 34 |
+
attention_mask = encoding['attention_mask'].to(device)
|
| 35 |
+
|
| 36 |
+
with torch.no_grad():
|
| 37 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
| 38 |
+
logits = outputs.logits
|
| 39 |
+
probs = torch.softmax(logits, dim=1)
|
| 40 |
+
pred = torch.argmax(probs, dim=1).item()
|
| 41 |
+
confidence = probs[0][pred].item()
|
| 42 |
+
|
| 43 |
+
relation = "Support" if pred == 1 else "Attack"
|
| 44 |
+
|
| 45 |
+
return {
|
| 46 |
+
"predicted_label": relation,
|
| 47 |
+
"probability": confidence,
|
| 48 |
+
"confidence": confidence
|
| 49 |
+
}
|
relations/tests.py
CHANGED
|
@@ -1,12 +1,11 @@
|
|
| 1 |
import time
|
| 2 |
-
from .
|
| 3 |
|
| 4 |
-
def print_pretty_prediction(result: dict, test_case_num: int, expected: str, claim1: str, claim2: str
|
| 5 |
"""Pretty-print the result of a single test case."""
|
| 6 |
prediction = result["predicted_label"]
|
| 7 |
confidence = result["confidence"]
|
| 8 |
probability = result["probability"]
|
| 9 |
-
|
| 10 |
status = "✅ Correct" if prediction == expected else "❌ Incorrect"
|
| 11 |
|
| 12 |
print(f"\n{'='*70}")
|
|
@@ -16,16 +15,13 @@ def print_pretty_prediction(result: dict, test_case_num: int, expected: str, cla
|
|
| 16 |
print(f"Claim 2: {claim2}")
|
| 17 |
print(f"\nExpected: {expected}")
|
| 18 |
print(f"Predicted: {prediction}")
|
| 19 |
-
print(f"Probability: {probability:.4f}
|
| 20 |
print(f"Confidence: {confidence:.2%}")
|
| 21 |
print(f"Status: {status}")
|
| 22 |
|
| 23 |
|
| 24 |
-
def run_tests(model,
|
| 25 |
-
"""
|
| 26 |
-
Run a list of test cases and display results.
|
| 27 |
-
Each test case must be a dict with keys: 'claim1', 'claim2', 'expected'
|
| 28 |
-
"""
|
| 29 |
print("\n" + "="*70)
|
| 30 |
print("RUNNING TEST CASES")
|
| 31 |
print("="*70)
|
|
@@ -38,28 +34,17 @@ def run_tests(model, embedding_model, processor, best_threshold, label_encoder,
|
|
| 38 |
case["claim1"],
|
| 39 |
case["claim2"],
|
| 40 |
model,
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
best_threshold,
|
| 44 |
-
label_encoder,
|
| 45 |
-
model_type,
|
| 46 |
-
)
|
| 47 |
-
print_pretty_prediction(
|
| 48 |
-
result,
|
| 49 |
-
test_case_num=i,
|
| 50 |
-
expected=case["expected"],
|
| 51 |
-
claim1=case["claim1"],
|
| 52 |
-
claim2=case["claim2"],
|
| 53 |
-
best_threshold=best_threshold
|
| 54 |
)
|
|
|
|
| 55 |
|
| 56 |
if result["predicted_label"] == case["expected"]:
|
| 57 |
correct_predictions += 1
|
| 58 |
|
| 59 |
-
#
|
| 60 |
accuracy = (correct_predictions / len(test_cases)) * 100
|
| 61 |
elapsed_time = time.time() - start_time
|
| 62 |
-
|
| 63 |
print(f"\n{'='*70}")
|
| 64 |
print("SUMMARY")
|
| 65 |
print(f"{'='*70}")
|
|
|
|
| 1 |
import time
|
| 2 |
+
from .predict_bert import predict_relation
|
| 3 |
|
| 4 |
+
def print_pretty_prediction(result: dict, test_case_num: int, expected: str, claim1: str, claim2: str):
|
| 5 |
"""Pretty-print the result of a single test case."""
|
| 6 |
prediction = result["predicted_label"]
|
| 7 |
confidence = result["confidence"]
|
| 8 |
probability = result["probability"]
|
|
|
|
| 9 |
status = "✅ Correct" if prediction == expected else "❌ Incorrect"
|
| 10 |
|
| 11 |
print(f"\n{'='*70}")
|
|
|
|
| 15 |
print(f"Claim 2: {claim2}")
|
| 16 |
print(f"\nExpected: {expected}")
|
| 17 |
print(f"Predicted: {prediction}")
|
| 18 |
+
print(f"Probability: {probability:.4f}")
|
| 19 |
print(f"Confidence: {confidence:.2%}")
|
| 20 |
print(f"Status: {status}")
|
| 21 |
|
| 22 |
|
| 23 |
+
def run_tests(model, tokenizer, device, test_cases):
|
| 24 |
+
"""Run test cases using the BERT model."""
|
|
|
|
|
|
|
|
|
|
| 25 |
print("\n" + "="*70)
|
| 26 |
print("RUNNING TEST CASES")
|
| 27 |
print("="*70)
|
|
|
|
| 34 |
case["claim1"],
|
| 35 |
case["claim2"],
|
| 36 |
model,
|
| 37 |
+
tokenizer,
|
| 38 |
+
device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
)
|
| 40 |
+
print_pretty_prediction(result, i, case["expected"], case["claim1"], case["claim2"])
|
| 41 |
|
| 42 |
if result["predicted_label"] == case["expected"]:
|
| 43 |
correct_predictions += 1
|
| 44 |
|
| 45 |
+
# Summary
|
| 46 |
accuracy = (correct_predictions / len(test_cases)) * 100
|
| 47 |
elapsed_time = time.time() - start_time
|
|
|
|
| 48 |
print(f"\n{'='*70}")
|
| 49 |
print("SUMMARY")
|
| 50 |
print(f"{'='*70}")
|