Update app.py
Browse files
app.py
CHANGED
|
@@ -2,52 +2,94 @@ import gradio as gr
|
|
| 2 |
from transformers import pipeline
|
| 3 |
import torch
|
| 4 |
|
| 5 |
-
# ---
|
| 6 |
-
#
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
#
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
'text-generation',
|
| 13 |
-
model=repo_id,
|
| 14 |
-
torch_dtype=torch.bfloat16,
|
| 15 |
-
device_map="auto"
|
| 16 |
-
)
|
| 17 |
-
print("✅ Lamina est prête.")
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
# --- Fonction de prédiction ---
|
| 21 |
-
# C'est la fonction que Gradio appellera à chaque message de l'utilisateur.
|
| 22 |
def predict(message, history):
|
| 23 |
-
# On
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
prompt_parts = []
|
| 25 |
for user_msg, assistant_msg in history:
|
| 26 |
prompt_parts.append(f"### Instruction:\n{user_msg}\n\n### Response:\n{assistant_msg}")
|
| 27 |
-
|
| 28 |
prompt_parts.append(f"### Instruction:\n{message}\n\n### Response:\n")
|
| 29 |
-
|
| 30 |
prompt = "\n".join(prompt_parts)
|
| 31 |
|
| 32 |
# Génération de la réponse
|
| 33 |
-
output =
|
| 34 |
|
| 35 |
# Nettoyage de la réponse
|
| 36 |
full_response = output[0]['generated_text']
|
| 37 |
-
# On ne garde que la dernière réponse générée
|
| 38 |
clean_response = full_response.split("### Response:")[-1].strip()
|
| 39 |
|
| 40 |
return clean_response
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
# --- Lancement de l'application ---
|
| 53 |
if __name__ == "__main__":
|
|
|
|
| 2 |
from transformers import pipeline
|
| 3 |
import torch
|
| 4 |
|
| 5 |
+
# --- Dictionnaire de configuration des modèles ---
|
| 6 |
+
# On centralise les informations de tes modèles ici pour plus de clarté
|
| 7 |
+
MODELS = {
|
| 8 |
+
"Lamina-extend": {
|
| 9 |
+
"repo_id": "Clemylia/Lamina-extend",
|
| 10 |
+
"description": "✅ Le plus avancé. Peut répondre à 67 questions, ses phrases sont mieux construites que Lamina-yl1."
|
| 11 |
+
},
|
| 12 |
+
"Lamina-yl1": {
|
| 13 |
+
"repo_id": "Clemylia/Lamina-yl1",
|
| 14 |
+
"description": "👍 Intermédiaire. Peut répondre à 55 questions, essaie de formuler ses propres phrases mais avec des incohérences."
|
| 15 |
+
},
|
| 16 |
+
"Lamina-basic": {
|
| 17 |
+
"repo_id": "Clemylia/lamina-basic",
|
| 18 |
+
"description": "🧪 Expérimental. A tendance à générer du texte créatif et des mélanges de mots sans sens."
|
| 19 |
+
}
|
| 20 |
+
}
|
| 21 |
|
| 22 |
+
# Variable globale pour conserver le pipeline chargé et éviter de le recharger
|
| 23 |
+
current_pipeline = None
|
| 24 |
+
current_model_name = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
# --- Fonction pour charger un modèle ---
|
| 27 |
+
def load_model(model_name_to_load):
|
| 28 |
+
global current_pipeline, current_model_name
|
| 29 |
+
|
| 30 |
+
# On ne recharge que si le modèle demandé est différent de celui en mémoire
|
| 31 |
+
if model_name_to_load != current_model_name:
|
| 32 |
+
print(f"Chargement du modèle : {model_name_to_load}...")
|
| 33 |
+
repo_id = MODELS[model_name_to_load]["repo_id"]
|
| 34 |
+
|
| 35 |
+
current_pipeline = pipeline(
|
| 36 |
+
'text-generation',
|
| 37 |
+
model=repo_id,
|
| 38 |
+
torch_dtype=torch.bfloat16,
|
| 39 |
+
device_map="auto"
|
| 40 |
+
)
|
| 41 |
+
current_model_name = model_name_to_load
|
| 42 |
+
print(f"✅ Modèle {model_name_to_load} prêt.")
|
| 43 |
+
|
| 44 |
+
# On retourne la description à afficher
|
| 45 |
+
return MODELS[model_name_to_load]["description"]
|
| 46 |
|
| 47 |
# --- Fonction de prédiction ---
|
|
|
|
| 48 |
def predict(message, history):
|
| 49 |
+
# On vérifie qu'un modèle est bien chargé
|
| 50 |
+
if current_pipeline is None:
|
| 51 |
+
return "Veuillez d'abord sélectionner un modèle."
|
| 52 |
+
|
| 53 |
+
# On formate le prompt avec l'historique de la conversation
|
| 54 |
prompt_parts = []
|
| 55 |
for user_msg, assistant_msg in history:
|
| 56 |
prompt_parts.append(f"### Instruction:\n{user_msg}\n\n### Response:\n{assistant_msg}")
|
|
|
|
| 57 |
prompt_parts.append(f"### Instruction:\n{message}\n\n### Response:\n")
|
|
|
|
| 58 |
prompt = "\n".join(prompt_parts)
|
| 59 |
|
| 60 |
# Génération de la réponse
|
| 61 |
+
output = current_pipeline(prompt, max_new_tokens=100, pad_token_id=current_pipeline.tokenizer.eos_token_id)
|
| 62 |
|
| 63 |
# Nettoyage de la réponse
|
| 64 |
full_response = output[0]['generated_text']
|
|
|
|
| 65 |
clean_response = full_response.split("### Response:")[-1].strip()
|
| 66 |
|
| 67 |
return clean_response
|
| 68 |
|
| 69 |
+
# --- Création de l'interface Gradio personnalisée ---
|
| 70 |
+
with gr.Blocks(theme="soft") as demo:
|
| 71 |
+
gr.Markdown("# 🤖 Lamina Chatbot - Comparateur de Versions")
|
| 72 |
+
gr.Markdown("Choisissez une version de Lamina pour commencer à discuter.")
|
| 73 |
|
| 74 |
+
with gr.Row():
|
| 75 |
+
with gr.Column(scale=1):
|
| 76 |
+
model_selector = gr.Radio(
|
| 77 |
+
choices=list(MODELS.keys()),
|
| 78 |
+
label="Choisir un modèle",
|
| 79 |
+
value="Lamina-extend" # Le modèle chargé par défaut
|
| 80 |
+
)
|
| 81 |
+
model_description = gr.Markdown(MODELS["Lamina-extend"]["description"])
|
| 82 |
+
with gr.Column(scale=3):
|
| 83 |
+
gr.ChatInterface(
|
| 84 |
+
fn=predict,
|
| 85 |
+
examples=[["Qui es-tu ?"], ["Que sais-tu sur les insectes ?"], ["Raconte-moi une blague."]]
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# On lie le changement du sélecteur à la fonction de chargement de modèle
|
| 89 |
+
model_selector.change(load_model, inputs=model_selector, outputs=model_description)
|
| 90 |
+
|
| 91 |
+
# On charge le modèle par défaut au démarrage de l'application
|
| 92 |
+
demo.load(lambda: load_model("Lamina-extend"), inputs=None, outputs=model_description)
|
| 93 |
|
| 94 |
# --- Lancement de l'application ---
|
| 95 |
if __name__ == "__main__":
|