sayakpaul HF Staff commited on
Commit
c5db835
·
1 Parent(s): d936567
Files changed (2) hide show
  1. app.py +2 -3
  2. optimization.py +16 -14
app.py CHANGED
@@ -14,6 +14,8 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
14
  # Load the model pipeline
15
  pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=dtype).to(device)
16
  pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
 
 
17
 
18
  @spaces.GPU(duration=120)
19
  def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken):
@@ -23,9 +25,6 @@ def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken):
23
  # this will throw if token is invalid
24
  _ = whoami(oauth_token.token)
25
 
26
- # --- Ahead-of-time compilation ---
27
- compiled_transformer = compile_transformer(pipe, prompt="prompt")
28
-
29
  token = oauth_token.token
30
  out = _push_compiled_graph_to_hub(
31
  compiled_transformer.archive_file,
 
14
  # Load the model pipeline
15
  pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=dtype).to(device)
16
  pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
17
+ # --- Ahead-of-time compilation ---
18
+ compiled_transformer = compile_transformer(pipe, prompt="prompt")
19
 
20
  @spaces.GPU(duration=120)
21
  def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken):
 
25
  # this will throw if token is invalid
26
  _ = whoami(oauth_token.token)
27
 
 
 
 
28
  token = oauth_token.token
29
  out = _push_compiled_graph_to_hub(
30
  compiled_transformer.archive_file,
optimization.py CHANGED
@@ -38,18 +38,20 @@ INDUCTOR_CONFIGS = {
38
  }
39
 
40
 
41
- @spaces.GPU(duration=1500)
42
  def compile_transformer(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
43
- with spaces.aoti_capture(pipeline.transformer) as call:
44
- pipeline(*args, **kwargs)
45
-
46
- dynamic_shapes = tree_map(lambda t: None, call.kwargs)
47
- dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
48
-
49
- exported = torch.export.export(
50
- mod=pipeline.transformer,
51
- args=call.args,
52
- kwargs=call.kwargs,
53
- dynamic_shapes=dynamic_shapes,
54
- )
55
- return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
 
 
 
 
38
  }
39
 
40
 
 
41
  def compile_transformer(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
42
+ @spaces.GPU(duration=1500)
43
+ def f():
44
+ with spaces.aoti_capture(pipeline.transformer) as call:
45
+ pipeline(*args, **kwargs)
46
+
47
+ dynamic_shapes = tree_map(lambda t: None, call.kwargs)
48
+ dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
49
+
50
+ exported = torch.export.export(
51
+ mod=pipeline.transformer,
52
+ args=call.args,
53
+ kwargs=call.kwargs,
54
+ dynamic_shapes=dynamic_shapes,
55
+ )
56
+ return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
57
+ return f