Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| """interior pipeline.ipynb | |
| Automatically generated by Colab. | |
| Original file is located at | |
| https://colab.research.google.com/drive/1Gz-u-NWthPK2XhiaHA1vILalcsYdmpHP | |
| """ | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| from diffusers import StableDiffusionInpaintPipeline, ControlNetModel, AutoencoderKL, DiffusionPipeline, StableDiffusionPipeline | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from PIL import Image | |
| import random | |
| from typing import Union | |
| from PIL import Image, ImageFilter | |
| from torchvision import transforms | |
| import numpy as np # linear algebra | |
| import pandas as pd | |
| from colors import COLOR_MAPPING_,ade_palette | |
| # --- Determine Device --- | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 # Use float16 for GPU, float32 for CPU | |
| print(f"Using device: {DEVICE} with dtype: {DTYPE}") | |
| print("\n---------------------------------------------------------------------\n") | |
| # Load ControlNet for segmentation | |
| controlnet_seg = ControlNetModel.from_pretrained( | |
| "BertChristiaens/controlnet-seg-room", | |
| torch_dtype=DTYPE, | |
| ).to(DEVICE) | |
| # Load ControlNet for depth | |
| controlnet_depth = ControlNetModel.from_pretrained( | |
| "lllyasviel/control_v11f1p_sd15_depth", | |
| torch_dtype=DTYPE, | |
| ).to(DEVICE) | |
| # Load Realistic Vision model for final realism boost | |
| realistic_vision = StableDiffusionInpaintPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-2-1", | |
| torch_dtype=DTYPE, | |
| ).to(DEVICE) | |
| # Load Stable Diffusion Inpainting model | |
| inpaint_model = StableDiffusionInpaintPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-inpainting", | |
| #controlnet=[controlnet_seg, controlnet_depth], | |
| torch_dtype=DTYPE, | |
| ).to(DEVICE) | |
| import torch | |
| import torchvision.transforms as T | |
| from transformers import AutoImageProcessor, AutoModelForDepthEstimation | |
| from transformers import Mask2FormerImageProcessor, Mask2FormerForUniversalSegmentation | |
| # Load Mask2Former | |
| processor = Mask2FormerImageProcessor.from_pretrained("facebook/mask2former-swin-large-ade-semantic") | |
| model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-large-ade-semantic") | |
| model = model.to(DEVICE) | |
| # Load Depth model | |
| depth_image_processor = AutoImageProcessor.from_pretrained("LiheYoung/depth-anything-large-hf", torch_dtype=DTYPE) | |
| depth_model = AutoModelForDepthEstimation.from_pretrained("LiheYoung/depth-anything-large-hf", torch_dtype=DTYPE) | |
| depth_model = depth_model.to(DEVICE) | |
| def to_rgb(color: str) -> tuple: | |
| return tuple(int(color[i:i+2], 16) for i in (1, 3, 5)) | |
| COLOR_MAPPING_RGB = {to_rgb(k): v for k, v in COLOR_MAPPING_.items()} | |
| def map_colors_rgb(color: tuple) -> str: | |
| return COLOR_MAPPING_RGB[color] | |
| def get_segmentation_of_room(image: Image) -> tuple[np.ndarray, Image.Image]: | |
| # Ensure image is valid before processing | |
| if not isinstance(image, Image.Image): | |
| raise TypeError("Input 'image' must be a PIL Image object.") | |
| # Semantic Segmentation | |
| with torch.inference_mode(): | |
| semantic_inputs = processor(images=image, return_tensors="pt", size={"height": 256, "width": 256}) | |
| semantic_inputs = {key: value.to(DEVICE) for key, value in semantic_inputs.items()} | |
| semantic_outputs = model(**semantic_inputs) | |
| # pass through image_processor for postprocessing | |
| segmentation_maps = processor.post_process_semantic_segmentation(semantic_outputs, target_sizes=[image.size[::-1]]) | |
| predicted_semantic_map = segmentation_maps[0] | |
| predicted_semantic_map = predicted_semantic_map.cpu() | |
| color_seg = np.zeros((predicted_semantic_map.shape[0], predicted_semantic_map.shape[1], 3), dtype=np.uint8) | |
| palette = np.array(ade_palette()) | |
| for label, color in enumerate(palette): | |
| color_seg[predicted_semantic_map == label, :] = color | |
| color_seg = color_seg.astype(np.uint8) | |
| seg_image = Image.fromarray(color_seg).convert('RGB') | |
| return color_seg, seg_image | |
| def filter_items(colors_list: Union[list, np.ndarray], items_list: Union[list, np.ndarray], items_to_remove: Union[list, np.ndarray]): | |
| filtered_colors = [] | |
| filtered_items = [] | |
| for color, item in zip(colors_list, items_list): | |
| if item not in items_to_remove: | |
| filtered_colors.append(color) | |
| filtered_items.append(item) | |
| return filtered_colors, filtered_items | |
| def get_inpating_mask(segmentation_mask: np.ndarray) -> Image: | |
| unique_colors = np.unique(segmentation_mask.reshape(-1, segmentation_mask.shape[2]), axis=0) | |
| unique_colors = [tuple(color) for color in unique_colors] | |
| segment_items = [map_colors_rgb(i) for i in unique_colors] | |
| control_items = ["windowpane;window", "door;double;door", "stairs;steps", "escalator;moving;staircase;moving;stairway"] | |
| chosen_colors, segment_items = filter_items(colors_list=unique_colors, items_list=segment_items, items_to_remove=control_items) | |
| mask = np.zeros_like(segmentation_mask) | |
| for color in chosen_colors: | |
| color_matches = (segmentation_mask == color).all(axis=2) | |
| mask[color_matches] = 1 | |
| mask_image = Image.fromarray((mask * 255).astype(np.uint8)).convert("RGB") | |
| # enlarge mask region so that it also will erase the neighborhood of masked stuff | |
| mask_image = mask_image.filter(ImageFilter.MaxFilter(25)) | |
| return mask_image | |
| def get_depth_image(image: Image) -> Image: | |
| image_to_depth = depth_image_processor(images=image, return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| if DEVICE == "cuda" and DTYPE == torch.float16: | |
| with torch.autocast(device_type="cuda", dtype=torch.float16): | |
| print("EPTH map continuous (autocast enabled)") | |
| depth_map = depth_model(**image_to_depth).predicted_depth | |
| print("depth map continous 2") | |
| else: | |
| print("EPTH map continuous (autocast not enabled)") | |
| depth_map = depth_model(**image_to_depth).predicted_depth | |
| print("depth map continous 2") | |
| # depth_map = depth_map.to("cpu") | |
| width, height = image.size | |
| depth_map = torch.nn.functional.interpolate( | |
| depth_map.unsqueeze(1).float(), | |
| size=(height, width), | |
| mode="bicubic", | |
| align_corners=False, | |
| ) | |
| depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) | |
| depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) | |
| depth_map = (depth_map - depth_min) / (depth_max - depth_min) | |
| image = torch.cat([depth_map] * 3, dim=1) | |
| image = image.permute(0, 2, 3, 1).cpu().numpy()[0] | |
| image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8)) | |
| return image | |
| # color_map, seg_mask = get_segmentation_of_room(input_image) | |
| # inpainting_mask = get_inpating_mask(color_map) | |
| # depth_mask = get_depth_image(input_image) | |
| def generate_with_controlnet(image, inpainting_mask, depth_mask, prompt, strength=0.85, seed=None): | |
| generator = torch.manual_seed(42) | |
| # Use both ControlNet models | |
| control_images = [inpainting_mask, depth_mask] | |
| control_nets = [controlnet_seg, controlnet_depth] | |
| generated_image = inpaint_model( | |
| prompt=prompt, | |
| image=image, | |
| mask_image=inpainting_mask, # Inpainting mask | |
| controlnet=control_nets, | |
| controlnet_conditioning_image=control_images, | |
| num_inference_steps=50, | |
| strength=strength, | |
| generator=generator | |
| ).images[0] | |
| return generated_image | |
| import gradio as gr | |
| def process_image_and_prompt(input_image: Image.Image, prompt: str) -> Image.Image: | |
| """ | |
| Main function to process the input image and prompt, generating an output image. | |
| This function orchestrates the segmentation, mask generation, and ControlNet generation. | |
| Args: | |
| input_image (PIL.Image.Image): The input image from Gradio. | |
| prompt (str): The text prompt from Gradio. | |
| Returns: | |
| PIL.Image.Image: The generated output image. | |
| """ | |
| # Ensure the input image is in RGB format | |
| if input_image.mode != 'RGB': | |
| input_image = input_image.convert('RGB') | |
| # Resize input image to a consistent size for processing | |
| input_image = input_image.resize((512, 512)) | |
| # Get segmentation map and inpainting mask | |
| color_map, _ = get_segmentation_of_room(input_image) | |
| inpainting_mask = get_inpating_mask(color_map) | |
| # Get depth mask | |
| depth_mask = get_depth_image(input_image) | |
| # Generate the final image using ControlNet | |
| gen_image = generate_with_controlnet(input_image, inpainting_mask, depth_mask, prompt) | |
| return gen_image | |
| # Create the Gradio interface | |
| iface = gr.Interface( | |
| fn=process_image_and_prompt, | |
| inputs=[ | |
| gr.Image(type="pil", label="Input Image"), | |
| gr.Textbox(label="Prompt", placeholder="Describe the desired interior design...") | |
| ], | |
| outputs=gr.Image(type="pil", label="Generated Image"), | |
| title="Interior Design Generation with ControlNet", | |
| description="Upload an image of a room and provide a text prompt to generate a new interior design. The model will use segmentation and depth information to guide the generation." | |
| ) | |
| # Launch the Gradio app | |
| if __name__ == "__main__": | |
| iface.launch() | |
| # prompt = """A contemporary living room with soft gray walls and a polished hardwood floor. | |
| # A sleek black entertainment unit sits beneath a wall-mounted flat-screen TV on the right side of the image. | |
| # A plush beige sofa faces the TV, with a glass coffee table in the center. | |
| # A large window allows natural light to stream in, while decorative shelves and potted plants add a cozy ambiance." | |
| # gen_image = generate_with_controlnet(input_image, inpainting_mask, depth_mask, prompt) | |
| # from IPython.display import display | |
| # display(gen_image, input_image) |