| | import os |
| | import requests |
| | import tarfile |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.utils.data import Dataset, DataLoader |
| | from torchvision import transforms |
| | from PIL import Image |
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | import json |
| | import math |
| | from tqdm import tqdm |
| | from transformers import BertTokenizer, BertModel |
| | import gradio as gr |
| |
|
| | |
| | class Config: |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | image_size = 64 |
| | batch_size = 32 |
| | num_epochs = 50 |
| | learning_rate = 1e-4 |
| | timesteps = 1000 |
| | text_embed_dim = 768 |
| | num_images_options = [1, 4, 6] |
| | |
| | |
| | coco_images_url = "http://images.cocodataset.org/zips/train2017.zip" |
| | coco_annotations_url = "http://images.cocodataset.org/annotations/annotations_trainval2017.zip" |
| | data_dir = "./coco_data" |
| | images_dir = os.path.join(data_dir, "train2017") |
| | annotations_path = os.path.join(data_dir, "annotations/instances_train2017.json") |
| | |
| | def __init__(self): |
| | os.makedirs(self.data_dir, exist_ok=True) |
| |
|
| | config = Config() |
| |
|
| | |
| | def download_and_extract_coco(): |
| | if os.path.exists(config.images_dir) and os.path.exists(config.annotations_path): |
| | print("COCO dataset already downloaded") |
| | return |
| | |
| | print("Downloading COCO dataset...") |
| | |
| | |
| | images_zip_path = os.path.join(config.data_dir, "train2017.zip") |
| | if not os.path.exists(images_zip_path): |
| | response = requests.get(config.coco_images_url, stream=True) |
| | with open(images_zip_path, "wb") as f: |
| | for chunk in tqdm(response.iter_content(chunk_size=1024)): |
| | if chunk: |
| | f.write(chunk) |
| | |
| | |
| | annotations_zip_path = os.path.join(config.data_dir, "annotations_trainval2017.zip") |
| | if not os.path.exists(annotations_zip_path): |
| | response = requests.get(config.coco_annotations_url, stream=True) |
| | with open(annotations_zip_path, "wb") as f: |
| | for chunk in tqdm(response.iter_content(chunk_size=1024)): |
| | if chunk: |
| | f.write(chunk) |
| | |
| | |
| | print("Extracting images...") |
| | with tarfile.open(images_zip_path, "r:zip") as tar: |
| | tar.extractall(config.data_dir) |
| | |
| | print("Extracting annotations...") |
| | with tarfile.open(annotations_zip_path, "r:zip") as tar: |
| | tar.extractall(config.data_dir) |
| | |
| | print("COCO dataset ready") |
| |
|
| | download_and_extract_coco() |
| |
|
| | |
| | class TextEncoder(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
| | self.model = BertModel.from_pretrained('bert-base-uncased') |
| | for param in self.model.parameters(): |
| | param.requires_grad = False |
| | |
| | def forward(self, texts): |
| | inputs = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=64) |
| | inputs = {k: v.to(config.device) for k, v in inputs.items()} |
| | outputs = self.model(**inputs) |
| | return outputs.last_hidden_state[:, 0, :] |
| |
|
| | text_encoder = TextEncoder().to(config.device) |
| |
|
| | |
| | class ConditionalUNet(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) |
| | self.down1 = DownBlock(64, 128) |
| | self.down2 = DownBlock(128, 256) |
| | |
| | self.text_proj = nn.Linear(config.text_embed_dim, 256) |
| | self.merge = nn.Linear(256 + 256, 256) |
| | |
| | self.up1 = UpBlock(256, 128) |
| | self.up2 = UpBlock(128, 64) |
| | self.final = nn.Conv2d(64, 3, kernel_size=3, padding=1) |
| | |
| | def forward(self, x, t, text_emb): |
| | x1 = F.relu(self.conv1(x)) |
| | x2 = self.down1(x1) |
| | x3 = self.down2(x2) |
| | |
| | text_emb = self.text_proj(text_emb) |
| | text_emb = text_emb.unsqueeze(-1).unsqueeze(-1) |
| | text_emb = text_emb.expand(-1, -1, x3.size(2), x3.size(3)) |
| | |
| | x = torch.cat([x3, text_emb], dim=1) |
| | b, c, h, w = x.shape |
| | x = x.permute(0, 2, 3, 1).reshape(b*h*w, c) |
| | x = self.merge(x) |
| | x = x.reshape(b, h, w, 256).permute(0, 3, 1, 2) |
| | |
| | x = self.up1(x) |
| | x = self.up2(x) |
| | return self.final(x) |
| |
|
| | class DownBlock(nn.Module): |
| | def __init__(self, in_ch, out_ch): |
| | super().__init__() |
| | self.conv = nn.Sequential( |
| | nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(out_ch), |
| | nn.ReLU(), |
| | nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(out_ch), |
| | nn.ReLU(), |
| | nn.MaxPool2d(2) |
| | ) |
| | |
| | def forward(self, x): |
| | return self.conv(x) |
| |
|
| | class UpBlock(nn.Module): |
| | def __init__(self, in_ch, out_ch): |
| | super().__init__() |
| | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) |
| | self.conv = nn.Sequential( |
| | nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(out_ch), |
| | nn.ReLU(), |
| | nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(out_ch), |
| | nn.ReLU() |
| | ) |
| | |
| | def forward(self, x): |
| | x = self.up(x) |
| | return self.conv(x) |
| |
|
| | |
| | betas = linear_beta_schedule(config.timesteps).to(config.device) |
| | alphas = 1. - betas |
| | alphas_cumprod = torch.cumprod(alphas, dim=0) |
| | sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) |
| | sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) |
| |
|
| | def linear_beta_schedule(timesteps): |
| | beta_start = 0.0001 |
| | beta_end = 0.02 |
| | return torch.linspace(beta_start, beta_end, timesteps) |
| |
|
| | def forward_diffusion_sample(x_0, t, device=config.device): |
| | noise = torch.randn_like(x_0) |
| | sqrt_alphas_cumprod_t = sqrt_alphas_cumprod[t].view(-1, 1, 1, 1) |
| | sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1) |
| | return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise |
| |
|
| | |
| | class CocoDataset(Dataset): |
| | def __init__(self, root_dir, annotations_file, transform=None): |
| | self.root_dir = root_dir |
| | self.transform = transform |
| | |
| | with open(annotations_file, 'r') as f: |
| | data = json.load(f) |
| | |
| | self.images = [] |
| | self.captions = [] |
| | |
| | image_id_to_captions = {} |
| | for ann in data['annotations']: |
| | if ann['image_id'] not in image_id_to_captions: |
| | image_id_to_captions[ann['image_id']] = [] |
| | image_id_to_captions[ann['image_id']].append(ann['caption']) |
| | |
| | for img in data['images']: |
| | if img['id'] in image_id_to_captions: |
| | self.images.append(img) |
| | self.captions.append(image_id_to_captions[img['id']][0]) |
| | |
| | def __len__(self): |
| | return len(self.images) |
| | |
| | def __getitem__(self, idx): |
| | img_path = os.path.join(self.root_dir, self.images[idx]['file_name']) |
| | image = Image.open(img_path).convert('RGB') |
| | caption = self.captions[idx] |
| | |
| | if self.transform: |
| | image = self.transform(image) |
| | |
| | return image, caption |
| |
|
| | |
| | transform = transforms.Compose([ |
| | transforms.Resize((config.image_size, config.image_size)), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
| | ]) |
| |
|
| | |
| | model = ConditionalUNet().to(config.device) |
| | optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate) |
| |
|
| | |
| | def train(): |
| | dataset = CocoDataset(config.images_dir, config.annotations_path, transform) |
| | dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True) |
| | |
| | for epoch in range(config.num_epochs): |
| | for batch_idx, (images, captions) in enumerate(tqdm(dataloader)): |
| | images = images.to(config.device) |
| | |
| | |
| | text_emb = text_encoder(captions) |
| | |
| | |
| | t = torch.randint(0, config.timesteps, (images.size(0),), device=config.device) |
| | |
| | |
| | x_noisy, noise = forward_diffusion_sample(images, t) |
| | |
| | |
| | pred_noise = model(x_noisy, t, text_emb) |
| | |
| | |
| | loss = F.mse_loss(pred_noise, noise) |
| | optimizer.zero_grad() |
| | loss.backward() |
| | optimizer.step() |
| | |
| | if batch_idx % 100 == 0: |
| | print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}") |
| | |
| | |
| | torch.save(model.state_dict(), f"model_epoch_{epoch}.pth") |
| |
|
| | |
| | @torch.no_grad() |
| | def generate(prompt, num_images=1): |
| | model.eval() |
| | num_images = int(num_images) |
| | |
| | text_emb = text_encoder([prompt]*num_images) |
| | x = torch.randn((num_images, 3, config.image_size, config.image_size)).to(config.device) |
| | |
| | for t in reversed(range(config.timesteps)): |
| | t_tensor = torch.full((num_images,), t, device=config.device) |
| | pred_noise = model(x, t_tensor, text_emb) |
| | alpha_t = alphas[t].view(1, 1, 1, 1) |
| | alpha_cumprod_t = alphas_cumprod[t].view(1, 1, 1, 1) |
| | beta_t = betas[t].view(1, 1, 1, 1) |
| | |
| | if t > 0: |
| | noise = torch.randn_like(x) |
| | else: |
| | noise = torch.zeros_like(x) |
| | |
| | x = (1 / torch.sqrt(alpha_t)) * ( |
| | x - ((1 - alpha_t) / torch.sqrt(1 - alpha_cumprod_t)) * pred_noise |
| | ) + torch.sqrt(beta_t) * noise |
| | |
| | x = torch.clamp(x, -1, 1) |
| | x = (x + 1) / 2 |
| | |
| | images = [] |
| | for img in x: |
| | img = transforms.ToPILImage()(img.cpu()) |
| | images.append(img) |
| | |
| | return images |
| |
|
| | |
| | def generate_and_display(prompt, num_images): |
| | images = generate(prompt, num_images) |
| | |
| | fig, axes = plt.subplots(1, len(images), figsize=(5*len(images), 5)) |
| | if len(images) == 1: |
| | axes.imshow(images[0]) |
| | axes.axis('off') |
| | else: |
| | for ax, img in zip(axes, images): |
| | ax.imshow(img) |
| | ax.axis('off') |
| | plt.tight_layout() |
| | return fig |
| |
|
| | with gr.Blocks() as demo: |
| | gr.Markdown("## GPUDiff-V1: diffussion powerful image generator!") |
| | with gr.Row(): |
| | prompt_input = gr.Textbox(label="Prompt", placeholder="Enter image description...") |
| | num_select = gr.Dropdown(choices=config.num_images_options, value=1, label="Number of images") |
| | generate_btn = gr.Button("Generate") |
| | output = gr.Plot() |
| | |
| | generate_btn.click( |
| | fn=generate_and_display, |
| | inputs=[prompt_input, num_select], |
| | outputs=output |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | |
| | train() |
| | |
| | demo.launch() |