broadfield-dev commited on
Commit
6f5c595
Β·
verified Β·
1 Parent(s): ea45df8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -11
app.py CHANGED
@@ -181,14 +181,14 @@ def stage_5_package_and_upload(model_id: str, optimized_model_path: str, pipelin
181
 
182
  def run_amop_pipeline(model_id: str, pipeline_type: str, do_prune: bool, prune_percent: float, onnx_quant_type: str, calibration_file, gguf_quant_type: str):
183
  if not model_id:
184
- yield {log_output: "Please enter a Model ID.", final_output: gr.Label(value="Idle", label="Status")}
185
  return
186
 
187
  initial_log = f"[START] AMOP {pipeline_type} Pipeline Initiated.\n"
188
  yield {
189
  run_button: gr.Button(interactive=False, value="πŸš€ Running..."),
190
  analyze_button: gr.Button(interactive=False),
191
- final_output: gr.Label(value={"label": f"RUNNING ({pipeline_type})"}, show_label=True),
192
  log_output: initial_log
193
  }
194
 
@@ -196,16 +196,19 @@ def run_amop_pipeline(model_id: str, pipeline_type: str, do_prune: bool, prune_p
196
  temp_model_dir = None
197
  try:
198
  repo_name_suffix = f"-amop-cpu-{pipeline_type.lower()}"
199
- repo_id_for_link = f"{api.whoami()['name']}/{model_id.split('/')[-1]}{repo_name_suffix}"
 
 
 
200
 
201
  if pipeline_type == "ONNX":
202
  full_log += "Loading base model for pruning...\n"
203
- yield {final_output: gr.Label(value="Loading model (1/5)"), log_output: full_log}
204
  model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
205
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
206
  full_log += f"Successfully loaded base model '{model_id}'.\n"
207
 
208
- yield {final_output: gr.Label(value="Pruning model (2/5)"), log_output: full_log}
209
  if do_prune:
210
  model, log = stage_2_prune_model(model, prune_percent)
211
  full_log += log
@@ -217,7 +220,7 @@ def run_amop_pipeline(model_id: str, pipeline_type: str, do_prune: bool, prune_p
217
  tokenizer.save_pretrained(temp_model_dir)
218
  full_log += f"Saved intermediate model to temporary directory: {temp_model_dir}\n"
219
 
220
- yield {final_output: gr.Label(value="Converting to ONNX (3/5)"), log_output: full_log}
221
  calib_path = calibration_file.name if onnx_quant_type == "Static" and calibration_file else None
222
  optimized_path, log = stage_3_4_onnx_quantize(temp_model_dir, calib_path)
223
  full_log += log
@@ -225,7 +228,7 @@ def run_amop_pipeline(model_id: str, pipeline_type: str, do_prune: bool, prune_p
225
 
226
  elif pipeline_type == "GGUF":
227
  full_log += "[STAGE 1 & 2] Loading and Pruning are skipped for GGUF pipeline.\n"
228
- yield {final_output: gr.Label(value="Converting to GGUF (3/5)"), log_output: full_log}
229
  optimized_path, log = stage_3_4_gguf_quantize(model_id, gguf_quant_type)
230
  full_log += log
231
  options = {'pipeline_type': 'GGUF', 'quant_type': gguf_quant_type}
@@ -233,12 +236,12 @@ def run_amop_pipeline(model_id: str, pipeline_type: str, do_prune: bool, prune_p
233
  else:
234
  raise ValueError("Invalid pipeline type selected.")
235
 
236
- yield {final_output: gr.Label(value="Packaging & Uploading (4/5)"), log_output: full_log}
237
  final_message, log = stage_5_package_and_upload(model_id, optimized_path, full_log, options)
238
  full_log += log
239
 
240
  yield {
241
- final_output: gr.Label(value="SUCCESS", label="Status"),
242
  log_output: full_log,
243
  success_box: gr.Markdown(f"βœ… **Success!** Your optimized model is available here: [{repo_id_for_link}](https://huggingface.co/{repo_id_for_link})", visible=True),
244
  run_button: gr.Button(interactive=True, value="Run Optimization Pipeline", variant="primary"),
@@ -249,7 +252,7 @@ def run_amop_pipeline(model_id: str, pipeline_type: str, do_prune: bool, prune_p
249
  logging.error(f"AMOP Pipeline failed. Error: {e}", exc_info=True)
250
  full_log += f"\n[ERROR] Pipeline failed: {e}"
251
  yield {
252
- final_output: gr.Label(value="ERROR", label="Status"),
253
  log_output: full_log,
254
  success_box: gr.Markdown(f"❌ **An error occurred.** Check the logs for details.", visible=True),
255
  run_button: gr.Button(interactive=True, value="Run Optimization Pipeline", variant="primary"),
@@ -273,7 +276,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
273
  gr.Markdown("### 1. Select a Model")
274
  model_id_input = gr.Textbox(
275
  label="Hugging Face Model ID",
276
- placeholder="e.g., gpt2, meta-llama/Llama-2-7b-chat-hf",
277
  )
278
  analyze_button = gr.Button("πŸ” Analyze Model", variant="secondary")
279
 
 
181
 
182
  def run_amop_pipeline(model_id: str, pipeline_type: str, do_prune: bool, prune_percent: float, onnx_quant_type: str, calibration_file, gguf_quant_type: str):
183
  if not model_id:
184
+ yield {log_output: "Please enter a Model ID.", final_output: "Idle"}
185
  return
186
 
187
  initial_log = f"[START] AMOP {pipeline_type} Pipeline Initiated.\n"
188
  yield {
189
  run_button: gr.Button(interactive=False, value="πŸš€ Running..."),
190
  analyze_button: gr.Button(interactive=False),
191
+ final_output: f"RUNNING ({pipeline_type})",
192
  log_output: initial_log
193
  }
194
 
 
196
  temp_model_dir = None
197
  try:
198
  repo_name_suffix = f"-amop-cpu-{pipeline_type.lower()}"
199
+ whoami = api.whoami()
200
+ if not whoami:
201
+ raise RuntimeError("Could not authenticate with Hugging Face Hub. Check your HF_TOKEN.")
202
+ repo_id_for_link = f"{whoami['name']}/{model_id.split('/')[-1]}{repo_name_suffix}"
203
 
204
  if pipeline_type == "ONNX":
205
  full_log += "Loading base model for pruning...\n"
206
+ yield {final_output: "Loading model (1/5)", log_output: full_log}
207
  model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
208
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
209
  full_log += f"Successfully loaded base model '{model_id}'.\n"
210
 
211
+ yield {final_output: "Pruning model (2/5)", log_output: full_log}
212
  if do_prune:
213
  model, log = stage_2_prune_model(model, prune_percent)
214
  full_log += log
 
220
  tokenizer.save_pretrained(temp_model_dir)
221
  full_log += f"Saved intermediate model to temporary directory: {temp_model_dir}\n"
222
 
223
+ yield {final_output: "Converting to ONNX (3/5)", log_output: full_log}
224
  calib_path = calibration_file.name if onnx_quant_type == "Static" and calibration_file else None
225
  optimized_path, log = stage_3_4_onnx_quantize(temp_model_dir, calib_path)
226
  full_log += log
 
228
 
229
  elif pipeline_type == "GGUF":
230
  full_log += "[STAGE 1 & 2] Loading and Pruning are skipped for GGUF pipeline.\n"
231
+ yield {final_output: "Converting to GGUF (3/5)", log_output: full_log}
232
  optimized_path, log = stage_3_4_gguf_quantize(model_id, gguf_quant_type)
233
  full_log += log
234
  options = {'pipeline_type': 'GGUF', 'quant_type': gguf_quant_type}
 
236
  else:
237
  raise ValueError("Invalid pipeline type selected.")
238
 
239
+ yield {final_output: "Packaging & Uploading (4/5)", log_output: full_log}
240
  final_message, log = stage_5_package_and_upload(model_id, optimized_path, full_log, options)
241
  full_log += log
242
 
243
  yield {
244
+ final_output: gr.update(value="SUCCESS", label="Status"),
245
  log_output: full_log,
246
  success_box: gr.Markdown(f"βœ… **Success!** Your optimized model is available here: [{repo_id_for_link}](https://huggingface.co/{repo_id_for_link})", visible=True),
247
  run_button: gr.Button(interactive=True, value="Run Optimization Pipeline", variant="primary"),
 
252
  logging.error(f"AMOP Pipeline failed. Error: {e}", exc_info=True)
253
  full_log += f"\n[ERROR] Pipeline failed: {e}"
254
  yield {
255
+ final_output: gr.update(value="ERROR", label="Status"),
256
  log_output: full_log,
257
  success_box: gr.Markdown(f"❌ **An error occurred.** Check the logs for details.", visible=True),
258
  run_button: gr.Button(interactive=True, value="Run Optimization Pipeline", variant="primary"),
 
276
  gr.Markdown("### 1. Select a Model")
277
  model_id_input = gr.Textbox(
278
  label="Hugging Face Model ID",
279
+ placeholder="e.g., gpt2, google/gemma-2b",
280
  )
281
  analyze_button = gr.Button("πŸ” Analyze Model", variant="secondary")
282