Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,109 Bytes
b22b80e b05966a 0f51018 ee02270 699b46e afa2559 b05966a b22b80e b05966a 937a94e ee02270 9c2430d 699b46e ee02270 b05966a ee02270 b22b80e 9c2430d ee02270 b22b80e ee02270 9c2430d ee02270 9c2430d ee02270 9c2430d ee02270 9c2430d ee02270 b22b80e 9c2430d ee02270 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import gradio as gr
import torch
from diffusers import QwenImagePipeline
from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
from optimization import get_compiled_transformer
from hub_utils import _push_compiled_graph_to_hub
from huggingface_hub import whoami
import spaces
# --- Model Loading ---
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the model pipeline
pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=dtype).to(device)
pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
# --- Ahead-of-time compilation ---
compiled_transformer = get_compiled_transformer(pipe, prompt="prompt")
@spaces.GPU(duration=120)
def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken):
if not filename.endswith(".pt2"):
raise NotImplementedError("The filename must end with a `.pt2` extension.")
# this will throw if token is invalid
_ = whoami(oauth_token.token)
token = oauth_token.token
out = _push_compiled_graph_to_hub(
compiled_transformer.archive_file,
repo_id=repo_id,
token=token,
path_in_repo=filename
)
if not isinstance(out, str) and hasattr(out, "commit_url"):
commit_url = out.commit_url
return f"[{commit_url}]({commit_url})"
else:
return out
css = """
#app {max-width: 840px; margin: 0 auto;}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("## Compile a model graph ahead of time & push to the Hub")
gr.Markdown("Enter a **repo_id** and **filename**. This repo automatically compiles the [QwenImage](https://hf.co/Qwen/Qwen-Image) model on start.")
with gr.Row():
repo_id = gr.Textbox(label="repo_id", placeholder="e.g. sayakpaul/qwen-aot")
filename = gr.Textbox(label="filename", placeholder="e.g. compiled.pt2")
run = gr.Button("Push graph to Hub", variant="primary")
markdown_out = gr.Markdown()
run.click(push_to_hub, inputs=[repo_id, filename], outputs=[markdown_out])
if __name__ == "__main__":
demo.launch(show_error=True) |