File size: 5,107 Bytes
91dedb2
4793ddb
91dedb2
7991d8b
dd37d28
91dedb2
 
b95f222
c804a35
7991d8b
91dedb2
4793ddb
91dedb2
4793ddb
 
 
1a07a5a
7991d8b
4793ddb
7991d8b
4793ddb
 
7991d8b
 
 
4793ddb
 
c804a35
 
 
4793ddb
91dedb2
c804a35
4793ddb
c804a35
 
 
4793ddb
7991d8b
4793ddb
7991d8b
4793ddb
7991d8b
4793ddb
 
 
1a07a5a
 
c804a35
4793ddb
 
 
 
 
 
 
 
 
 
 
 
 
c804a35
4793ddb
 
 
 
1a07a5a
7991d8b
c804a35
4793ddb
 
 
 
 
 
 
 
 
 
 
 
 
91dedb2
c804a35
91dedb2
 
4793ddb
 
 
c804a35
 
 
 
 
 
4793ddb
91dedb2
 
 
 
4793ddb
 
91dedb2
 
 
 
 
 
 
 
 
 
 
 
 
c804a35
91dedb2
1a07a5a
4793ddb
1a07a5a
c804a35
91dedb2
 
4793ddb
 
 
 
 
 
 
 
 
91dedb2
4793ddb
 
 
91dedb2
c804a35
9e4ddce
fc27bb3
 
9e4ddce
97e91c0
9e4ddce
 
 
97e91c0
9e4ddce
 
 
 
 
97e91c0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
from typing import Dict, Tuple

import gradio as gr
import torch
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from gradio.routes import mount_gradio_app
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# --- Model setup ---
# Fine-tuned model (ton modèle entraîné)
MODEL_ID = os.getenv("MODEL_ID", "TayebBou/sentiment-fr-allocine")
# Model de base (corps pré-entraîné + tête random) pour comparaison
BASE_MODEL = os.getenv("BASE_MODEL", "distilbert-base-multilingual-cased")

LABELS = ["neg", "pos"]

# Tokenizer (on peut réutiliser le même tokenizer si compatible)
tok = AutoTokenizer.from_pretrained(MODEL_ID)

# Chargement du modèle fine-tuné (celui que tu déploies)
mdl = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
mdl.eval()

# Chargement du modèle "non entraîné" : corps pré-entraîné + tête aléatoire
# On utilise from_pretrained(BASE_MODEL, num_labels=2) — la tête sera initialisée aléatoirement
mdl_head_random = AutoModelForSequenceClassification.from_pretrained(
    BASE_MODEL, num_labels=2
)
mdl_head_random.eval()


# --- Prediction utilities ---
def predict_proba_from_model(
    model: AutoModelForSequenceClassification, text: str
) -> Dict[str, float]:
    """Return probability distribution over labels for a given text and model."""
    inputs = tok(text, return_tensors="pt", truncation=True)
    # Si tu veux forcer CPU (par ex. sur un HF Space sans GPU), pas de .to(device) ici
    with torch.no_grad():
        logits = model(**inputs).logits
    probs = torch.softmax(logits, dim=-1).squeeze().tolist()
    # Si modèle renvoie scalaire pour un seul exemple, ensure list
    if isinstance(probs, float):
        probs = [probs]
    return {LABELS[i]: float(probs[i]) for i in range(len(LABELS))}


def top_label_phrase(probs: Dict[str, float]) -> str:
    """
    Transforme les probabilités en phrase demandée.
    Exemple de sortie :
      "Avis positif à 95.00% de probabilité"
    """
    pos_prob = probs.get("pos", 0.0)
    neg_prob = probs.get("neg", 0.0)
    if pos_prob >= neg_prob:
        return f"Avis positif à {pos_prob * 100:.2f}% de probabilité"
    else:
        return f"Avis négatif à {neg_prob * 100:.2f}% de probabilité"


# Fonctions exposées
def predict_label_only(text: str) -> str:
    """Fonction legacy qui renvoie juste le label du modèle fine-tuné (compatibilité)."""
    probs = predict_proba_from_model(mdl, text)
    return max(probs.keys(), key=lambda k: probs[k])


def predict_both_phrases(text: str) -> Tuple[str, str]:
    """
    Renvoie deux phrases formatées :
      - phrase pour le modèle fine-tuné (MODEL_ID)
      - phrase pour le modèle non-entraîné (BASE_MODEL avec tête random)
    """
    probs_ft = predict_proba_from_model(mdl, text)
    probs_head_random = predict_proba_from_model(mdl_head_random, text)

    phrase_ft = top_label_phrase(probs_ft)
    phrase_head_random = top_label_phrase(probs_head_random)

    return phrase_ft, phrase_head_random


# --- Gradio interface ---
demo = gr.Interface(
    fn=predict_both_phrases,
    inputs=gr.Textbox(label="Texte (FR)", lines=4, value="Ce film est bon"),
    outputs=[
        gr.Textbox(
            label=f"Modèle {BASE_MODEL} fine-tuné ({MODEL_ID})", interactive=False
        ),
        gr.Textbox(
            label=f"Modèle {BASE_MODEL} non-entrainé (tête random)", interactive=False
        ),
    ],
    examples=[
        ["Ce film est une merveille, j'ai adoré !"],
        ["Vraiment décevant, perte de temps."],
    ],
    title="Comparaison : modèle fine-tuné vs modèle non-entrainé",
    description="Affiche pour chaque modèle la prédiction principale sous forme 'Avis positif/negatif à X% de probabilité'.",
)

# --- FastAPI app ---
app = FastAPI(title="Sentiment FR API")

# Allow CORS for demo/testing purposes
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.get("/healthz")
def healthz():
    return {"status": "ok", "model": MODEL_ID, "base_model_for_compare": BASE_MODEL}


@app.post("/predict")
def predict_api(item: dict):
    """
    Endpoint qui renvoie les probabilités pour les deux modèles.
    JSON attendu: {"text": "..." }
    Réponse:
    {
      "fine_tuned": {"neg": 0.12, "pos": 0.88},
      "head_random": {"neg": 0.51, "pos": 0.49}
    }
    """
    text = item.get("text", "")
    probs_ft = predict_proba_from_model(mdl, text)
    probs_head_random = predict_proba_from_model(mdl_head_random, text)
    return {"fine_tuned": probs_ft, "head_random": probs_head_random}


# --- Mount Gradio in HF Space friendly way ---
IS_HF_SPACE = os.getenv("SYSTEM") == "spaces"

if IS_HF_SPACE:
    # In HF Space → launch Gradio only
    if __name__ == "__main__":
        demo.launch()
else:
    # In local or else → FastAPI + Gradio mounted together
    mount_gradio_app(app, demo, path="/")

    if __name__ == "__main__":
        import uvicorn

        uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)