codemichaeld commited on
Commit
7f615fd
Β·
verified Β·
1 Parent(s): 06c835d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +223 -0
app.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import tempfile
4
+ import shutil
5
+ import re
6
+ import json
7
+ import datetime
8
+ from pathlib import Path
9
+ from huggingface_hub import HfApi, hf_hub_download
10
+ from safetensors.torch import load_file, save_file
11
+ import torch
12
+
13
+ # --- Conversion Function: Safetensors β†’ FP8 Safetensors (E4M3FN or E5M2) ---
14
+ def convert_safetensors_to_fp8(safetensors_path, output_dir, fp8_format, progress=gr.Progress()):
15
+ """
16
+ Loads a .safetensors file and saves a pruned FP8 version.
17
+ fp8_format: 'e4m3fn' or 'e5m2'
18
+ """
19
+ progress(0.1, desc="Starting FP8 conversion...")
20
+
21
+ try:
22
+ # Read metadata
23
+ def read_safetensors_metadata(path):
24
+ with open(path, 'rb') as f:
25
+ header_size = int.from_bytes(f.read(8), 'little')
26
+ header_json = f.read(header_size).decode('utf-8')
27
+ header = json.loads(header_json)
28
+ return header.get('__metadata__', {})
29
+
30
+ metadata = read_safetensors_metadata(safetensors_path)
31
+ progress(0.3, desc="Loaded model metadata.")
32
+
33
+ # Load state dict
34
+ state_dict = load_file(safetensors_path)
35
+ progress(0.5, desc="Loaded model weights.")
36
+
37
+ # Select FP8 dtype
38
+ if fp8_format == "e5m2":
39
+ fp8_dtype = torch.float8_e5m2
40
+ else: # default to e4m3fn
41
+ fp8_dtype = torch.float8_e4m3fn
42
+
43
+ # Convert to FP8
44
+ sd_pruned = {}
45
+ total = len(state_dict)
46
+ for i, key in enumerate(state_dict):
47
+ progress(0.5 + 0.4 * (i / total), desc=f"Converting tensor {i+1}/{total} to FP8 ({fp8_format})...")
48
+ # Only convert float tensors
49
+ if state_dict[key].dtype in [torch.float16, torch.float32, torch.bfloat16]:
50
+ sd_pruned[key] = state_dict[key].to(fp8_dtype)
51
+ else:
52
+ sd_pruned[key] = state_dict[key] # keep non-float as-is (e.g., int for embeddings)
53
+
54
+ # Save FP8 safetensors
55
+ base_name = os.path.splitext(os.path.basename(safetensors_path))[0]
56
+ output_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors")
57
+ save_file(sd_pruned, output_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata})
58
+ progress(0.9, desc="Saved FP8 safetensors file.")
59
+
60
+ progress(1.0, desc="FP8 conversion complete!")
61
+ return True, f"Model successfully pruned to FP8 ({fp8_format})."
62
+
63
+ except Exception as e:
64
+ return False, str(e)
65
+
66
+ # --- Main Processing Function ---
67
+ def process_and_upload_fp8(repo_url, safetensors_filename, fp8_format, hf_token, new_repo_id, private_repo, progress=gr.Progress()):
68
+ if not all([repo_url, safetensors_filename, fp8_format, hf_token, new_repo_id]):
69
+ return None, "❌ Error: Please fill in all fields.", ""
70
+
71
+ if not re.match(r"^[a-zA-Z0-9._-]+/[a-zA-Z0-9._-]+$", new_repo_id):
72
+ return None, "❌ Error: Invalid repository ID format. Use 'username/model-name'.", ""
73
+
74
+ temp_dir = tempfile.mkdtemp()
75
+ output_dir = tempfile.mkdtemp()
76
+
77
+ try:
78
+ # Authenticate
79
+ progress(0.05, desc="Logging into Hugging Face...")
80
+ api = HfApi(token=hf_token)
81
+ user_info = api.whoami()
82
+ user_name = user_info['name']
83
+ progress(0.1, desc=f"Logged in as {user_name}.")
84
+
85
+ # Parse source repo
86
+ clean_url = repo_url.strip().rstrip("/")
87
+ if "huggingface.co" not in clean_url:
88
+ return None, "❌ Source must be a Hugging Face model repo.", ""
89
+ src_repo_id = clean_url.replace("https://huggingface.co/", "")
90
+
91
+ # Download specified safetensors file
92
+ progress(0.15, desc=f"Downloading {safetensors_filename}...")
93
+ safetensors_path = hf_hub_download(
94
+ repo_id=src_repo_id,
95
+ filename=safetensors_filename,
96
+ cache_dir=temp_dir,
97
+ token=hf_token
98
+ )
99
+ progress(0.25, desc="Download complete.")
100
+
101
+ # Convert to FP8
102
+ success, msg = convert_safetensors_to_fp8(safetensors_path, output_dir, fp8_format, progress)
103
+ if not success:
104
+ return None, f"❌ Conversion failed: {msg}", ""
105
+
106
+ # Create new repo
107
+ progress(0.92, desc="Creating new repository...")
108
+ api.create_repo(
109
+ repo_id=new_repo_id,
110
+ private=private_repo,
111
+ repo_type="model",
112
+ exist_ok=True
113
+ )
114
+
115
+ # Generate README
116
+ base_name = os.path.splitext(safetensors_filename)[0]
117
+ fp8_filename = f"{base_name}-fp8-{fp8_format}.safetensors"
118
+ readme = f"""---
119
+ library_name: diffusers
120
+ tags:
121
+ - fp8
122
+ - safetensors
123
+ - pruned
124
+ - diffusion
125
+ - converted-by-gradio
126
+ - fp8-{fp8_format}
127
+ ---
128
+
129
+ # FP8 Pruned Model ({fp8_format.upper()})
130
+
131
+ Converted from: [`{src_repo_id}`](https://huggingface.co/{src_repo_id})
132
+ File: `{safetensors_filename}` β†’ `{fp8_filename}`
133
+
134
+ Quantization: **FP8 ({fp8_format.upper()})**
135
+ Converted by: {user_name}
136
+ Date: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
137
+
138
+ > ⚠️ FP8 models require PyTorch β‰₯ 2.1 and compatible hardware (e.g., NVIDIA Ada/Hopper) for full acceleration. May fall back to FP16 on older GPUs.
139
+ """
140
+ with open(os.path.join(output_dir, "README.md"), "w") as f:
141
+ f.write(readme)
142
+
143
+ # Upload
144
+ progress(0.95, desc="Uploading to Hugging Face Hub...")
145
+ api.upload_folder(
146
+ repo_id=new_repo_id,
147
+ folder_path=output_dir,
148
+ repo_type="model",
149
+ token=hf_token,
150
+ commit_message=f"Upload FP8 ({fp8_format}) pruned safetensors model"
151
+ )
152
+
153
+ progress(1.0, desc="βœ… Done!")
154
+ result_html = f"""
155
+ βœ… Success!
156
+ Your FP8 ({fp8_format}) model is uploaded to: [{new_repo_id}](https://huggingface.co/{new_repo_id})
157
+ Visibility: {'Private' if private_repo else 'Public'}
158
+ """
159
+ return gr.HTML(result_html), "βœ… FP8 conversion and upload successful!", ""
160
+
161
+ except Exception as e:
162
+ return None, f"❌ Error: {str(e)}", ""
163
+ finally:
164
+ shutil.rmtree(temp_dir, ignore_errors=True)
165
+ shutil.rmtree(output_dir, ignore_errors=True)
166
+
167
+ # --- Gradio UI ---
168
+ with gr.Blocks(title="Safetensors β†’ FP8 Pruner") as demo:
169
+ gr.Markdown("# πŸ”„ Safetensors to FP8 Pruner")
170
+ gr.Markdown("Converts any `.safetensors` file from a Hugging Face model repo to **FP8 (E4M3FN or E5M2)** for compact storage and faster inference.")
171
+
172
+ with gr.Row():
173
+ with gr.Column():
174
+ repo_url = gr.Textbox(
175
+ label="Source Model Repository URL",
176
+ placeholder="https://huggingface.co/Yabo/FramePainter",
177
+ info="Hugging Face model repo containing your safetensors file"
178
+ )
179
+ safetensors_filename = gr.Textbox(
180
+ label="Safetensors Filename",
181
+ placeholder="unet_diffusion_pytorch_model.safetensors",
182
+ info="Name of the .safetensors file in the repo"
183
+ )
184
+ fp8_format = gr.Radio(
185
+ choices=["e4m3fn", "e5m2"],
186
+ value="e5m2",
187
+ label="FP8 Format",
188
+ info="E5M2 has wider dynamic range; E4M3FN has higher precision near zero."
189
+ )
190
+ hf_token = gr.Textbox(
191
+ label="Hugging Face Token",
192
+ type="password",
193
+ info="Write-access token from https://huggingface.co/settings/tokens"
194
+ )
195
+ with gr.Column():
196
+ new_repo_id = gr.Textbox(
197
+ label="New Repository ID",
198
+ placeholder="your-username/my-model-fp8",
199
+ info="Format: username/model-name"
200
+ )
201
+ private_repo = gr.Checkbox(label="Make Private", value=False)
202
+
203
+ convert_btn = gr.Button("πŸš€ Convert & Upload", variant="primary")
204
+
205
+ with gr.Row():
206
+ status_output = gr.Markdown()
207
+ repo_link_output = gr.HTML()
208
+
209
+ convert_btn.click(
210
+ fn=process_and_upload_fp8,
211
+ inputs=[repo_url, safetensors_filename, fp8_format, hf_token, new_repo_id, private_repo],
212
+ outputs=[repo_link_output, status_output],
213
+ show_progress=True
214
+ )
215
+
216
+ gr.Examples(
217
+ examples=[
218
+ ["https://huggingface.co/Yabo/FramePainter", "unet_diffusion_pytorch_model.safetensors", "e5m2"]
219
+ ],
220
+ inputs=[repo_url, safetensors_filename, fp8_format]
221
+ )
222
+
223
+ demo.launch()