Spaces:
Runtime error
Runtime error
| from huggingface_hub import hf_hub_download | |
| Rain_Princess = hf_hub_download(repo_id="maze/FastStyleTransfer", filename="Rain_Princess_512.pth") | |
| The_Scream = hf_hub_download(repo_id="maze/FastStyleTransfer", filename="Scream_512.pth") | |
| The_Mosaic = hf_hub_download(repo_id="maze/FastStyleTransfer", filename="Mosaic_512.pth") | |
| Starry_Night = hf_hub_download(repo_id="maze/FastStyleTransfer", filename="Starry_Night_512.pth") | |
| import numpy as np | |
| from PIL import Image | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| class TransformerNetwork(nn.Module): | |
| def __init__(self, tanh_multiplier=None): | |
| super(TransformerNetwork, self).__init__() | |
| self.ConvBlock = nn.Sequential( | |
| ConvLayer(3, 32, 9, 1), | |
| nn.ReLU(), | |
| ConvLayer(32, 64, 3, 2), | |
| nn.ReLU(), | |
| ConvLayer(64, 128, 3, 2), | |
| nn.ReLU() | |
| ) | |
| self.ResidualBlock = nn.Sequential( | |
| ResidualLayer(128, 3), | |
| ResidualLayer(128, 3), | |
| ResidualLayer(128, 3), | |
| ResidualLayer(128, 3), | |
| ResidualLayer(128, 3) | |
| ) | |
| self.DeconvBlock = nn.Sequential( | |
| DeconvLayer(128, 64, 3, 2, 1), | |
| nn.ReLU(), | |
| DeconvLayer(64, 32, 3, 2, 1), | |
| nn.ReLU(), | |
| ConvLayer(32, 3, 9, 1, norm="None") | |
| ) | |
| self.tanh_multiplier = tanh_multiplier | |
| def forward(self, x): | |
| x = self.ConvBlock(x) | |
| x = self.ResidualBlock(x) | |
| x = self.DeconvBlock(x) | |
| if isinstance(self.tanh_multiplier, int): | |
| x = self.tanh_multiplier * F.tanh(x) | |
| return x | |
| class ConvLayer(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, stride, norm="instance"): | |
| super(ConvLayer, self).__init__() | |
| padding_size = kernel_size // 2 | |
| self.pad = nn.ReflectionPad2d(padding_size) | |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride) | |
| if norm == "instance": | |
| self.norm = nn.InstanceNorm2d(out_channels, affine=True) | |
| elif norm == "batch": | |
| self.norm = nn.BatchNorm2d(out_channels, affine=True) | |
| else: | |
| self.norm = nn.Identity() | |
| def forward(self, x): | |
| x = self.pad(x) | |
| x = self.conv(x) | |
| x = self.norm(x) | |
| return x | |
| class ResidualLayer(nn.Module): | |
| def __init__(self, channels=128, kernel_size=3): | |
| super(ResidualLayer, self).__init__() | |
| self.conv1 = ConvLayer(channels, channels, kernel_size, stride=1) | |
| self.relu = nn.ReLU() | |
| self.conv2 = ConvLayer(channels, channels, kernel_size, stride=1) | |
| def forward(self, x): | |
| identity = x | |
| out = self.relu(self.conv1(x)) | |
| out = self.conv2(out) | |
| out = out + identity | |
| return out | |
| class DeconvLayer(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, stride, output_padding, norm="instance"): | |
| super(DeconvLayer, self).__init__() | |
| padding_size = kernel_size // 2 | |
| self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding_size, output_padding) | |
| if norm == "instance": | |
| self.norm = nn.InstanceNorm2d(out_channels, affine=True) | |
| elif norm == "batch": | |
| self.norm = nn.BatchNorm2d(out_channels, affine=True) | |
| else: | |
| self.norm = nn.Identity() | |
| def forward(self, x): | |
| x = self.conv_transpose(x) | |
| out = self.norm(x) | |
| return out | |
| mean = np.array([0.485, 0.456, 0.406]) | |
| std = np.array([0.229, 0.224, 0.225]) | |
| transformer = TransformerNetwork().to(device) | |
| transformer.eval() | |
| transform = transforms.Compose([ | |
| transforms.Resize(512), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean, std), | |
| ]) | |
| denormalize = transforms.Normalize( | |
| mean= [-m/s for m, s in zip(mean, std)], | |
| std= [1/s for s in std] | |
| ) | |
| tensor2Image = transforms.ToPILImage() | |
| def process(image, model): | |
| image = transform(image).to(device) | |
| image = image.unsqueeze(dim=0) | |
| image = denormalize(model(image)).cpu() | |
| image = torch.clamp(image.squeeze(dim=0), 0, 1) | |
| image = tensor2Image(image) | |
| return image | |
| def main(image, backbone, style): | |
| if style == "The Scream": | |
| transformer.load_state_dict(torch.load(The_Scream, map_location=torch.device('cpu'))) | |
| elif style == "Rain Princess": | |
| transformer.load_state_dict(torch.load(Rain_Princess, map_location=torch.device('cpu'))) | |
| elif style == "The Mosaic": | |
| transformer.load_state_dict(torch.load(The_Mosaic, map_location=torch.device('cpu'))) | |
| elif style == "Starry Night": | |
| transformer.load_state_dict(torch.load(Starry_Night, map_location=torch.device('cpu'))) | |
| else: | |
| transformer.load_state_dict(torch.load(Rain_Princess, map_location=torch.device('cpu'))) | |
| image = Image.fromarray(image) | |
| isize = image.size | |
| image = process(image, transformer) | |
| s = f"The output image {str(image.size)} is processed by {backbone} based on input image {str(isize)}. <br> Please <b>rate</b> the generated image through the <b>Flag</b> button below!" | |
| print(s) | |
| return image, s | |
| # "Standard ResNet50"οΌ "VGG19" | |
| gr.Interface( | |
| title = "Stylize", | |
| description = "Image generated based on Fast Style Transfer", | |
| fn = main, | |
| inputs = [ | |
| gr.inputs.Image(), | |
| gr.inputs.Radio(["Robust ResNet50"], label="Backbone"), | |
| gr.inputs.Dropdown(["The Scream", "Rain Princess", "Starry Night", "The Mosaic"], type="value", default="Rain Princess", label="style") | |
| ], | |
| outputs = [gr.outputs.Image(label="Stylized"), gr.outputs.HTML(label="Comment")], | |
| # examples = [ | |
| # [] | |
| # ], | |
| # live = True, # the interface will recalculate as soon as the user input changes. | |
| allow_flagging = "manual", | |
| flagging_options = ["Excellect", "Moderate", "Bad"], | |
| flagging_dir = "flagged", | |
| allow_screenshot = False, | |
| ).launch() | |
| # iface.launch(enable_queue=True, cache_examples=True, debug=True) |