File size: 2,109 Bytes
b22b80e
b05966a
 
0f51018
ee02270
 
 
699b46e
afa2559
b05966a
 
 
b22b80e
b05966a
 
937a94e
 
 
ee02270
9c2430d
699b46e
ee02270
 
 
b05966a
ee02270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b22b80e
9c2430d
ee02270
b22b80e
 
 
ee02270
 
9c2430d
ee02270
 
 
9c2430d
ee02270
9c2430d
ee02270
9c2430d
ee02270
b22b80e
9c2430d
ee02270
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import gradio as gr
import torch
from diffusers import QwenImagePipeline
from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
from optimization import get_compiled_transformer
from hub_utils import _push_compiled_graph_to_hub
from huggingface_hub import whoami
import spaces

# --- Model Loading ---
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the model pipeline
pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=dtype).to(device)
pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())

# --- Ahead-of-time compilation ---
compiled_transformer = get_compiled_transformer(pipe, prompt="prompt")

@spaces.GPU(duration=120)
def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken):
    if not filename.endswith(".pt2"):
        raise NotImplementedError("The filename must end with a `.pt2` extension.")
    
    # this will throw if token is invalid
    _ = whoami(oauth_token.token)

    token = oauth_token.token 
    out = _push_compiled_graph_to_hub(
        compiled_transformer.archive_file,
        repo_id=repo_id,
        token=token,
        path_in_repo=filename
    )
    if not isinstance(out, str) and hasattr(out, "commit_url"):
        commit_url = out.commit_url
        return f"[{commit_url}]({commit_url})"
    else:
        return out
     

css = """
#app {max-width: 840px; margin: 0 auto;}
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown("## Compile a model graph ahead of time & push to the Hub")
    gr.Markdown("Enter a **repo_id** and **filename**. This repo automatically compiles the [QwenImage](https://hf.co/Qwen/Qwen-Image) model on start.")

    with gr.Row():
        repo_id = gr.Textbox(label="repo_id", placeholder="e.g. sayakpaul/qwen-aot")
        filename = gr.Textbox(label="filename", placeholder="e.g. compiled.pt2")

    run = gr.Button("Push graph to Hub", variant="primary")

    markdown_out = gr.Markdown()

    run.click(push_to_hub, inputs=[repo_id, filename], outputs=[markdown_out])

if __name__ == "__main__":
    demo.launch(show_error=True)