broadfield-dev commited on
Commit
f074b57
Β·
verified Β·
1 Parent(s): b0fc6c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -77
app.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  import os
4
  import logging
5
  from datetime import datetime
6
- from huggingface_hub import HfApi, HfFolder
7
  from transformers import AutoConfig, AutoModel, AutoTokenizer
8
  from optimum.onnxruntime import ORTQuantizer, ORTModelForCausalLM
9
  from optimum.onnxruntime.configuration import AutoQuantizationConfig
@@ -13,10 +13,8 @@ import time
13
 
14
  # --- 1. SETUP AND CONFIGURATION ---
15
 
16
- # Setup basic logging
17
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
18
 
19
- # Ensure the user has set their Hugging Face token in the Space secrets
20
  HF_TOKEN = os.getenv("HF_TOKEN")
21
  if not HF_TOKEN:
22
  logging.warning("HF_TOKEN environment variable not set. Packaging and uploading will fail.")
@@ -35,15 +33,13 @@ def stage_1_analyze_model(model_id: str):
35
  """
36
  log_stream = "[STAGE 1] Analyzing model...\n"
37
  try:
38
- config = AutoConfig.from_pretrained(model_id)
39
  model_type = config.model_type
40
- num_params = getattr(config, "num_hidden_layers", "N/A") * getattr(config, "hidden_size", 0) / 1e6 # A rough estimate
41
 
42
  analysis_report = f"""
43
  ### Model Analysis Report
44
  - **Model ID:** `{model_id}`
45
  - **Architecture:** `{model_type}`
