adrianrm commited on
Commit
66c7512
·
verified ·
1 Parent(s): c936d1f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +23 -13
README.md CHANGED
@@ -51,36 +51,46 @@ At low diffusion times, the model exploits locality properties of natural images
51
  ## Usage
52
 
53
  ```python
54
- from micro_diffusion.models.model import create_latent_diffusion
55
  import torch
 
 
 
56
 
57
-
58
  params = {
59
  'latent_res': 64,
60
  'in_channels': 4,
61
  'pos_interp_scale': 2.0,
62
  }
63
  model = create_latent_diffusion(**params).to('cuda')
64
- checkpoint = torch.load(ckpt_path, map_location='cuda', weights_only=False)
65
- model_dict = checkpoint['state']['model']
66
- # Convert parameters to float32
 
 
 
 
 
 
67
  float_model_params = {
68
- k.replace('dit.', ''): v.to(torch.float32) for k, v in model_dict.items() if 'dit' in k
69
  }
70
  model.dit.load_state_dict(float_model_params)
71
 
 
 
72
 
73
-
74
-
75
-
76
  prompts = [
77
  "Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumet",
78
  "A illustration from a graphic novel. A bustling city street under the shine of a full moon.",
 
 
 
 
 
79
  ]
80
-
81
- model = model.eval()
82
- gen_images = model.generate(prompt=prompts, num_inference_steps=30,
83
- guidance_scale=5.0, seed=42)
84
  ```
85
 
86
  ## Citation
 
51
  ## Usage
52
 
53
  ```python
 
54
  import torch
55
+ from micro_diffusion.models.model import create_latent_diffusion
56
+ from huggingface_hub import hf_hub_download
57
+ from safetensors import safe_open
58
 
59
+ # Init model
60
  params = {
61
  'latent_res': 64,
62
  'in_channels': 4,
63
  'pos_interp_scale': 2.0,
64
  }
65
  model = create_latent_diffusion(**params).to('cuda')
66
+
67
+ # Download weights from HF
68
+ model_dict_path = hf_hub_download(repo_id="giannisdaras/ambient-o", filename="model.safetensors")
69
+ model_dict = {}
70
+ with safe_open(model_dict_path, framework="pt", device="cpu") as f:
71
+ for key in f.keys():
72
+ model_dict[key] = f.get_tensor(key)
73
+
74
+ # Convert parameters to float32 + load
75
  float_model_params = {
76
+ k: v.to(torch.float32) for k, v in model_dict.items()
77
  }
78
  model.dit.load_state_dict(float_model_params)
79
 
80
+ # Eval mode
81
+ model = model.eval()
82
 
83
+ # Generate images
 
 
84
  prompts = [
85
  "Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumet",
86
  "A illustration from a graphic novel. A bustling city street under the shine of a full moon.",
87
+ "A giant cobra snake made from corn",
88
+ "A fierce garden gnome warrior, clad in armor crafted from leaves and bark, brandishes a tiny sword.",
89
+ "A capybara made of lego sitting in a realistic, natural field",
90
+ "a close-up of a fire spitting dragon, cinematic shot.",
91
+ "Panda mad scientist mixing sparkling chemicals, artstation"
92
  ]
93
+ images = model.generate(prompt=prompts, num_inference_steps=30, guidance_scale=5.0, seed=42)
 
 
 
94
  ```
95
 
96
  ## Citation