Clemylia commited on
Commit
9a68925
·
verified ·
1 Parent(s): 85fd04c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -26
app.py CHANGED
@@ -2,52 +2,94 @@ import gradio as gr
2
  from transformers import pipeline
3
  import torch
4
 
5
- # --- Configuration ---
6
- # Le nom de ton modèle sur le Hub Hugging Face
7
- repo_id = "Clemylia/Lamina-extend"
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # --- Chargement du modèle ---
10
- print("Chargement du chatbot Lamina...")
11
- chatbot = pipeline(
12
- 'text-generation',
13
- model=repo_id,
14
- torch_dtype=torch.bfloat16,
15
- device_map="auto"
16
- )
17
- print("✅ Lamina est prête.")
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  # --- Fonction de prédiction ---
21
- # C'est la fonction que Gradio appellera à chaque message de l'utilisateur.
22
  def predict(message, history):
23
- # On formate l'historique et la nouvelle question dans le format que Lamina a appris
 
 
 
 
24
  prompt_parts = []
25
  for user_msg, assistant_msg in history:
26
  prompt_parts.append(f"### Instruction:\n{user_msg}\n\n### Response:\n{assistant_msg}")
27
-
28
  prompt_parts.append(f"### Instruction:\n{message}\n\n### Response:\n")
29
-
30
  prompt = "\n".join(prompt_parts)
31
 
32
  # Génération de la réponse
33
- output = chatbot(prompt, max_new_tokens=100, pad_token_id=chatbot.tokenizer.eos_token_id)
34
 
35
  # Nettoyage de la réponse
36
  full_response = output[0]['generated_text']
37
- # On ne garde que la dernière réponse générée
38
  clean_response = full_response.split("### Response:")[-1].strip()
39
 
40
  return clean_response
41
 
 
 
 
 
42
 
43
- # --- Création de l'interface Gradio ---
44
- demo = gr.ChatInterface(
45
- fn=predict,
46
- title="🤖 Lamina Chatbot",
47
- description="Discutez avec Lamina, une IA entraînée par Clemylia. Posez-moi une question !",
48
- examples=[["Qui es-tu ?"], ["Que sais-tu sur les insectes ?"], ["Raconte-moi une blague."]],
49
- theme="soft"
50
- )
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  # --- Lancement de l'application ---
53
  if __name__ == "__main__":
 
2
  from transformers import pipeline
3
  import torch
4
 
5
+ # --- Dictionnaire de configuration des modèles ---
6
+ # On centralise les informations de tes modèles ici pour plus de clarté
7
+ MODELS = {
8
+ "Lamina-extend": {
9
+ "repo_id": "Clemylia/Lamina-extend",
10
+ "description": "✅ Le plus avancé. Peut répondre à 67 questions, ses phrases sont mieux construites que Lamina-yl1."
11
+ },
12
+ "Lamina-yl1": {
13
+ "repo_id": "Clemylia/Lamina-yl1",
14
+ "description": "👍 Intermédiaire. Peut répondre à 55 questions, essaie de formuler ses propres phrases mais avec des incohérences."
15
+ },
16
+ "Lamina-basic": {
17
+ "repo_id": "Clemylia/lamina-basic",
18
+ "description": "🧪 Expérimental. A tendance à générer du texte créatif et des mélanges de mots sans sens."
19
+ }
20
+ }
21
 
22
+ # Variable globale pour conserver le pipeline chargé et éviter de le recharger
23
+ current_pipeline = None
24
+ current_model_name = ""
 
 
 
 
 
 
25
 
26
+ # --- Fonction pour charger un modèle ---
27
+ def load_model(model_name_to_load):
28
+ global current_pipeline, current_model_name
29
+
30
+ # On ne recharge que si le modèle demandé est différent de celui en mémoire
31
+ if model_name_to_load != current_model_name:
32
+ print(f"Chargement du modèle : {model_name_to_load}...")
33
+ repo_id = MODELS[model_name_to_load]["repo_id"]
34
+
35
+ current_pipeline = pipeline(
36
+ 'text-generation',
37
+ model=repo_id,
38
+ torch_dtype=torch.bfloat16,
39
+ device_map="auto"
40
+ )
41
+ current_model_name = model_name_to_load
42
+ print(f"✅ Modèle {model_name_to_load} prêt.")
43
+
44
+ # On retourne la description à afficher
45
+ return MODELS[model_name_to_load]["description"]
46
 
47
  # --- Fonction de prédiction ---
 
48
  def predict(message, history):
49
+ # On vérifie qu'un modèle est bien chargé
50
+ if current_pipeline is None:
51
+ return "Veuillez d'abord sélectionner un modèle."
52
+
53
+ # On formate le prompt avec l'historique de la conversation
54
  prompt_parts = []
55
  for user_msg, assistant_msg in history:
56
  prompt_parts.append(f"### Instruction:\n{user_msg}\n\n### Response:\n{assistant_msg}")
 
57
  prompt_parts.append(f"### Instruction:\n{message}\n\n### Response:\n")
 
58
  prompt = "\n".join(prompt_parts)
59
 
60
  # Génération de la réponse
61
+ output = current_pipeline(prompt, max_new_tokens=100, pad_token_id=current_pipeline.tokenizer.eos_token_id)
62
 
63
  # Nettoyage de la réponse
64
  full_response = output[0]['generated_text']
 
65
  clean_response = full_response.split("### Response:")[-1].strip()
66
 
67
  return clean_response
68
 
69
+ # --- Création de l'interface Gradio personnalisée ---
70
+ with gr.Blocks(theme="soft") as demo:
71
+ gr.Markdown("# 🤖 Lamina Chatbot - Comparateur de Versions")
72
+ gr.Markdown("Choisissez une version de Lamina pour commencer à discuter.")
73
 
74
+ with gr.Row():
75
+ with gr.Column(scale=1):
76
+ model_selector = gr.Radio(
77
+ choices=list(MODELS.keys()),
78
+ label="Choisir un modèle",
79
+ value="Lamina-extend" # Le modèle chargé par défaut
80
+ )
81
+ model_description = gr.Markdown(MODELS["Lamina-extend"]["description"])
82
+ with gr.Column(scale=3):
83
+ gr.ChatInterface(
84
+ fn=predict,
85
+ examples=[["Qui es-tu ?"], ["Que sais-tu sur les insectes ?"], ["Raconte-moi une blague."]]
86
+ )
87
+
88
+ # On lie le changement du sélecteur à la fonction de chargement de modèle
89
+ model_selector.change(load_model, inputs=model_selector, outputs=model_description)
90
+
91
+ # On charge le modèle par défaut au démarrage de l'application
92
+ demo.load(lambda: load_model("Lamina-extend"), inputs=None, outputs=model_description)
93
 
94
  # --- Lancement de l'application ---
95
  if __name__ == "__main__":