46
- - **Estimated Parameters:** ~{num_params:.2f}M
47
  """
48
 
49
  recommendation = ""
@@ -57,45 +53,30 @@ def stage_1_analyze_model(model_id: str):
57
  recommendation = "**Recommendation:** Unrecognized architecture. The standard path of **Quantization -> ONNX Conversion** is a safe starting point."
58
 
59
  log_stream += f"Analysis complete. Architecture: {model_type}.\n"
60
- return log_stream, analysis_report + "\n" + recommendation, gr.update(visible=True)
 
61
  except Exception as e:
62
  error_msg = f"Failed to analyze model '{model_id}'. Error: {e}"
63
  logging.error(error_msg)
64
- return log_stream + error_msg, "Could not analyze model. Please check the model ID and try again.", gr.update(visible=False)
65
 
66
 
67
- def stage_2_prune_model(model, prune_percentage: float, progress):
68
- """
69
- Performs Stage 2: Structural Reduction via one-shot unstructured pruning.
70
- """
71
  if prune_percentage == 0:
72
  return model, "Skipped pruning as percentage was 0."
73
 
74
  log_stream = "[STAGE 2] Pruning model...\n"
75
- progress(0.25, desc="Applying Unstructured Pruning")
76
-
77
- total_params = sum(p.numel() for p in model.parameters())
78
-
79
  for name, module in model.named_modules():
80
  if isinstance(module, torch.nn.Linear):
81
  prune.l1_unstructured(module, name='weight', amount=prune_percentage / 100.0)
82
- prune.remove(module, 'weight') # Makes the pruning permanent
83
-
84
- pruned_params = sum(p.numel() for p in model.parameters())
85
- reduction = (total_params - pruned_params) / total_params * 100
86
 
87
- log_stream += f"Pruning complete. Parameter reduction: ~{reduction:.2f}%\n"
88
  return model, log_stream
89
 
90
 
91
- def stage_3_and_4_quantize_and_onnx(model_id: str, progress):
92
- """
93
- Performs Stage 3 (Quantization) and Stage 4 (ONNX Conversion).
94
- This version uses post-training dynamic quantization.
95
- """
96
  log_stream = "[STAGE 3 & 4] Converting to ONNX and Quantizing...\n"
97
- progress(0.5, desc="Exporting to ONNX")
98
-
99
  try:
100
  run_id = datetime.now().strftime("%Y%m%d-%H%M%S")
101
  onnx_path = os.path.join(OUTPUT_DIR, f"{model_id.replace('/', '_')}-{run_id}-onnx")
@@ -104,16 +85,14 @@ def stage_3_and_4_quantize_and_onnx(model_id: str, progress):
104
  main_export(model_id, output=onnx_path, task="auto", trust_remote_code=True)
105
  log_stream += f"Successfully exported base model to ONNX at: {onnx_path}\n"
106
 
107
- progress(0.7, desc="Applying Dynamic Quantization")
108
  quantizer = ORTQuantizer.from_pretrained(onnx_path)
109
- dqconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False) # Dynamic quantization for CPUs
110
 
111
  quantized_path = os.path.join(onnx_path, "quantized")
112
  quantizer.quantize(save_dir=quantized_path, quantization_config=dqconfig)
113
 
114
  log_stream += f"Successfully quantized model to: {quantized_path}\n"
115
  return quantized_path, log_stream
116
-
117
  except Exception as e:
118
  error_msg = f"Failed during ONNX conversion/quantization. Error: {e}"
119
  logging.error(error_msg, exc_info=True)
@@ -124,15 +103,9 @@ def stage_5_evaluate_and_package(
124
  model_id: str,
125
  optimized_model_path: str,
126
  pipeline_log: str,
127
- options: dict,
128
- progress
129
  ):
130
- """
131
- Performs Stage 5: Evaluation, Packaging, and Uploading.
132
- """
133
  log_stream = "[STAGE 5] Evaluating and Packaging...\n"
134
- progress(0.9, desc="Evaluating performance")
135
-
136
  try:
137
  ort_model = ORTModelForCausalLM.from_pretrained(optimized_model_path)
138
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
@@ -145,18 +118,16 @@ def stage_5_evaluate_and_package(
145
  end_time = time.time()
146
 
147
  latency = (end_time - start_time) * 1000
148
- num_tokens = len(gen_tokens[0])
149
- ms_per_token = latency / num_tokens
150
 
151
  eval_report = f"- **Inference Latency:** {latency:.2f} ms\n"
152
  eval_report += f"- **Speed:** {ms_per_token:.2f} ms/token\n"
153
  log_stream += "Evaluation complete.\n"
154
  except Exception as e:
155
- eval_report = f"- **Evaluation Failed:** Could not load and test the ONNX model. This often happens if the base model is not a text-generation model. Error: {e}\n"
156
  log_stream += f"Warning: Evaluation failed. {e}\n"
157
 
158
- progress(0.95, desc="Uploading to Hugging Face Hub")
159
-
160
  if not HF_TOKEN:
161
  return "Skipping upload: HF_TOKEN not found.", log_stream + "Skipping upload: HF_TOKEN not found."
162
 
@@ -164,39 +135,29 @@ def stage_5_evaluate_and_package(
164
  repo_name = f"{model_id.split('/')[-1]}-amop-cpu"
165
  repo_url = api.create_repo(repo_id=repo_name, exist_ok=True, token=HF_TOKEN)
166
 
167
- # --- THIS IS THE UPDATED SECTION ---
168
- # Read the template file
169
  with open("model_card_template.md", "r", encoding="utf-8") as f:
170
  template_content = f.read()
171
 
172
- # Fill in the placeholders
173
  model_card_content = template_content.format(
174
- repo_name=repo_name,
175
- model_id=model_id,
176
  optimization_date=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
177
- eval_report=eval_report,
178
- pruning_status="Enabled" if options['prune'] else "Disabled",
179
- pruning_percent=options['prune_percent'],
180
- repo_id=repo_url.repo_id,
181
  pipeline_log=pipeline_log
182
  )
183
- # --- END OF UPDATED SECTION ---
184
 
185
  readme_path = os.path.join(optimized_model_path, "README.md")
186
- with open(readme_path, "w", encoding="utf-8") as f:
187
- f.write(model_card_content)
188
 
189
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
190
  tokenizer.save_pretrained(optimized_model_path)
191
 
192
  api.upload_folder(
193
- folder_path=optimized_model_path,
194
- repo_id=repo_url.repo_id,
195
- repo_type="model",
196
- token=HF_TOKEN
197
  )
198
 
199
- final_message = f"βœ… Success! Your optimized model is available at: {repo_url}"
200
  log_stream += "Upload complete.\n"
201
  return final_message, log_stream
202
  except Exception as e:
@@ -205,44 +166,56 @@ def stage_5_evaluate_and_package(
205
  return f"❌ Error: {error_msg}", log_stream + error_msg
206
 
207
 
208
- # --- 3. MAIN WORKFLOW FUNCTION ---
209
 
210
- def run_amop_pipeline(model_id: str, do_prune: bool, prune_percent: float, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
211
  if not model_id:
212
- return "Please enter a Model ID.", ""
213
-
 
214
  full_log = "[START] AMOP Pipeline Initiated.\n"
215
- progress(0, desc="Loading Base Model")
216
 
217
  try:
 
 
 
218
  model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
219
  full_log += f"Successfully loaded base model '{model_id}'.\n"
220
 
 
 
221
  if do_prune:
222
- model, log = stage_2_prune_model(model, prune_percent, progress)
223
  full_log += log
224
  else:
225
  full_log += "[STAGE 2] Pruning skipped by user.\n"
226
 
227
- # We re-export the pruned model, so it needs to be saved and reloaded by optimum
228
- # For simplicity in V1, we will export the original model from the hub
229
- # A future version could handle the pruned model state_dict
230
- optimized_path, log = stage_3_and_4_quantize_and_onnx(model_id, progress)
231
  full_log += log
232
 
 
 
233
  options = {'prune': do_prune, 'prune_percent': prune_percent}
234
- final_status, log = stage_5_evaluate_and_package(model_id, optimized_path, full_log, options, progress)
235
  full_log += log
236
-
237
- return final_status, full_log
 
238
 
239
  except Exception as e:
240
  logging.error(f"AMOP Pipeline failed. Error: {e}", exc_info=True)
241
  full_log += f"\n[ERROR] Pipeline failed: {e}"
242
- return f"❌ An error occurred during the pipeline. Check the logs for details.", full_log
243
 
244
 
245
- # --- 4. GRADIO USER INTERFACE ---
246
 
247
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
248
  gr.Markdown("# AMOP: Adaptive Model Optimization Pipeline")
@@ -266,12 +239,11 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
266
  prune_slider = gr.Slider(minimum=0, maximum=90, value=20, step=5, label="Pruning Percentage (%)")
267
 
268
  gr.Checkbox(label="Enable Quantization & ONNX (Stages 3 & 4)", value=True, interactive=False)
269
-
270
  run_button = gr.Button("3. Run Optimization Pipeline", variant="primary")
271
 
272
  with gr.Column(scale=2):
273
  gr.Markdown("### Pipeline Status & Logs")
274
- final_output = gr.Markdown(label="Final Result")
275
  log_output = gr.Textbox(label="Live Logs", lines=20, interactive=False)
276
 
277
  analyze_button.click(
 
3
  import os
4
  import logging
5
  from datetime import datetime
6
+ from huggingface_hub import HfApi
7
  from transformers import AutoConfig, AutoModel, AutoTokenizer
8
  from optimum.onnxruntime import ORTQuantizer, ORTModelForCausalLM
9
  from optimum.onnxruntime.configuration import AutoQuantizationConfig
 
13
 
14
  # --- 1. SETUP AND CONFIGURATION ---
15
 
 
16
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
17
 
 
18
  HF_TOKEN = os.getenv("HF_TOKEN")
19
  if not HF_TOKEN:
20
  logging.warning("HF_TOKEN environment variable not set. Packaging and uploading will fail.")
 
33
  """
