padmanabhbosamia commited on
Commit
e9026bb
·
verified ·
1 Parent(s): fd6b90a

Changed before after

Browse files
Files changed (1) hide show
  1. app.py +65 -12
app.py CHANGED
@@ -21,6 +21,15 @@ tokenizer = AutoTokenizer.from_pretrained("./fine-tuned-model")
21
  tokenizer.pad_token = tokenizer.eos_token
22
  tokenizer.padding_side = 'left'
23
 
 
 
 
 
 
 
 
 
 
24
  def generate_response(
25
  prompt,
26
  max_length=128, # Match training max_length
@@ -29,6 +38,7 @@ def generate_response(
29
  num_generations=2, # Match training num_generations
30
  repetition_penalty=1.1,
31
  do_sample=True,
 
32
  ):
33
  try:
34
  # Get the device of the model
@@ -40,7 +50,7 @@ def generate_response(
40
  # Move inputs to the same device as the model
41
  inputs = {k: v.to(device) for k, v in inputs.items()}
42
 
43
- # Generate response
44
  with torch.no_grad(): # Disable gradient computation
45
  outputs = model.generate(
46
  **inputs,
@@ -60,7 +70,35 @@ def generate_response(
60
  response = tokenizer.decode(output, skip_special_tokens=True)
61
  responses.append(response)
62
 
63
- return "\n\n---\n\n".join(responses)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  except Exception as e:
65
  console.print(f"[bold red]Error during generation: {str(e)}[/bold red]")
66
  return f"Error: {str(e)}"
@@ -85,6 +123,12 @@ custom_css = """
85
  line-height: 1.6;
86
  margin-bottom: 20px;
87
  }
 
 
 
 
 
 
88
  """
89
 
90
  # Create the Gradio interface with enhanced UI
@@ -160,12 +204,17 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
160
  info="Enable/disable sampling for deterministic output"
161
  )
162
 
 
 
 
 
 
 
163
  generate_btn = gr.Button("Generate", variant="primary")
164
 
165
  with gr.Column(scale=3):
166
- output = gr.Textbox(
167
  label="Generated Response(s)",
168
- lines=10,
169
  show_label=True,
170
  )
171
 
@@ -177,22 +226,23 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
177
  1. **Technical Questions**:
178
  - "What is machine learning?"
179
  - "What is deep learning?"
180
- - "Explain quantum computing in simple terms."
181
 
182
  2. **Creative Writing**:
183
  - "Write a short story about a robot learning to paint."
184
  - "Write a story about a time-traveling smartphone."
185
- - "Write a poem about artificial intelligence."
 
186
 
187
  3. **Technical Explanations**:
188
  - "How does neural network training work?"
189
- - "What is the difference between supervised and unsupervised learning?"
190
- - "Explain the concept of transfer learning."
191
 
192
  4. **Creative Tasks**:
193
- - "Write a fairy tale about a computer learning to dream."
194
- - "Create a story about an AI becoming an artist."
195
  - "Write a poem about the future of technology."
 
196
  """,
197
  elem_classes="description"
198
  )
@@ -207,7 +257,9 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
207
  ["What is deep learning?"],
208
  ["Write a story about a time-traveling smartphone."],
209
  ["How does neural network training work?"],
210
- ["Write a fairy tale about a computer learning to dream."]
 
 
211
  ],
212
  inputs=prompt
213
  )
@@ -222,7 +274,8 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
222
  top_p,
223
  num_generations,
224
  repetition_penalty,
225
- do_sample
 
226
  ],
227
  outputs=output
228
  )
 
21
  tokenizer.pad_token = tokenizer.eos_token
22
  tokenizer.padding_side = 'left'
23
 
24
+ # Load base model for before/after comparison
25
+ console.print("[bold green]Loading base model for comparison...[/bold green]")
26
+ base_model = AutoModelForCausalLM.from_pretrained(
27
+ "microsoft/phi-2",
28
+ device_map="auto",
29
+ trust_remote_code=True,
30
+ torch_dtype=torch.float16,
31
+ )
32
+
33
  def generate_response(
34
  prompt,
35
  max_length=128, # Match training max_length
 
38
  num_generations=2, # Match training num_generations
39
  repetition_penalty=1.1,
40
  do_sample=True,
41
+ show_comparison=True, # New parameter for comparison toggle
42
  ):
43
  try:
44
  # Get the device of the model
 
50
  # Move inputs to the same device as the model
51
  inputs = {k: v.to(device) for k, v in inputs.items()}
52
 
53
+ # Generate response from fine-tuned model
54
  with torch.no_grad(): # Disable gradient computation
55
  outputs = model.generate(
56
  **inputs,
 
70
  response = tokenizer.decode(output, skip_special_tokens=True)
71
  responses.append(response)
72
 
73
+ fine_tuned_response = "\n\n---\n\n".join(responses)
74
+
75
+ if show_comparison:
76
+ # Generate response from base model
77
+ with torch.no_grad():
78
+ base_outputs = base_model.generate(
79
+ **inputs,
80
+ max_new_tokens=max_length,
81
+ do_sample=do_sample,
82
+ temperature=temperature,
83
+ top_p=top_p,
84
+ num_return_sequences=1, # Only one for comparison
85
+ repetition_penalty=repetition_penalty,
86
+ pad_token_id=tokenizer.eos_token_id,
87
+ eos_token_id=tokenizer.eos_token_id,
88
+ )
89
+
90
+ base_response = tokenizer.decode(base_outputs[0], skip_special_tokens=True)
91
+
92
+ return f"""
93
+ ### Before Fine-tuning (Base Model)
94
+ {base_response}
95
+
96
+ ### After Fine-tuning
97
+ {fine_tuned_response}
98
+ """
99
+ else:
100
+ return fine_tuned_response
101
+
102
  except Exception as e:
103
  console.print(f"[bold red]Error during generation: {str(e)}[/bold red]")
104
  return f"Error: {str(e)}"
 
123
  line-height: 1.6;
124
  margin-bottom: 20px;
125
  }
126
+ .comparison {
127
+ background-color: #f8f9fa;
128
+ padding: 15px;
129
+ border-radius: 8px;
130
+ margin: 10px 0;
131
+ }
132
  """
133
 
134
  # Create the Gradio interface with enhanced UI
 
204
  info="Enable/disable sampling for deterministic output"
205
  )
206
 
207
+ show_comparison = gr.Checkbox(
208
+ value=True,
209
+ label="Show Before/After Comparison",
210
+ info="Toggle to show responses from both base and fine-tuned models"
211
+ )
212
+
213
  generate_btn = gr.Button("Generate", variant="primary")
214
 
215
  with gr.Column(scale=3):
216
+ output = gr.Markdown(
217
  label="Generated Response(s)",
 
218
  show_label=True,
219
  )
220
 
 
226
  1. **Technical Questions**:
227
  - "What is machine learning?"
228
  - "What is deep learning?"
229
+ - "What is the difference between supervised and unsupervised learning?"
230
 
231
  2. **Creative Writing**:
232
  - "Write a short story about a robot learning to paint."
233
  - "Write a story about a time-traveling smartphone."
234
+ - "Write a fairy tale about a computer learning to dream."
235
+ - "Create a story about an AI becoming an artist."
236
 
237
  3. **Technical Explanations**:
238
  - "How does neural network training work?"
239
+ - "Explain quantum computing in simple terms."
240
+ - "What is transfer learning?"
241
 
242
  4. **Creative Tasks**:
243
+ - "Write a poem about artificial intelligence."
 
244
  - "Write a poem about the future of technology."
245
+ - "Create a story about a robot learning to dream."
246
  """,
247
  elem_classes="description"
248
  )
 
257
  ["What is deep learning?"],
258
  ["Write a story about a time-traveling smartphone."],
259
  ["How does neural network training work?"],
260
+ ["Write a fairy tale about a computer learning to dream."],
261
+ ["What is the difference between supervised and unsupervised learning?"],
262
+ ["Create a story about an AI becoming an artist."]
263
  ],
264
  inputs=prompt
265
  )
 
274
  top_p,
275
  num_generations,
276
  repetition_penalty,
277
+ do_sample,
278
+ show_comparison
279
  ],
280
  outputs=output
281
  )