Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from repeng import ControlVector, ControlModel | |
| import gradio as gr | |
| # Initialize model and tokenizer | |
| from huggingface_hub import login | |
| # Initialize model and tokenizer | |
| mistral_path = "mistralai/Mistral-7B-Instruct-v0.3" | |
| # mistral_path = "E:/language_models/models/mistral" | |
| access_token = os.getenv("mistralaccesstoken") | |
| login(access_token) | |
| tokenizer = AutoTokenizer.from_pretrained(mistral_path) | |
| tokenizer.pad_token_id = 0 | |
| model = AutoModelForCausalLM.from_pretrained( | |
| mistral_path, | |
| torch_dtype=torch.float16, | |
| trust_remote_code=True, | |
| use_safetensors=True | |
| ) | |
| model = model.to("cuda:0" if torch.cuda.is_available() else "cpu") | |
| model = ControlModel(model, list(range(-5, -18, -1))) | |
| # Generation settings | |
| default_generation_settings = { | |
| "pad_token_id": tokenizer.eos_token_id, # Silence warning | |
| "do_sample": False, # Deterministic output | |
| "max_new_tokens": 256, | |
| "repetition_penalty": 1.1, # Reduce repetition | |
| } | |
| # Tags for prompt formatting | |
| user_tag, asst_tag = "[INST]", "[/INST]" | |
| # List available control vectors | |
| control_vector_files = [f for f in os.listdir('.') if f.endswith('.gguf')] | |
| if not control_vector_files: | |
| raise FileNotFoundError("No .gguf control vector files found in the current directory.") | |
| # Function to toggle slider visibility based on checkbox state | |
| def toggle_slider(checked): | |
| return gr.update(visible=checked) | |
| # Function to generate the model's response | |
| def generate_response(system_prompt, user_message, history, max_new_tokens, repitition_penalty, *args): | |
| checkboxes = [] | |
| sliders = [] | |
| #inputs_list = [system_prompt, user_input, chatbot, max_new_tokens, repetition_penalty] + control_checks + control_sliders | |
| # Separate checkboxes and sliders based on type | |
| # The first x in args are the checkbox names (the file names) | |
| # The second x in args are the slider values | |
| for i in range(len(control_vector_files)): | |
| checkboxes.append(args[i]) | |
| sliders.append(args[len(control_vector_files) + i]) | |
| if len(checkboxes) != len(control_vector_files) or len(sliders) != len(control_vector_files): | |
| return history if history else [], history if history else [] | |
| # Reset any previous control vectors | |
| model.reset() | |
| # Apply selected control vectors with their corresponding weights | |
| for i in range(len(control_vector_files)): | |
| if checkboxes[i]: | |
| cv_file = control_vector_files[i] | |
| weight = sliders[i] | |
| try: | |
| control_vector = ControlVector.import_gguf(cv_file) | |
| model.set_control(control_vector, weight) | |
| except Exception as e: | |
| print(f"Failed to set control vector {cv_file}: {e}") | |
| formatted_prompt = "" | |
| # Mistral expects the history to be wrapped in <s>history</s> | |
| if len(history) > 0: | |
| formatted_prompt += "<s>" | |
| # Append the system prompt if provided | |
| if system_prompt.strip(): | |
| formatted_prompt += f"[INST] {system_prompt} [/INST] " | |
| # Construct the formatted prompt based on history | |
| if len(history) > 0: | |
| for turn in history: | |
| user_msg, asst_msg = turn | |
| formatted_prompt += f"{user_tag} {user_msg} {asst_tag} {asst_msg}" | |
| if len(history) > 0: | |
| formatted_prompt += "</s>" | |
| # Append the new user message | |
| formatted_prompt += f"{user_tag} {user_message} {asst_tag}" | |
| # Tokenize the input | |
| input_ids = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) | |
| generation_settings = { | |
| "pad_token_id": tokenizer.eos_token_id, | |
| "do_sample": default_generation_settings["do_sample"], | |
| "max_new_tokens": int(max_new_tokens), | |
| "repetition_penalty": repetition_penalty.value, | |
| } | |
| # Generate the response | |
| output_ids = model.generate(**input_ids, **generation_settings) | |
| response = tokenizer.decode(output_ids.squeeze(), skip_special_tokens=False) | |
| def get_assistant_response(input_string): | |
| # Use regex to find the text between the final [/INST] tag and </s> | |
| pattern = r'\[/INST\](?!.*\[/INST\])\s*(.*?)(?:</s>|$)' | |
| match = re.search(pattern, input_string, re.DOTALL) | |
| if match: | |
| return match.group(1).strip() | |
| return None | |
| assistant_response = get_assistant_response(response) | |
| # Update conversation history | |
| history.append((user_message, assistant_response)) | |
| return history | |
| # Function to reset the conversation history | |
| def reset_chat(): | |
| # returns a blank user input text and a blank conversation history | |
| return [], [] | |
| # Build the Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🧠 Mistral v3 Language Model Interface") | |
| with gr.Row(): | |
| # Left Column: Settings and Control Vectors | |
| with gr.Column(scale=1): | |
| gr.Markdown("### ⚙️ Settings") | |
| # System Prompt Input | |
| system_prompt = gr.Textbox( | |
| label="System Prompt", | |
| lines=2, | |
| placeholder="Respond tot he user concisely" | |
| ) | |
| gr.Markdown("### 📊 Control Vectors") | |
| # Create checkboxes and sliders for each control vector | |
| control_checks = [] | |
| control_sliders = [] | |
| for cv_file in control_vector_files: | |
| with gr.Row(): | |
| # Checkbox to select the control vector | |
| checkbox = gr.Checkbox(label=cv_file, value=False) | |
| control_checks.append(checkbox) | |
| # Slider to adjust the control vector's weight | |
| slider = gr.Slider( | |
| minimum=-2.5, | |
| maximum=2.5, | |
| value=0.0, | |
| step=0.1, | |
| label=f"{cv_file} Weight", | |
| visible=False | |
| ) | |
| control_sliders.append(slider) | |
| # Link the checkbox to toggle slider visibility | |
| checkbox.change( | |
| toggle_slider, | |
| inputs=checkbox, | |
| outputs=slider | |
| ) | |
| # Advanced Settings Section (collapsed by default) | |
| with gr.Accordion("🔧 Advanced Settings", open=False): | |
| with gr.Row(): | |
| max_new_tokens = gr.Number( | |
| label="Max New Tokens", | |
| value=default_generation_settings["max_new_tokens"], | |
| precision=0, | |
| step=10, | |
| ) | |
| repetition_penalty = gr.Number( | |
| label="Repetition Penalty", | |
| value=default_generation_settings["repetition_penalty"], | |
| precision=2, | |
| step=0.1, | |
| ) | |
| # Right Column: Chat Interface | |
| with gr.Column(scale=2): | |
| gr.Markdown("### 🗨️ Conversation") | |
| # Chatbot to display conversation | |
| chatbot = gr.Chatbot(label="Conversation") | |
| # User Message Input | |
| user_input = gr.Textbox( | |
| label="Your Message", | |
| lines=2, | |
| placeholder="Type your message here..." | |
| ) | |
| with gr.Row(): | |
| # Submit and New Chat buttons | |
| submit_button = gr.Button("💬 Submit") | |
| new_chat_button = gr.Button("🆕 New Chat") | |
| inputs_list = [system_prompt, user_input, chatbot, max_new_tokens, repetition_penalty] + control_checks + control_sliders | |
| # Define button actions | |
| submit_button.click( | |
| generate_response, | |
| inputs=inputs_list, | |
| outputs=[chatbot] | |
| ) | |
| new_chat_button.click( | |
| reset_chat, | |
| inputs=[], | |
| outputs=[chatbot, user_input] | |
| ) | |
| # Launch the Gradio app | |
| if __name__ == "__main__": | |
| demo.launch() |