34
  log_stream = "[STAGE 1] Analyzing model...\n"
35
  try:
36
+ config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
37
  model_type = config.model_type
 
38
 
39
  analysis_report = f"""
40
  ### Model Analysis Report
41
  - **Model ID:** `{model_id}`
42
  - **Architecture:** `{model_type}`
 
43
  """
44
 
45
  recommendation = ""
 
53
  recommendation = "**Recommendation:** Unrecognized architecture. The standard path of **Quantization -> ONNX Conversion** is a safe starting point."
54
 
55
  log_stream += f"Analysis complete. Architecture: {model_type}.\n"
56
+ # GRADIO 5 UPDATE: Instead of gr.update(), return a new component object.
57
+ return log_stream, analysis_report + "\n" + recommendation, gr.Group(visible=True)
58
  except Exception as e:
59
  error_msg = f"Failed to analyze model '{model_id}'. Error: {e}"
60
  logging.error(error_msg)
61
+ return log_stream + error_msg, "Could not analyze model. Please check the model ID and try again.", gr.Group(visible=False)
62
 
63
 
64
+ def stage_2_prune_model(model, prune_percentage: float):
 
 
 
65
  if prune_percentage == 0:
66
  return model, "Skipped pruning as percentage was 0."
67
 
68
  log_stream = "[STAGE 2] Pruning model...\n"
 
 
 
 
