|
|
import gradio as gr |
|
|
from transformers import pipeline |
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
MODELS = { |
|
|
"Lamina-extend": { |
|
|
"repo_id": "Clemylia/Lamina-extend", |
|
|
"description": "✅ Le plus avancé. Peut répondre à 67 questions, ses phrases sont mieux construites que Lamina-yl1." |
|
|
}, |
|
|
"Lamina-yl1": { |
|
|
"repo_id": "Clemylia/Lamina-yl1", |
|
|
"description": "👍 Intermédiaire. Peut répondre à 55 questions, essaie de formuler ses propres phrases mais avec des incohérences." |
|
|
}, |
|
|
"Lamina-basic": { |
|
|
"repo_id": "Clemylia/lamina-basic", |
|
|
"description": "🧪 Expérimental. A tendance à générer du texte créatif et des mélanges de mots sans sens." |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
current_pipeline = None |
|
|
current_model_name = "" |
|
|
|
|
|
|
|
|
def load_model(model_name_to_load): |
|
|
global current_pipeline, current_model_name |
|
|
|
|
|
|
|
|
if model_name_to_load != current_model_name: |
|
|
print(f"Chargement du modèle : {model_name_to_load}...") |
|
|
repo_id = MODELS[model_name_to_load]["repo_id"] |
|
|
|
|
|
current_pipeline = pipeline( |
|
|
'text-generation', |
|
|
model=repo_id, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="auto" |
|
|
) |
|
|
current_model_name = model_name_to_load |
|
|
print(f"✅ Modèle {model_name_to_load} prêt.") |
|
|
|
|
|
|
|
|
return MODELS[model_name_to_load]["description"] |
|
|
|
|
|
|
|
|
def predict(message, history): |
|
|
|
|
|
if current_pipeline is None: |
|
|
return "Veuillez d'abord sélectionner un modèle." |
|
|
|
|
|
|
|
|
prompt_parts = [] |
|
|
for user_msg, assistant_msg in history: |
|
|
prompt_parts.append(f"### Instruction:\n{user_msg}\n\n### Response:\n{assistant_msg}") |
|
|
prompt_parts.append(f"### Instruction:\n{message}\n\n### Response:\n") |
|
|
prompt = "\n".join(prompt_parts) |
|
|
|
|
|
|
|
|
output = current_pipeline(prompt, max_new_tokens=100, pad_token_id=current_pipeline.tokenizer.eos_token_id) |
|
|
|
|
|
|
|
|
full_response = output[0]['generated_text'] |
|
|
clean_response = full_response.split("### Response:")[-1].strip() |
|
|
|
|
|
return clean_response |
|
|
|
|
|
|
|
|
with gr.Blocks(theme="soft") as demo: |
|
|
gr.Markdown("# 🤖 Lamina Chatbot - Comparateur de Versions") |
|
|
gr.Markdown("Choisissez une version de Lamina pour commencer à discuter.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
model_selector = gr.Radio( |
|
|
choices=list(MODELS.keys()), |
|
|
label="Choisir un modèle", |
|
|
value="Lamina-extend" |
|
|
) |
|
|
model_description = gr.Markdown(MODELS["Lamina-extend"]["description"]) |
|
|
with gr.Column(scale=3): |
|
|
gr.ChatInterface( |
|
|
fn=predict, |
|
|
examples=[["Qui es-tu ?"], ["Que sais-tu sur les insectes ?"], ["Raconte-moi une blague."]] |
|
|
) |
|
|
|
|
|
|
|
|
model_selector.change(load_model, inputs=model_selector, outputs=model_description) |
|
|
|
|
|
|
|
|
demo.load(lambda: load_model("Lamina-extend"), inputs=None, outputs=model_description) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |