Alexander Bagus commited on
Commit
af13ace
·
1 Parent(s): e8db9a2
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -27,13 +27,13 @@ weight_dtype = torch.bfloat16
27
  transformer = ZImageControlTransformer2DModel.from_pretrained(
28
  MODEL_LOCAL,
29
  subfolder="transformer",
30
- low_cpu_mem_usage=False,
31
- dtype=torch.bfloat16,
32
  transformer_additional_kwargs={
33
  "control_layers_places": [0, 5, 10, 15, 20, 25],
34
  "control_in_dim": 16
35
  },
36
- ).to(torch.bfloat16)
37
 
38
  if TRANSFORMER_LOCAL is not None:
39
  print(f"From checkpoint: {TRANSFORMER_LOCAL}")
@@ -69,8 +69,9 @@ pipe = ZImageControlPipeline(
69
  transformer=transformer,
70
  scheduler=scheduler,
71
  )
72
- pipe.transformer = transformer
73
- pipe.to("cuda")
 
74
 
75
  # ======== AoTI compilation + FA3 ========
76
  pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
 
27
  transformer = ZImageControlTransformer2DModel.from_pretrained(
28
  MODEL_LOCAL,
29
  subfolder="transformer",
30
+ # low_cpu_mem_usage=False,
31
+ # torch_dtype=torch.bfloat16,
32
  transformer_additional_kwargs={
33
  "control_layers_places": [0, 5, 10, 15, 20, 25],
34
  "control_in_dim": 16
35
  },
36
+ ).to("cuda", torch.bfloat16)
37
 
38
  if TRANSFORMER_LOCAL is not None:
39
  print(f"From checkpoint: {TRANSFORMER_LOCAL}")
 
69
  transformer=transformer,
70
  scheduler=scheduler,
71
  )
72
+ pipe.to("cuda", torch.bfloat16)
73
+ # pipe.transformer = transformer
74
+ # pipe.to("cuda")
75
 
76
  # ======== AoTI compilation + FA3 ========
77
  pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]