69
  for name, module in model.named_modules():
70
  if isinstance(module, torch.nn.Linear):
71
  prune.l1_unstructured(module, name='weight', amount=prune_percentage / 100.0)
72
+ prune.remove(module, 'weight')
 
 
 
73
 
74
+ log_stream += f"Pruning complete. Note: This version exports the original model to ONNX for maximum compatibility.\n"
75
  return model, log_stream
76
 
77
 
78
+ def stage_3_and_4_quantize_and_onnx(model_id: str):
 
 
 
 
79
  log_stream = "[STAGE 3 & 4] Converting to ONNX and Quantizing...\n"
 
 
80
  try:
81
  run_id = datetime.now().strftime("%Y%m%d-%H%M%S")
82
  onnx_path = os.path.join(OUTPUT_DIR, f"{model_id.replace('/', '_')}-{run_id}-onnx")
 
85
  main_export(model_id, output=onnx_path, task="auto", trust_remote_code=True)
86
  log_stream += f"Successfully exported base model to ONNX at: {onnx_path}\n"
87
 
 
88
  quantizer = ORTQuantizer.from_pretrained(onnx_path)
89
+ dqconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False)
90
 
91
  quantized_path = os.path.join(onnx_path, "quantized")
92
  quantizer.quantize(save_dir=quantized_path, quantization_config=dqconfig)
93
 
94
  log_stream += f"Successfully quantized model to: {quantized_path}\n"
95
  return quantized_path, log_stream
 
96
  except Exception as e:
97
  error_msg = f"Failed during ONNX conversion/quantization. Error: {e}"
98
  logging.error(error_msg, exc_info=True)
 
103
  model_id: str,
104
  optimized_model_path: str,
105
  pipeline_log: str,
106
+ options: dict
 
107
  ):
 
 
 
108
  log_stream = "[STAGE 5] Evaluating and Packaging...\n"
 
 
109
  try:
110
  ort_model = ORTModelForCausalLM.from_pretrained(optimized_model_path)
111
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
 
118
  end_time = time.time()
119
 
120
  latency = (end_time - start_time) * 1000
121
+ num_tokens = len(gen_tokens[0]) - inputs.input_ids.shape[1]
122
+ ms_per_token = latency / num_tokens if num_tokens > 0 else float('inf')
123
 
124
  eval_report = f"- **Inference Latency:** {latency:.2f} ms\n"
125
  eval_report += f"- **Speed:** {ms_per_token:.2f} ms/token\n"
126
  log_stream += "Evaluation complete.\n"
127
  except Exception as e:
128
+ eval_report = f"- **Evaluation Failed:** Could not run generation. This often happens if the base model is not a text-generation model. Error: {e}\n"
129
  log_stream += f"Warning: Evaluation failed. {e}\n"
130
 
 
 
131
  if not HF_TOKEN:
132
  return "Skipping upload: HF_TOKEN not found.", log_stream + "Skipping upload: HF_TOKEN not found."
133
 
 
135
  repo_name = f"{model_id.split('/')[-1]}-amop-cpu"
136
  repo_url = api.create_repo(repo_id=repo_name, exist_ok=True, token=HF_TOKEN)
137
 
 
 
138
  with open("model_card_template.md", "r", encoding="utf-8") as f:
139
  template_content = f.read()
140
 
 
141
  model_card_content = template_content.format(
142
+ repo_name=repo_name, model_id=model_id,
 
143
  optimization_date=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
144
+ eval_report=eval_report, pruning_status="Enabled" if options['prune'] else "Disabled",
145
+ pruning_percent=options['prune_percent'], repo_id=repo_url.repo_id,
 
 
146
  pipeline_log=pipeline_log
147
  )
 
