Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,107 Bytes
b22b80e a693cd4 b05966a 0f51018 f5a3617 ee02270 afa2559 b05966a b22b80e b05966a 937a94e 699b46e ee02270 b05966a ee02270 408c04e ee02270 b22b80e 9c2430d ee02270 b22b80e ee02270 9c2430d ee02270 9c2430d ee02270 9c2430d ee02270 9c2430d ee02270 b22b80e 9c2430d ee02270 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import gradio as gr
import spaces
import torch
from diffusers import QwenImagePipeline
from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
from optimization import compile_transformer
from hub_utils import _push_compiled_graph_to_hub
from huggingface_hub import whoami
# --- Model Loading ---
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the model pipeline
pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=dtype).to(device)
pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
@spaces.GPU(duration=120)
def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken):
if not filename.endswith(".pt2"):
raise NotImplementedError("The filename must end with a `.pt2` extension.")
# this will throw if token is invalid
_ = whoami(oauth_token.token)
# --- Ahead-of-time compilation ---
compiled_transformer = compile_transformer(pipe, prompt="prompt")
token = oauth_token.token
out = _push_compiled_graph_to_hub(
compiled_transformer.archive_file,
repo_id=repo_id,
token=token,
path_in_repo=filename
)
if not isinstance(out, str) and hasattr(out, "commit_url"):
commit_url = out.commit_url
return f"[{commit_url}]({commit_url})"
else:
return out
css = """
#app {max-width: 840px; margin: 0 auto;}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("## Compile a model graph ahead of time & push to the Hub")
gr.Markdown("Enter a **repo_id** and **filename**. This repo automatically compiles the [QwenImage](https://hf.co/Qwen/Qwen-Image) model on start.")
with gr.Row():
repo_id = gr.Textbox(label="repo_id", placeholder="e.g. sayakpaul/qwen-aot")
filename = gr.Textbox(label="filename", placeholder="e.g. compiled.pt2")
run = gr.Button("Push graph to Hub", variant="primary")
markdown_out = gr.Markdown()
run.click(push_to_hub, inputs=[repo_id, filename], outputs=[markdown_out])
if __name__ == "__main__":
demo.launch(show_error=True) |