import os import gradio as gr import torch from diffusers.utils import load_image, check_min_version from controlnet_flux import FluxControlNetModel from transformer_flux import FluxTransformer2DModel from pipeline_flux_cnet import FluxControlNetInpaintingPipeline from PIL import Image, ImageDraw import numpy as np HF_TOKEN = os.getenv("HF_TOKEN") # Ensure that the minimal version of diffusers is installed check_min_version("0.30.2") # Build pipeline controlnet = FluxControlNetModel.from_pretrained( "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", torch_dtype=torch.bfloat16, token=HF_TOKEN ) transformer = FluxTransformer2DModel.from_pretrained( "black-forest-labs/FLUX.1-dev", subfolder='transformer', torch_dtype=torch.bfloat16, token=HF_TOKEN ) pipe = FluxControlNetInpaintingPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", controlnet=controlnet, transformer=transformer, torch_dtype=torch.bfloat16, token=HF_TOKEN ).to("cuda") pipe.transformer.to(torch.bfloat16) pipe.controlnet.to(torch.bfloat16) def create_mask_from_editor(editor_value): """ Create a mask from the ImageEditor value. Args: editor_value: Dictionary from EditorValue with 'background', 'layers', and 'composite' Returns: PIL Image with white mask """ # The 'composite' key contains the final image with all layers applied composite_image = editor_value['composite'] # Convert to numpy array composite_array = np.array(composite_image) # Create mask where the composite image is white mask_array = np.all(composite_array == (255, 255, 255), axis=-1).astype(np.uint8) * 255 mask_image = Image.fromarray(mask_array) return mask_image def create_diptych_image(image, mask): # Create a diptych image with original on left and masked on right width, height = image.size diptych = Image.new('RGB', (width * 2, height), 'black') diptych.paste(image, (0, 0)) diptych.paste(mask, (width, 0)) return diptych @spaces.GPU() def inpaint_image(image, prompt, editor_value): # Create mask from editor value mask = create_mask_from_editor(editor_value) # Load and preprocess image image = image.convert("RGB").resize((768, 768)) mask = mask.convert("L").resize((768, 768)) # Convert mask to single channel (grayscale) # Create diptych image diptych_image = create_diptych_image(image, mask) # Preprocess prompt and image for the pipeline prompt = pipe.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).input_ids.to("cuda") image_tensor = pipe.feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda") mask_tensor = pipe.feature_extractor(images=mask, return_tensors="pt").pixel_values.to("cuda") control_image_tensor = pipe.feature_extractor(images=diptych_image, return_tensors="pt").pixel_values.to("cuda") generator = torch.Generator(device="cuda").manual_seed(24) # Calculate attention scale mask attn_scale_factor = 1.5 size = (1536, 768) H, W = size[1] // 16, size[0] // 16 attn_scale_mask = torch.zeros(size[1], size[0]) attn_scale_mask[:, 768:] = 1.0 # height, width attn_scale_mask = torch.nn.functional.interpolate(attn_scale_mask[None, None, :, :], (H, W), mode='nearest-exact').flatten() attn_scale_mask = attn_scale_mask[None, None, :, None].repeat(1, 24, 1, H*W) transposed_inverted_attn_scale_mask = (1.0 - attn_scale_mask).transpose(-1, -2) cross_attn_region = torch.logical_and(attn_scale_mask, transposed_inverted_attn_scale_mask) cross_attn_region = cross_attn_region * attn_scale_factor cross_attn_region[cross_attn_region < 1.0] = 1.0 full_attn_scale_mask = torch.ones(1, 24, 512+H*W, 512+H*W) full_attn_scale_mask[:, :, 512:, 512:] = cross_attn_region full_attn_scale_mask = full_attn_scale_mask.to(device=pipe.transformer.device, dtype=torch.bfloat16) # Inpaint result = pipe( prompt=prompt, height=size[1], width=size[0], control_image=control_image_tensor, control_mask=mask_tensor, num_inference_steps=20, generator=generator, controlnet_conditioning_scale=0.95, guidance_scale=3.5, negative_prompt="", true_guidance_scale=1.0, attn_scale_mask=full_attn_scale_mask, ).images[0] return result, diptych_image # Create Gradio interface iface = gr.Interface( fn=inpaint_image, inputs=[ gr.Image(type="pil", label="Upload Image"), gr.Textbox(lines=1, placeholder="Enter your prompt here (e.g., 'wearing a christmas hat, in a busy street')", label="Prompt"), gr.ImageEditor(type="pil", label="Image with Mask", sources="upload", interactive=True) ], outputs=[ gr.Image(type="pil", label="Inpainted Image"), gr.Image(type="pil", label="Diptych Image") ], title="FLUX Inpainting with Diptych Prompting", description="Upload an image, specify a prompt, and draw a mask on the image. The app will automatically generate the inpainted image." ) # Launch the app iface.launch()