Lamina / app.py
Clemylia's picture
Update app.py
edcc5ed verified
import gradio as gr
from transformers import pipeline
import torch
# --- Dictionnaire de configuration des modèles ---
# On centralise les informations de tes modèles ici pour plus de clarté
MODELS = {
"Lamina-extend": {
"repo_id": "Clemylia/Lamina-extend",
"description": "✅ Le plus avancé. Peut répondre à 67 questions, ses phrases sont mieux construites que Lamina-yl1."
},
"Lamina-yl1": {
"repo_id": "Clemylia/Lamina-yl1",
"description": "👍 Intermédiaire. Peut répondre à 55 questions, essaie de formuler ses propres phrases mais avec des incohérences."
},
"Lamina-basic": {
"repo_id": "Clemylia/lamina-basic",
"description": "🧪 Expérimental. A tendance à générer du texte créatif et des mélanges de mots sans sens."
}
}
# Variable globale pour conserver le pipeline chargé et éviter de le recharger
current_pipeline = None
current_model_name = ""
# --- Fonction pour charger un modèle ---
def load_model(model_name_to_load):
global current_pipeline, current_model_name
# On ne recharge que si le modèle demandé est différent de celui en mémoire
if model_name_to_load != current_model_name:
print(f"Chargement du modèle : {model_name_to_load}...")
repo_id = MODELS[model_name_to_load]["repo_id"]
current_pipeline = pipeline(
'text-generation',
model=repo_id,
torch_dtype=torch.bfloat16,
device_map="auto"
)
current_model_name = model_name_to_load
print(f"✅ Modèle {model_name_to_load} prêt.")
# On retourne la description à afficher
return MODELS[model_name_to_load]["description"]
# --- Fonction de prédiction ---
def predict(message, history):
# On vérifie qu'un modèle est bien chargé
if current_pipeline is None:
return "Veuillez d'abord sélectionner un modèle."
# On formate le prompt avec l'historique de la conversation
prompt_parts = []
for user_msg, assistant_msg in history:
prompt_parts.append(f"### Instruction:\n{user_msg}\n\n### Response:\n{assistant_msg}")
prompt_parts.append(f"### Instruction:\n{message}\n\n### Response:\n")
prompt = "\n".join(prompt_parts)
# Génération de la réponse
output = current_pipeline(prompt, max_new_tokens=100, pad_token_id=current_pipeline.tokenizer.eos_token_id)
# Nettoyage de la réponse
full_response = output[0]['generated_text']
clean_response = full_response.split("### Response:")[-1].strip()
return clean_response
# --- Création de l'interface Gradio personnalisée ---
with gr.Blocks(theme="soft") as demo:
gr.Markdown("# 🤖 Lamina Chatbot - Comparateur de Versions")
gr.Markdown("Choisissez une version de Lamina pour commencer à discuter.")
with gr.Row():
with gr.Column(scale=1):
model_selector = gr.Radio(
choices=list(MODELS.keys()),
label="Choisir un modèle",
value="Lamina-extend" # Le modèle chargé par défaut
)
model_description = gr.Markdown(MODELS["Lamina-extend"]["description"])
with gr.Column(scale=3):
gr.ChatInterface(
fn=predict,
examples=[["Qui es-tu ?"], ["Que sais-tu sur les insectes ?"], ["Raconte-moi une blague."]]
)
# On lie le changement du sélecteur à la fonction de chargement de modèle
model_selector.change(load_model, inputs=model_selector, outputs=model_description)
# On charge le modèle par défaut au démarrage de l'application
demo.load(lambda: load_model("Lamina-extend"), inputs=None, outputs=model_description)
# --- Lancement de l'application ---
if __name__ == "__main__":
demo.launch()