Spaces:
Runtime error
Runtime error
Changed before after
Browse files
app.py
CHANGED
|
@@ -21,6 +21,15 @@ tokenizer = AutoTokenizer.from_pretrained("./fine-tuned-model")
|
|
| 21 |
tokenizer.pad_token = tokenizer.eos_token
|
| 22 |
tokenizer.padding_side = 'left'
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
def generate_response(
|
| 25 |
prompt,
|
| 26 |
max_length=128, # Match training max_length
|
|
@@ -29,6 +38,7 @@ def generate_response(
|
|
| 29 |
num_generations=2, # Match training num_generations
|
| 30 |
repetition_penalty=1.1,
|
| 31 |
do_sample=True,
|
|
|
|
| 32 |
):
|
| 33 |
try:
|
| 34 |
# Get the device of the model
|
|
@@ -40,7 +50,7 @@ def generate_response(
|
|
| 40 |
# Move inputs to the same device as the model
|
| 41 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 42 |
|
| 43 |
-
# Generate response
|
| 44 |
with torch.no_grad(): # Disable gradient computation
|
| 45 |
outputs = model.generate(
|
| 46 |
**inputs,
|
|
@@ -60,7 +70,35 @@ def generate_response(
|
|
| 60 |
response = tokenizer.decode(output, skip_special_tokens=True)
|
| 61 |
responses.append(response)
|
| 62 |
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
except Exception as e:
|
| 65 |
console.print(f"[bold red]Error during generation: {str(e)}[/bold red]")
|
| 66 |
return f"Error: {str(e)}"
|
|
@@ -85,6 +123,12 @@ custom_css = """
|
|
| 85 |
line-height: 1.6;
|
| 86 |
margin-bottom: 20px;
|
| 87 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
"""
|
| 89 |
|
| 90 |
# Create the Gradio interface with enhanced UI
|
|
@@ -160,12 +204,17 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
|
|
| 160 |
info="Enable/disable sampling for deterministic output"
|
| 161 |
)
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
generate_btn = gr.Button("Generate", variant="primary")
|
| 164 |
|
| 165 |
with gr.Column(scale=3):
|
| 166 |
-
output = gr.
|
| 167 |
label="Generated Response(s)",
|
| 168 |
-
lines=10,
|
| 169 |
show_label=True,
|
| 170 |
)
|
| 171 |
|
|
@@ -177,22 +226,23 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
|
|
| 177 |
1. **Technical Questions**:
|
| 178 |
- "What is machine learning?"
|
| 179 |
- "What is deep learning?"
|
| 180 |
-
- "
|
| 181 |
|
| 182 |
2. **Creative Writing**:
|
| 183 |
- "Write a short story about a robot learning to paint."
|
| 184 |
- "Write a story about a time-traveling smartphone."
|
| 185 |
-
- "Write a
|
|
|
|
| 186 |
|
| 187 |
3. **Technical Explanations**:
|
| 188 |
- "How does neural network training work?"
|
| 189 |
-
- "
|
| 190 |
-
- "
|
| 191 |
|
| 192 |
4. **Creative Tasks**:
|
| 193 |
-
- "Write a
|
| 194 |
-
- "Create a story about an AI becoming an artist."
|
| 195 |
- "Write a poem about the future of technology."
|
|
|
|
| 196 |
""",
|
| 197 |
elem_classes="description"
|
| 198 |
)
|
|
@@ -207,7 +257,9 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
|
|
| 207 |
["What is deep learning?"],
|
| 208 |
["Write a story about a time-traveling smartphone."],
|
| 209 |
["How does neural network training work?"],
|
| 210 |
-
["Write a fairy tale about a computer learning to dream."]
|
|
|
|
|
|
|
| 211 |
],
|
| 212 |
inputs=prompt
|
| 213 |
)
|
|
@@ -222,7 +274,8 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
|
|
| 222 |
top_p,
|
| 223 |
num_generations,
|
| 224 |
repetition_penalty,
|
| 225 |
-
do_sample
|
|
|
|
| 226 |
],
|
| 227 |
outputs=output
|
| 228 |
)
|
|
|
|
| 21 |
tokenizer.pad_token = tokenizer.eos_token
|
| 22 |
tokenizer.padding_side = 'left'
|
| 23 |
|
| 24 |
+
# Load base model for before/after comparison
|
| 25 |
+
console.print("[bold green]Loading base model for comparison...[/bold green]")
|
| 26 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 27 |
+
"microsoft/phi-2",
|
| 28 |
+
device_map="auto",
|
| 29 |
+
trust_remote_code=True,
|
| 30 |
+
torch_dtype=torch.float16,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
def generate_response(
|
| 34 |
prompt,
|
| 35 |
max_length=128, # Match training max_length
|
|
|
|
| 38 |
num_generations=2, # Match training num_generations
|
| 39 |
repetition_penalty=1.1,
|
| 40 |
do_sample=True,
|
| 41 |
+
show_comparison=True, # New parameter for comparison toggle
|
| 42 |
):
|
| 43 |
try:
|
| 44 |
# Get the device of the model
|
|
|
|
| 50 |
# Move inputs to the same device as the model
|
| 51 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 52 |
|
| 53 |
+
# Generate response from fine-tuned model
|
| 54 |
with torch.no_grad(): # Disable gradient computation
|
| 55 |
outputs = model.generate(
|
| 56 |
**inputs,
|
|
|
|
| 70 |
response = tokenizer.decode(output, skip_special_tokens=True)
|
| 71 |
responses.append(response)
|
| 72 |
|
| 73 |
+
fine_tuned_response = "\n\n---\n\n".join(responses)
|
| 74 |
+
|
| 75 |
+
if show_comparison:
|
| 76 |
+
# Generate response from base model
|
| 77 |
+
with torch.no_grad():
|
| 78 |
+
base_outputs = base_model.generate(
|
| 79 |
+
**inputs,
|
| 80 |
+
max_new_tokens=max_length,
|
| 81 |
+
do_sample=do_sample,
|
| 82 |
+
temperature=temperature,
|
| 83 |
+
top_p=top_p,
|
| 84 |
+
num_return_sequences=1, # Only one for comparison
|
| 85 |
+
repetition_penalty=repetition_penalty,
|
| 86 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 87 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
base_response = tokenizer.decode(base_outputs[0], skip_special_tokens=True)
|
| 91 |
+
|
| 92 |
+
return f"""
|
| 93 |
+
### Before Fine-tuning (Base Model)
|
| 94 |
+
{base_response}
|
| 95 |
+
|
| 96 |
+
### After Fine-tuning
|
| 97 |
+
{fine_tuned_response}
|
| 98 |
+
"""
|
| 99 |
+
else:
|
| 100 |
+
return fine_tuned_response
|
| 101 |
+
|
| 102 |
except Exception as e:
|
| 103 |
console.print(f"[bold red]Error during generation: {str(e)}[/bold red]")
|
| 104 |
return f"Error: {str(e)}"
|
|
|
|
| 123 |
line-height: 1.6;
|
| 124 |
margin-bottom: 20px;
|
| 125 |
}
|
| 126 |
+
.comparison {
|
| 127 |
+
background-color: #f8f9fa;
|
| 128 |
+
padding: 15px;
|
| 129 |
+
border-radius: 8px;
|
| 130 |
+
margin: 10px 0;
|
| 131 |
+
}
|
| 132 |
"""
|
| 133 |
|
| 134 |
# Create the Gradio interface with enhanced UI
|
|
|
|
| 204 |
info="Enable/disable sampling for deterministic output"
|
| 205 |
)
|
| 206 |
|
| 207 |
+
show_comparison = gr.Checkbox(
|
| 208 |
+
value=True,
|
| 209 |
+
label="Show Before/After Comparison",
|
| 210 |
+
info="Toggle to show responses from both base and fine-tuned models"
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
generate_btn = gr.Button("Generate", variant="primary")
|
| 214 |
|
| 215 |
with gr.Column(scale=3):
|
| 216 |
+
output = gr.Markdown(
|
| 217 |
label="Generated Response(s)",
|
|
|
|
| 218 |
show_label=True,
|
| 219 |
)
|
| 220 |
|
|
|
|
| 226 |
1. **Technical Questions**:
|
| 227 |
- "What is machine learning?"
|
| 228 |
- "What is deep learning?"
|
| 229 |
+
- "What is the difference between supervised and unsupervised learning?"
|
| 230 |
|
| 231 |
2. **Creative Writing**:
|
| 232 |
- "Write a short story about a robot learning to paint."
|
| 233 |
- "Write a story about a time-traveling smartphone."
|
| 234 |
+
- "Write a fairy tale about a computer learning to dream."
|
| 235 |
+
- "Create a story about an AI becoming an artist."
|
| 236 |
|
| 237 |
3. **Technical Explanations**:
|
| 238 |
- "How does neural network training work?"
|
| 239 |
+
- "Explain quantum computing in simple terms."
|
| 240 |
+
- "What is transfer learning?"
|
| 241 |
|
| 242 |
4. **Creative Tasks**:
|
| 243 |
+
- "Write a poem about artificial intelligence."
|
|
|
|
| 244 |
- "Write a poem about the future of technology."
|
| 245 |
+
- "Create a story about a robot learning to dream."
|
| 246 |
""",
|
| 247 |
elem_classes="description"
|
| 248 |
)
|
|
|
|
| 257 |
["What is deep learning?"],
|
| 258 |
["Write a story about a time-traveling smartphone."],
|
| 259 |
["How does neural network training work?"],
|
| 260 |
+
["Write a fairy tale about a computer learning to dream."],
|
| 261 |
+
["What is the difference between supervised and unsupervised learning?"],
|
| 262 |
+
["Create a story about an AI becoming an artist."]
|
| 263 |
],
|
| 264 |
inputs=prompt
|
| 265 |
)
|
|
|
|
| 274 |
top_p,
|
| 275 |
num_generations,
|
| 276 |
repetition_penalty,
|
| 277 |
+
do_sample,
|
| 278 |
+
show_comparison
|
| 279 |
],
|
| 280 |
outputs=output
|
| 281 |
)
|