Spaces:
Runtime error
Runtime error
| import torch | |
| import gradio as gr | |
| from torchvision import transforms | |
| from PIL import Image | |
| import numpy as np | |
| from utils.utils import load_restore_ckpt, load_embedder_ckpt | |
| import os | |
| from gradio_imageslider import ImageSlider | |
| # Enforce CPU usage | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| embedder_model_path = "ckpts/embedder_model.tar" # Update with actual path to embedder checkpoint | |
| restorer_model_path = "ckpts/onerestore_cdd-11.tar" # Update with actual path to restorer checkpoint | |
| # Load models on CPU only | |
| embedder = load_embedder_ckpt(device, freeze_model=True, ckpt_name=embedder_model_path) | |
| restorer = load_restore_ckpt(device, freeze_model=True, ckpt_name=restorer_model_path) | |
| # Define image preprocessing and postprocessing | |
| transform_resize = transforms.Compose([ | |
| transforms.Resize([224,224]), | |
| transforms.ToTensor() | |
| ]) | |
| def postprocess_image(tensor): | |
| image = tensor.squeeze(0).cpu().detach().numpy() | |
| image = (image) * 255 # Assuming output in [-1, 1], rescale to [0, 255] | |
| image = np.clip(image, 0, 255).astype("uint8") # Clip values to [0, 255] | |
| return Image.fromarray(image.transpose(1, 2, 0)) # Reorder to (H, W, C) | |
| # Define the enhancement function | |
| def enhance_image(image, degradation_type=None): | |
| # Preprocess the image | |
| input_tensor = torch.Tensor((np.array(image)/255).transpose(2, 0, 1)).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu") | |
| lq_em = transform_resize(image).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu") | |
| lq_em = transform_resize(image).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu") | |
| # Generate embedding | |
| if degradation_type == "auto" or degradation_type is None: | |
| text_embedding, _, [text] = embedder(lq_em, 'image_encoder') | |
| else: | |
| text_embedding, _, [text] = embedder([degradation_type], 'text_encoder') | |
| # Model inference | |
| with torch.no_grad(): | |
| enhanced_tensor = restorer(input_tensor, text_embedding) | |
| # Postprocess the output | |
| return (image, postprocess_image(enhanced_tensor)), text | |
| # Define the Gradio interface | |
| def inference(image, degradation_type=None): | |
| return enhance_image(image, degradation_type) | |
| #### Image,Prompts examples | |
| examples = [ | |
| ['image/low_haze_rain_00469_01_lq.png'], | |
| ['image/low_haze_snow_00337_01_lq.png'], | |
| ] | |
| # Create the Gradio app interface using updated API | |
| interface = gr.Interface( | |
| fn=inference, | |
| inputs=[ | |
| gr.Image(type="pil", value="image/low_haze_rain_00469_01_lq.png"), # Image input | |
| gr.Dropdown(['auto', 'low', 'haze', 'rain', 'snow',\ | |
| 'low_haze', 'low_rain', 'low_snow', 'haze_rain',\ | |
| 'haze_snow', 'low_haze_rain', 'low_haze_snow'], label="Degradation Type", value="auto") # Manual or auto degradation | |
| ], | |
| outputs=[ | |
| ImageSlider(label="Restored Image", | |
| type="pil", | |
| show_download_button=True, | |
| ), # Enhanced image outputImageSlider(type="pil", show_download_button=True, ), | |
| gr.Textbox(label="Degradation Type") # Display the estimated degradation type | |
| ], | |
| title="Image Restoration with OneRestore", | |
| description="Upload an image and enhance it using OneRestore model. You can choose to let the model automatically estimate the degradation type or set it manually.", | |
| examples=examples, | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| interface.launch() | |