TayebBou's picture
Déploiement automatique depuis GitHub Actions 🚀
97e91c0 verified
import os
from typing import Dict, Tuple
import gradio as gr
import torch
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from gradio.routes import mount_gradio_app
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# --- Model setup ---
# Fine-tuned model (ton modèle entraîné)
MODEL_ID = os.getenv("MODEL_ID", "TayebBou/sentiment-fr-allocine")
# Model de base (corps pré-entraîné + tête random) pour comparaison
BASE_MODEL = os.getenv("BASE_MODEL", "distilbert-base-multilingual-cased")
LABELS = ["neg", "pos"]
# Tokenizer (on peut réutiliser le même tokenizer si compatible)
tok = AutoTokenizer.from_pretrained(MODEL_ID)
# Chargement du modèle fine-tuné (celui que tu déploies)
mdl = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
mdl.eval()
# Chargement du modèle "non entraîné" : corps pré-entraîné + tête aléatoire
# On utilise from_pretrained(BASE_MODEL, num_labels=2) — la tête sera initialisée aléatoirement
mdl_head_random = AutoModelForSequenceClassification.from_pretrained(
BASE_MODEL, num_labels=2
)
mdl_head_random.eval()
# --- Prediction utilities ---
def predict_proba_from_model(
model: AutoModelForSequenceClassification, text: str
) -> Dict[str, float]:
"""Return probability distribution over labels for a given text and model."""
inputs = tok(text, return_tensors="pt", truncation=True)
# Si tu veux forcer CPU (par ex. sur un HF Space sans GPU), pas de .to(device) ici
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1).squeeze().tolist()
# Si modèle renvoie scalaire pour un seul exemple, ensure list
if isinstance(probs, float):
probs = [probs]
return {LABELS[i]: float(probs[i]) for i in range(len(LABELS))}
def top_label_phrase(probs: Dict[str, float]) -> str:
"""
Transforme les probabilités en phrase demandée.
Exemple de sortie :
"Avis positif à 95.00% de probabilité"
"""
pos_prob = probs.get("pos", 0.0)
neg_prob = probs.get("neg", 0.0)
if pos_prob >= neg_prob:
return f"Avis positif à {pos_prob * 100:.2f}% de probabilité"
else:
return f"Avis négatif à {neg_prob * 100:.2f}% de probabilité"
# Fonctions exposées
def predict_label_only(text: str) -> str:
"""Fonction legacy qui renvoie juste le label du modèle fine-tuné (compatibilité)."""
probs = predict_proba_from_model(mdl, text)
return max(probs.keys(), key=lambda k: probs[k])
def predict_both_phrases(text: str) -> Tuple[str, str]:
"""
Renvoie deux phrases formatées :
- phrase pour le modèle fine-tuné (MODEL_ID)
- phrase pour le modèle non-entraîné (BASE_MODEL avec tête random)
"""
probs_ft = predict_proba_from_model(mdl, text)
probs_head_random = predict_proba_from_model(mdl_head_random, text)
phrase_ft = top_label_phrase(probs_ft)
phrase_head_random = top_label_phrase(probs_head_random)
return phrase_ft, phrase_head_random
# --- Gradio interface ---
demo = gr.Interface(
fn=predict_both_phrases,
inputs=gr.Textbox(label="Texte (FR)", lines=4, value="Ce film est bon"),
outputs=[
gr.Textbox(
label=f"Modèle {BASE_MODEL} fine-tuné ({MODEL_ID})", interactive=False
),
gr.Textbox(
label=f"Modèle {BASE_MODEL} non-entrainé (tête random)", interactive=False
),
],
examples=[
["Ce film est une merveille, j'ai adoré !"],
["Vraiment décevant, perte de temps."],
],
title="Comparaison : modèle fine-tuné vs modèle non-entrainé",
description="Affiche pour chaque modèle la prédiction principale sous forme 'Avis positif/negatif à X% de probabilité'.",
)
# --- FastAPI app ---
app = FastAPI(title="Sentiment FR API")
# Allow CORS for demo/testing purposes
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/healthz")
def healthz():
return {"status": "ok", "model": MODEL_ID, "base_model_for_compare": BASE_MODEL}
@app.post("/predict")
def predict_api(item: dict):
"""
Endpoint qui renvoie les probabilités pour les deux modèles.
JSON attendu: {"text": "..." }
Réponse:
{
"fine_tuned": {"neg": 0.12, "pos": 0.88},
"head_random": {"neg": 0.51, "pos": 0.49}
}
"""
text = item.get("text", "")
probs_ft = predict_proba_from_model(mdl, text)
probs_head_random = predict_proba_from_model(mdl_head_random, text)
return {"fine_tuned": probs_ft, "head_random": probs_head_random}
# --- Mount Gradio in HF Space friendly way ---
IS_HF_SPACE = os.getenv("SYSTEM") == "spaces"
if IS_HF_SPACE:
# In HF Space → launch Gradio only
if __name__ == "__main__":
demo.launch()
else:
# In local or else → FastAPI + Gradio mounted together
mount_gradio_app(app, demo, path="/")
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)