File size: 3,853 Bytes
85fd04c
 
 
 
9a68925
 
 
786407b
9a68925
 
 
 
 
 
 
 
 
 
 
 
85fd04c
9a68925
 
 
85fd04c
9a68925
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85fd04c
 
 
9a68925
 
 
 
 
85fd04c
 
 
 
 
 
 
9a68925
85fd04c
 
 
 
 
 
 
9a68925
 
 
 
85fd04c
9a68925
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85fd04c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
from transformers import pipeline
import torch

# --- Dictionnaire de configuration des modèles ---
# On centralise les informations de tes modèles ici pour plus de clarté
MODELS = {
"Lamina-extend": {
        "repo_id": "Clemylia/Lamina-extend",
        "description": "✅ Le plus avancé. Peut répondre à 67 questions, ses phrases sont mieux construites que Lamina-yl1."
    },
    "Lamina-yl1": {
        "repo_id": "Clemylia/Lamina-yl1",
        "description": "👍 Intermédiaire. Peut répondre à 55 questions, essaie de formuler ses propres phrases mais avec des incohérences."
    },
    "Lamina-basic": {
        "repo_id": "Clemylia/lamina-basic",
        "description": "🧪 Expérimental. A tendance à générer du texte créatif et des mélanges de mots sans sens."
    }
}

# Variable globale pour conserver le pipeline chargé et éviter de le recharger
current_pipeline = None
current_model_name = ""

# --- Fonction pour charger un modèle ---
def load_model(model_name_to_load):
    global current_pipeline, current_model_name
    
    # On ne recharge que si le modèle demandé est différent de celui en mémoire
    if model_name_to_load != current_model_name:
        print(f"Chargement du modèle : {model_name_to_load}...")
        repo_id = MODELS[model_name_to_load]["repo_id"]
        
        current_pipeline = pipeline(
            'text-generation', 
            model=repo_id, 
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )
        current_model_name = model_name_to_load
        print(f"✅ Modèle {model_name_to_load} prêt.")
        
    # On retourne la description à afficher
    return MODELS[model_name_to_load]["description"]

# --- Fonction de prédiction ---
def predict(message, history):
    # On vérifie qu'un modèle est bien chargé
    if current_pipeline is None:
        return "Veuillez d'abord sélectionner un modèle."

    # On formate le prompt avec l'historique de la conversation
    prompt_parts = []
    for user_msg, assistant_msg in history:
        prompt_parts.append(f"### Instruction:\n{user_msg}\n\n### Response:\n{assistant_msg}")
    prompt_parts.append(f"### Instruction:\n{message}\n\n### Response:\n")
    prompt = "\n".join(prompt_parts)
    
    # Génération de la réponse
    output = current_pipeline(prompt, max_new_tokens=100, pad_token_id=current_pipeline.tokenizer.eos_token_id)
    
    # Nettoyage de la réponse
    full_response = output[0]['generated_text']
    clean_response = full_response.split("### Response:")[-1].strip()
    
    return clean_response

# --- Création de l'interface Gradio personnalisée ---
with gr.Blocks(theme="soft") as demo:
    gr.Markdown("# 🤖 Lamina Chatbot - Comparateur de Versions")
    gr.Markdown("Choisissez une version de Lamina pour commencer à discuter.")

    with gr.Row():
        with gr.Column(scale=1):
            model_selector = gr.Radio(
                choices=list(MODELS.keys()),
                label="Choisir un modèle",
                value="Lamina-extend" # Le modèle chargé par défaut
            )
            model_description = gr.Markdown(MODELS["Lamina-extend"]["description"])
        with gr.Column(scale=3):
            gr.ChatInterface(
                fn=predict,
                examples=[["Qui es-tu ?"], ["Que sais-tu sur les insectes ?"], ["Raconte-moi une blague."]]
            )
            
    # On lie le changement du sélecteur à la fonction de chargement de modèle
    model_selector.change(load_model, inputs=model_selector, outputs=model_description)
    
    # On charge le modèle par défaut au démarrage de l'application
    demo.load(lambda: load_model("Lamina-extend"), inputs=None, outputs=model_description)

# --- Lancement de l'application ---
if __name__ == "__main__":
    demo.launch()