Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| import os | |
| import torch | |
| from diffusers import AutoencoderKL | |
| from torch import nn | |
| from torch.optim import Adam | |
| from utils import load_image, save_image | |
| def main(args): | |
| os.makedirs(args.out_dir, exist_ok=True) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| vae = AutoencoderKL.from_pretrained(args.vae_model_path).to( | |
| device, dtype=torch.float32 | |
| ) | |
| vae.requires_grad_(False) | |
| image = load_image(args.image_path, size=(512, 512)).to(device, dtype=torch.float32) | |
| image = image * 2 - 1 | |
| save_image(image / 2 + 0.5, f"{args.out_dir}/ori_image.png") | |
| latents = vae.encode(image)["latent_dist"].mean | |
| save_image(latents, f"{args.out_dir}/latents.png") | |
| rec_image = vae.decode(latents, return_dict=False)[0] | |
| save_image(rec_image / 2 + 0.5, f"{args.out_dir}/rec_image.png") | |
| for param in vae.decoder.parameters(): | |
| param.requires_grad = True | |
| loss_fn = nn.L1Loss() | |
| optimizer = Adam(vae.decoder.parameters(), lr=args.learning_rate) | |
| # Training loop | |
| for epoch in range(args.num_epochs): | |
| reconstructed = vae.decode(latents, return_dict=False)[0] | |
| loss = loss_fn(reconstructed, image) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| print(f"Epoch {epoch+1}/{args.num_epochs}, Loss: {loss.item()}") | |
| rec_image = vae.decode(latents, return_dict=False)[0] | |
| save_image(rec_image / 2 + 0.5, f"{args.out_dir}/trained_rec_image.png") | |
| vae.save_pretrained( | |
| f"{args.out_dir}/trained_vae_{os.path.basename(args.image_path)}" | |
| ) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="Train a VAE with given image and settings." | |
| ) | |
| # Add arguments | |
| parser.add_argument( | |
| "--out_dir", | |
| type=str, | |
| default="./trained_vae/", | |
| help="Output directory to save results", | |
| ) | |
| parser.add_argument( | |
| "--vae_model_path", | |
| type=str, | |
| required=True, | |
| help="Path to the pretrained VAE model", | |
| ) | |
| parser.add_argument( | |
| "--image_path", type=str, required=True, help="Path to the input image" | |
| ) | |
| parser.add_argument( | |
| "--learning_rate", | |
| type=float, | |
| default=1e-4, | |
| help="Learning rate for the optimizer", | |
| ) | |
| parser.add_argument( | |
| "--num_epochs", type=int, default=75, help="Number of training epochs" | |
| ) | |
| args = parser.parse_args() | |
| main(args) | |