Spaces:
Build error
Build error
| import torch | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import gradio as gr | |
| from tqdm import tqdm | |
| def optimize_latent_vector(G, target_image, num_iterations=1000): | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| target_image = transforms.Resize((G.img_resolution, G.img_resolution))(target_image) | |
| target_tensor = transforms.ToTensor()(target_image).unsqueeze(0).to(device) | |
| target_tensor = (target_tensor * 2) - 1 # Normalize to [-1, 1] | |
| latent_vector = torch.randn((1, G.z_dim), device=device, requires_grad=True) | |
| optimizer = torch.optim.Adam([latent_vector], lr=0.1) | |
| for i in tqdm(range(num_iterations), desc="Optimizing latent vector"): | |
| optimizer.zero_grad() | |
| generated_image = G(latent_vector, None) | |
| loss = torch.nn.functional.mse_loss(generated_image, target_tensor) | |
| loss.backward() | |
| optimizer.step() | |
| if (i + 1) % 100 == 0: | |
| print(f'Iteration {i+1}/{num_iterations}, Loss: {loss.item()}') | |
| return latent_vector.detach() | |
| def generate_from_upload(uploaded_image): | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Optimize latent vector for the uploaded image | |
| optimized_z = optimize_latent_vector(G, uploaded_image) | |
| # Generate variations | |
| num_variations = 4 | |
| variation_strength = 0.1 | |
| varied_z = optimized_z + torch.randn((num_variations, G.z_dim), device=device) * variation_strength | |
| # Generate the variations | |
| with torch.no_grad(): | |
| imgs = G(varied_z, c=None, truncation_psi=0.7, noise_mode='const') | |
| imgs = (imgs * 127.5 + 128).clamp(0, 255).to(torch.uint8) | |
| imgs = imgs.permute(0, 2, 3, 1).cpu().numpy() | |
| # Convert the generated image tensors to PIL Images | |
| generated_images = [Image.fromarray(img) for img in imgs] | |
| # Return the images separately | |
| return generated_images[0], generated_images[1], generated_images[2], generated_images[3] | |
| # Create the Gradio interface | |
| iface = gr.Interface( | |
| fn=generate_from_upload, | |
| inputs=gr.Image(type="pil"), | |
| outputs=[gr.Image(type="pil") for _ in range(4)], | |
| title="StyleGAN Image Variation Generator" | |
| ) | |
| # Launch the Gradio interface | |
| iface.launch(share=True, debug=True) | |