File size: 2,107 Bytes
b22b80e
a693cd4
b05966a
 
0f51018
f5a3617
ee02270
 
afa2559
b05966a
 
 
b22b80e
b05966a
 
937a94e
 
699b46e
ee02270
 
 
b05966a
ee02270
 
 
408c04e
 
 
ee02270
 
 
 
 
 
 
 
 
 
 
 
 
b22b80e
9c2430d
ee02270
b22b80e
 
 
ee02270
 
9c2430d
ee02270
 
 
9c2430d
ee02270
9c2430d
ee02270
9c2430d
ee02270
b22b80e
9c2430d
ee02270
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import gradio as gr
import spaces
import torch
from diffusers import QwenImagePipeline
from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
from optimization import compile_transformer
from hub_utils import _push_compiled_graph_to_hub
from huggingface_hub import whoami

# --- Model Loading ---
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the model pipeline
pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=dtype).to(device)
pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())

@spaces.GPU(duration=120)
def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken):
    if not filename.endswith(".pt2"):
        raise NotImplementedError("The filename must end with a `.pt2` extension.")
    
    # this will throw if token is invalid
    _ = whoami(oauth_token.token)

    # --- Ahead-of-time compilation ---
    compiled_transformer = compile_transformer(pipe, prompt="prompt")

    token = oauth_token.token 
    out = _push_compiled_graph_to_hub(
        compiled_transformer.archive_file,
        repo_id=repo_id,
        token=token,
        path_in_repo=filename
    )
    if not isinstance(out, str) and hasattr(out, "commit_url"):
        commit_url = out.commit_url
        return f"[{commit_url}]({commit_url})"
    else:
        return out
     

css = """
#app {max-width: 840px; margin: 0 auto;}
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown("## Compile a model graph ahead of time & push to the Hub")
    gr.Markdown("Enter a **repo_id** and **filename**. This repo automatically compiles the [QwenImage](https://hf.co/Qwen/Qwen-Image) model on start.")

    with gr.Row():
        repo_id = gr.Textbox(label="repo_id", placeholder="e.g. sayakpaul/qwen-aot")
        filename = gr.Textbox(label="filename", placeholder="e.g. compiled.pt2")

    run = gr.Button("Push graph to Hub", variant="primary")

    markdown_out = gr.Markdown()

    run.click(push_to_hub, inputs=[repo_id, filename], outputs=[markdown_out])

if __name__ == "__main__":
    demo.launch(show_error=True)