TayebBou commited on
Commit
c804a35
·
verified ·
1 Parent(s): 2143736

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +22 -5
app.py CHANGED
@@ -4,9 +4,9 @@ from typing import Dict, Tuple
4
  import gradio as gr
5
  import torch
6
  from fastapi import FastAPI
7
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from gradio.routes import mount_gradio_app
 
10
 
11
  # --- Model setup ---
12
  # Fine-tuned model (ton modèle entraîné)
@@ -25,11 +25,16 @@ mdl.eval()
25
 
26
  # Chargement du modèle "non entraîné" : corps pré-entraîné + tête aléatoire
27
  # On utilise from_pretrained(BASE_MODEL, num_labels=2) — la tête sera initialisée aléatoirement
28
- mdl_head_random = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL, num_labels=2)
 
 
29
  mdl_head_random.eval()
30
 
 
31
  # --- Prediction utilities ---
32
- def predict_proba_from_model(model: AutoModelForSequenceClassification, text: str) -> Dict[str, float]:
 
 
33
  """Return probability distribution over labels for a given text and model."""
34
  inputs = tok(text, return_tensors="pt", truncation=True)
35
  # Si tu veux forcer CPU (par ex. sur un HF Space sans GPU), pas de .to(device) ici
@@ -41,6 +46,7 @@ def predict_proba_from_model(model: AutoModelForSequenceClassification, text: st
41
  probs = [probs]
42
  return {LABELS[i]: float(probs[i]) for i in range(len(LABELS))}
43
 
 
44
  def top_label_phrase(probs: Dict[str, float]) -> str:
45
  """
46
  Transforme les probabilités en phrase demandée.
@@ -54,12 +60,14 @@ def top_label_phrase(probs: Dict[str, float]) -> str:
54
  else:
55
  return f"Avis négatif à {neg_prob * 100:.2f}% de probabilité"
56
 
 
57
  # Fonctions exposées
58
  def predict_label_only(text: str) -> str:
59
  """Fonction legacy qui renvoie juste le label du modèle fine-tuné (compatibilité)."""
60
  probs = predict_proba_from_model(mdl, text)
61
  return max(probs.keys(), key=lambda k: probs[k])
62
 
 
63
  def predict_both_phrases(text: str) -> Tuple[str, str]:
64
  """
65
  Renvoie deux phrases formatées :
@@ -74,13 +82,18 @@ def predict_both_phrases(text: str) -> Tuple[str, str]:
74
 
75
  return phrase_ft, phrase_head_random
76
 
 
77
  # --- Gradio interface ---
78
  demo = gr.Interface(
79
  fn=predict_both_phrases,
80
  inputs=gr.Textbox(label="Texte (FR)", lines=4, value="Ce film est bon"),
81
  outputs=[
82
- gr.Textbox(label=f"Modèle {BASE_MODEL} fine-tuné ({MODEL_ID})", interactive=False),
83
- gr.Textbox(label=f"Modèle {BASE_MODEL} non-entrainé (tête random)", interactive=False),
 
 
 
 
84
  ],
85
  examples=[
86
  ["Ce film est une merveille, j'ai adoré !"],
@@ -101,10 +114,12 @@ app.add_middleware(
101
  allow_headers=["*"],
102
  )
103
 
 
104
  @app.get("/healthz")
105
  def healthz():
106
  return {"status": "ok", "model": MODEL_ID, "base_model_for_compare": BASE_MODEL}
107
 
 
108
  @app.post("/predict")
109
  def predict_api(item: dict):
110
  """
@@ -121,10 +136,12 @@ def predict_api(item: dict):
121
  probs_head_random = predict_proba_from_model(mdl_head_random, text)
122
  return {"fine_tuned": probs_ft, "head_random": probs_head_random}
123
 
 
124
  # --- Mount Gradio at root (for HF Space) ---
125
  mount_gradio_app(app, demo, path="/")
126
 
127
  # --- Optional: run locally ---
128
  if __name__ == "__main__":
129
  import uvicorn
 
130
  uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
 
4
  import gradio as gr
5
  import torch
6
  from fastapi import FastAPI
 
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from gradio.routes import mount_gradio_app
9
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
10
 
11
  # --- Model setup ---
12
  # Fine-tuned model (ton modèle entraîné)
 
25
 
26
  # Chargement du modèle "non entraîné" : corps pré-entraîné + tête aléatoire
27
  # On utilise from_pretrained(BASE_MODEL, num_labels=2) — la tête sera initialisée aléatoirement
28
+ mdl_head_random = AutoModelForSequenceClassification.from_pretrained(
29
+ BASE_MODEL, num_labels=2
30
+ )
31
  mdl_head_random.eval()
32
 
33
+
34
  # --- Prediction utilities ---
35
+ def predict_proba_from_model(
36
+ model: AutoModelForSequenceClassification, text: str
37
+ ) -> Dict[str, float]:
38
  """Return probability distribution over labels for a given text and model."""
39
  inputs = tok(text, return_tensors="pt", truncation=True)
40
  # Si tu veux forcer CPU (par ex. sur un HF Space sans GPU), pas de .to(device) ici
 
46
  probs = [probs]
47
  return {LABELS[i]: float(probs[i]) for i in range(len(LABELS))}
48
 
49
+
50
  def top_label_phrase(probs: Dict[str, float]) -> str:
51
  """
52
  Transforme les probabilités en phrase demandée.
 
60
  else:
61
  return f"Avis négatif à {neg_prob * 100:.2f}% de probabilité"
62
 
63
+
64
  # Fonctions exposées
65
  def predict_label_only(text: str) -> str:
66
  """Fonction legacy qui renvoie juste le label du modèle fine-tuné (compatibilité)."""
67
  probs = predict_proba_from_model(mdl, text)
68
  return max(probs.keys(), key=lambda k: probs[k])
69
 
70
+
71
  def predict_both_phrases(text: str) -> Tuple[str, str]:
72
  """
73
  Renvoie deux phrases formatées :
 
82
 
83
  return phrase_ft, phrase_head_random
84
 
85
+
86
  # --- Gradio interface ---
87
  demo = gr.Interface(
88
  fn=predict_both_phrases,
89
  inputs=gr.Textbox(label="Texte (FR)", lines=4, value="Ce film est bon"),
90
  outputs=[
91
+ gr.Textbox(
92
+ label=f"Modèle {BASE_MODEL} fine-tuné ({MODEL_ID})", interactive=False
93
+ ),
94
+ gr.Textbox(
95
+ label=f"Modèle {BASE_MODEL} non-entrainé (tête random)", interactive=False
96
+ ),
97
  ],
98
  examples=[
99
  ["Ce film est une merveille, j'ai adoré !"],
 
114
  allow_headers=["*"],
115
  )
116
 
117
+
118
  @app.get("/healthz")
119
  def healthz():
120
  return {"status": "ok", "model": MODEL_ID, "base_model_for_compare": BASE_MODEL}
121
 
122
+
123
  @app.post("/predict")
124
  def predict_api(item: dict):
125
  """
 
136
  probs_head_random = predict_proba_from_model(mdl_head_random, text)
137
  return {"fine_tuned": probs_ft, "head_random": probs_head_random}
138
 
139
+
140
  # --- Mount Gradio at root (for HF Space) ---
141
  mount_gradio_app(app, demo, path="/")
142
 
143
  # --- Optional: run locally ---
144
  if __name__ == "__main__":
145
  import uvicorn
146
+
147
  uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)