148
 
149
  readme_path = os.path.join(optimized_model_path, "README.md")
150
+ with open(readme_path, "w", encoding="utf-8") as f: f.write(model_card_content)
 
151
 
152
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
153
  tokenizer.save_pretrained(optimized_model_path)
154
 
155
  api.upload_folder(
156
+ folder_path=optimized_model_path, repo_id=repo_url.repo_id,
157
+ repo_type="model", token=HF_TOKEN
 
 
158
  )
159
 
160
+ final_message = f"βœ… Success! Your optimized model is available at: [{repo_url.repo_id}](https://huggingface.co/{repo_url.repo_id})"
161
  log_stream += "Upload complete.\n"
162
  return final_message, log_stream
163
  except Exception as e:
 
166
  return f"❌ Error: {error_msg}", log_stream + error_msg
167
 
168
 
169
+ # --- 3. MAIN WORKFLOW FUNCTION (GENERATOR FOR GRADIO 5+) ---
170
 
171
+ def run_amop_pipeline(model_id: str, do_prune: bool, prune_percent: float):
172
+ """
173
+ This is now a generator function. It 'yields' updates to the UI
174
+ at each step, providing a real-time log.
175
+ """
176
  if not model_id:
177
+ yield "Please enter a Model ID.", ""
178
+ return
179
+
180
  full_log = "[START] AMOP Pipeline Initiated.\n"
181
+ yield gr.Markdown("πŸš€ Pipeline is running... Check logs for real-time updates."), full_log
182
 
183
  try:
184
+ # Step 1: Load Model
185
+ full_log += "Loading base model...\n"
186
+ yield gr.Markdown("πŸš€ Pipeline is running... (1/5) Loading model"), full_log
187
  model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
188
  full_log += f"Successfully loaded base model '{model_id}'.\n"
189
 
190
+ # Step 2: Pruning
191
+ yield gr.Markdown("πŸš€ Pipeline is running... (2/5) Pruning model"), full_log
192
  if do_prune:
193
+ model, log = stage_2_prune_model(model, prune_percent)
194
  full_log += log
195
  else:
196
  full_log += "[STAGE 2] Pruning skipped by user.\n"
197
 
198
+ # Step 3 & 4: ONNX Conversion
199
+ yield gr.Markdown("πŸš€ Pipeline is running... (3/5) Converting to ONNX & Quantizing"), full_log
200
+ optimized_path, log = stage_3_and_4_quantize_and_onnx(model_id)
 
201
  full_log += log
202
 
203
+ # Step 5: Packaging
204
+ yield gr.Markdown("πŸš€ Pipeline is running... (4/5) Evaluating and Packaging"), full_log
205
  options = {'prune': do_prune, 'prune_percent': prune_percent}
206
+ final_status_msg, log = stage_5_evaluate_and_package(model_id, optimized_path, full_log, options)
207
  full_log += log
208
+
209
+ # Final Step: Done
210
+ yield gr.Markdown(final_status_msg), full_log
211
 
212
  except Exception as e:
213
  logging.error(f"AMOP Pipeline failed. Error: {e}", exc_info=True)
214
  full_log += f"\n[ERROR] Pipeline failed: {e}"
215
+ yield f"❌ An error occurred during the pipeline. Check the logs for details.", full_log
216
 
217
 
218
+ # --- 4. GRADIO USER INTERFACE (for Gradio 5+) ---
219
 
220
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
221
  gr.Markdown("# AMOP: Adaptive Model Optimization Pipeline")
 
239
  prune_slider = gr.Slider(minimum=0, maximum=90, value=20, step=5, label="Pruning Percentage (%)")
240
 
241
  gr.Checkbox(label="Enable Quantization & ONNX (Stages 3 & 4)", value=True, interactive=False)
 
242
  run_button = gr.Button("3. Run Optimization Pipeline", variant="primary")
243
 
244
  with gr.Column(scale=2):
245
  gr.Markdown("### Pipeline Status & Logs")
246
+ final_output = gr.Markdown(value="*Pipeline has not been run yet.*", label="Final Result")
247
  log_output = gr.Textbox(label="Live Logs", lines=20, interactive=False)
248
 
249
  analyze_button.click(