Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| from rich.console import Console | |
| import time | |
| # Initialize rich console for better logging | |
| console = Console() | |
| # Load the model and tokenizer with the same configuration as training | |
| console.print("[bold green]Loading model and tokenizer...[/bold green]") | |
| # Load model with memory optimizations | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "./fine-tuned-model", | |
| device_map="auto", | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16, # Use float16 for memory efficiency | |
| low_cpu_mem_usage=True, # Add this for better memory handling | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained("./fine-tuned-model") | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = 'left' | |
| # Load base model for before/after comparison | |
| console.print("[bold green]Loading base model for comparison...[/bold green]") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| "microsoft/phi-2", | |
| device_map="auto", | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True, # Add this for better memory handling | |
| ) | |
| def generate_response( | |
| prompt, | |
| max_length=128, # Match training max_length | |
| temperature=0.7, | |
| top_p=0.9, | |
| num_generations=2, # Match training num_generations | |
| repetition_penalty=1.1, | |
| do_sample=True, | |
| show_comparison=True, # New parameter for comparison toggle | |
| ): | |
| try: | |
| # Get the device of the model | |
| device = next(model.parameters()).device | |
| # Tokenize the input | |
| inputs = tokenizer(prompt, return_tensors="pt", padding=True) | |
| # Move inputs to the same device as the model | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Generate response from fine-tuned model | |
| with torch.no_grad(): # Disable gradient computation | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_length, | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| top_p=top_p, | |
| num_return_sequences=num_generations, | |
| repetition_penalty=repetition_penalty, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| # Decode and return the responses | |
| responses = [] | |
| for output in outputs: | |
| response = tokenizer.decode(output, skip_special_tokens=True) | |
| responses.append(response) | |
| fine_tuned_response = "\n\n---\n\n".join(responses) | |
| if show_comparison: | |
| # Generate response from base model | |
| with torch.no_grad(): | |
| base_outputs = base_model.generate( | |
| **inputs, | |
| max_new_tokens=max_length, | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| top_p=top_p, | |
| num_return_sequences=1, # Only one for comparison | |
| repetition_penalty=repetition_penalty, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| base_response = tokenizer.decode(base_outputs[0], skip_special_tokens=True) | |
| return f""" | |
| ### Before Fine-tuning (Base Model) | |
| {base_response} | |
| ### After Fine-tuning | |
| {fine_tuned_response} | |
| """ | |
| else: | |
| return fine_tuned_response | |
| except Exception as e: | |
| console.print(f"[bold red]Error during generation: {str(e)}[/bold red]") | |
| return f"Error: {str(e)}" | |
| # Create custom CSS for better UI | |
| custom_css = """ | |
| .gradio-container { | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| } | |
| .container { | |
| max-width: 800px; | |
| margin: auto; | |
| padding: 20px; | |
| } | |
| .title { | |
| text-align: center; | |
| color: #2c3e50; | |
| margin-bottom: 20px; | |
| } | |
| .description { | |
| color: #34495e; | |
| line-height: 1.6; | |
| margin-bottom: 20px; | |
| } | |
| .comparison { | |
| background-color: #f8f9fa; | |
| padding: 15px; | |
| border-radius: 8px; | |
| margin: 10px 0; | |
| } | |
| .prompt-box { | |
| background-color: #ffffff; | |
| border: 2px solid #3498db; | |
| border-radius: 8px; | |
| padding: 15px; | |
| margin-bottom: 20px; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
| } | |
| .prompt-box label { | |
| font-size: 1.1em; | |
| font-weight: bold; | |
| color: #2c3e50; | |
| margin-bottom: 10px; | |
| display: block; | |
| } | |
| .prompt-box textarea { | |
| width: 100%; | |
| min-height: 100px; | |
| padding: 10px; | |
| border: 1px solid #bdc3c7; | |
| border-radius: 4px; | |
| font-size: 1em; | |
| line-height: 1.5; | |
| } | |
| .output-box { | |
| background-color: #ffffff; | |
| border: 2px solid #2ecc71; | |
| border-radius: 8px; | |
| padding: 20px; | |
| margin-top: 20px; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
| } | |
| .output-box label { | |
| font-size: 1.1em; | |
| font-weight: bold; | |
| color: #2c3e50; | |
| margin-bottom: 15px; | |
| display: block; | |
| } | |
| .output-box .markdown { | |
| background-color: #f8f9fa; | |
| padding: 15px; | |
| border-radius: 6px; | |
| border: 1px solid #e9ecef; | |
| } | |
| .output-box h3 { | |
| color: #2c3e50; | |
| border-bottom: 2px solid #3498db; | |
| padding-bottom: 8px; | |
| margin-top: 20px; | |
| } | |
| .output-box p { | |
| line-height: 1.6; | |
| color: #34495e; | |
| margin: 10px 0; | |
| } | |
| .loading { | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| padding: 20px; | |
| background-color: #f8f9fa; | |
| border-radius: 8px; | |
| margin: 10px 0; | |
| } | |
| .loading-spinner { | |
| width: 40px; | |
| height: 40px; | |
| border: 4px solid #f3f3f3; | |
| border-top: 4px solid #3498db; | |
| border-radius: 50%; | |
| animation: spin 1s linear infinite; | |
| margin-right: 15px; | |
| } | |
| @keyframes spin { | |
| 0% { transform: rotate(0deg); } | |
| 100% { transform: rotate(360deg); } | |
| } | |
| .loading-text { | |
| color: #2c3e50; | |
| font-size: 1.1em; | |
| font-weight: 500; | |
| } | |
| """ | |
| # Create the Gradio interface with enhanced UI | |
| with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # Phi-2 Fine-tuned with GRPO and qLoRA | |
| This model has been fine-tuned using GRPO (Generative Reward-Penalized Optimization) and compressed using qLoRA. | |
| Try it out with different prompts and generation parameters! | |
| """, | |
| elem_classes="title" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| with gr.Column(elem_classes="prompt-box"): | |
| prompt = gr.Textbox( | |
| label="Enter Your Prompt Here", | |
| placeholder="Type your prompt here... (e.g., 'What is machine learning?' or 'Write a story about a robot learning to paint')", | |
| lines=5, | |
| show_label=True, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| max_length = gr.Slider( | |
| minimum=32, | |
| maximum=256, | |
| value=128, | |
| step=32, | |
| label="Max Length", | |
| info="Maximum number of tokens to generate" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature", | |
| info="Higher values make output more random, lower values more deterministic" | |
| ) | |
| with gr.Column(): | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.1, | |
| label="Top-p", | |
| info="Nucleus sampling parameter" | |
| ) | |
| num_generations = gr.Slider( | |
| minimum=1, | |
| maximum=4, | |
| value=2, | |
| step=1, | |
| label="Number of Generations", | |
| info="Number of different responses to generate" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| repetition_penalty = gr.Slider( | |
| minimum=1.0, | |
| maximum=2.0, | |
| value=1.1, | |
| step=0.1, | |
| label="Repetition Penalty", | |
| info="Higher values prevent repetition" | |
| ) | |
| with gr.Column(): | |
| do_sample = gr.Checkbox( | |
| value=True, | |
| label="Enable Sampling", | |
| info="Enable/disable sampling for deterministic output" | |
| ) | |
| show_comparison = gr.Checkbox( | |
| value=True, | |
| label="Show Before/After Comparison", | |
| info="Toggle to show responses from both base and fine-tuned models" | |
| ) | |
| generate_btn = gr.Button("Generate", variant="primary", size="large") | |
| with gr.Column(scale=3): | |
| with gr.Column(elem_classes="output-box"): | |
| output = gr.Markdown( | |
| label="Generated Response(s)", | |
| show_label=True, | |
| value="Your generated responses will appear here...", # Add default value | |
| ) | |
| loading_status = gr.Markdown( | |
| value="", | |
| show_label=False, | |
| elem_classes="loading" | |
| ) | |
| gr.Markdown( | |
| """ | |
| ### Example Prompts | |
| Try these example prompts to test the model: | |
| 1. **Technical Questions**: | |
| - "What is machine learning?" | |
| - "What is deep learning?" | |
| - "What is the difference between supervised and unsupervised learning?" | |
| 2. **Creative Writing**: | |
| - "Write a short story about a robot learning to paint." | |
| - "Write a story about a time-traveling smartphone." | |
| - "Write a fairy tale about a computer learning to dream." | |
| - "Create a story about an AI becoming an artist." | |
| 3. **Technical Explanations**: | |
| - "How does neural network training work?" | |
| - "Explain quantum computing in simple terms." | |
| - "What is transfer learning?" | |
| 4. **Creative Tasks**: | |
| - "Write a poem about artificial intelligence." | |
| - "Write a poem about the future of technology." | |
| - "Create a story about a robot learning to dream." | |
| """, | |
| elem_classes="description" | |
| ) | |
| def generate_with_status(*args): | |
| # Show loading status | |
| loading_status.value = """ | |
| <div class="loading"> | |
| <div class="loading-spinner"></div> | |
| <div class="loading-text">Generating responses... Please wait...</div> | |
| </div> | |
| """ | |
| # Generate response | |
| result = generate_response(*args) | |
| # Clear loading status | |
| loading_status.value = "" | |
| return result | |
| # Connect the interface | |
| generate_btn.click( | |
| fn=generate_with_status, | |
| inputs=[ | |
| prompt, | |
| max_length, | |
| temperature, | |
| top_p, | |
| num_generations, | |
| repetition_penalty, | |
| do_sample, | |
| show_comparison | |
| ], | |
| outputs=output | |
| ) | |
| if __name__ == "__main__": | |
| console.print("[bold green]Starting Gradio interface...[/bold green]") | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True # Enable sharing for HuggingFace Spaces | |
| ) |