File size: 4,498 Bytes
36c749c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
from diffusers import DiffusionPipeline
import torch
import os

# Ensure necessary libraries are installed
# pip install diffusers --upgrade
# pip install invisible_watermark transformers accelerate safetensors gradio torch

model_id = "stabilityai/stable-diffusion-xl-base-1.0"

# Determine device and dtype
if torch.cuda.is_available():
    device = "cuda"
    dtype = torch.float16
    print("Using CUDA (GPU).")
# elif torch.backends.mps.is_available(): # Uncomment for MacOS Metal support
#     device = "mps"
#     dtype = torch.float16
#     print("Using MPS (Apple Silicon GPU).")
else:
    device = "cpu"
    dtype = torch.float32
    print("Using CPU.")

# Load the Stable Diffusion XL pipeline
# Using float16 and safetensors for efficiency if on GPU
# variant="fp16" loads the fp16 weights
try:
    pipe = DiffusionPipeline.from_pretrained(
        model_id,
        torch_dtype=dtype,
        use_safetensors=True,
        variant="fp16" if device!= "cpu" else None # Only use fp16 variant if not on CPU
    )
    pipe.to(device)

    # Optional: Enable CPU offloading if VRAM is limited (only works on CUDA)
    if device == "cuda":
         try:
            # Check VRAM - this is a rough estimate, adjust threshold as needed
            total_vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
            if total_vram_gb < 10: # Example threshold: less than 10GB VRAM
                 print(f"Low VRAM ({total_vram_gb:.2f}GB detected). Enabling model CPU offload.")
                 pipe.enable_model_cpu_offload()
         except Exception as offload_err:
            print(f"Could not check VRAM or enable offload: {offload_err}")


    # Optional: Use torch.compile for speedup (requires torch >= 2.0)
    # if device!= "cpu" and hasattr(torch, "compile"):
    #     try:
    #         print("Attempting to compile the UNet...")
    #         pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
    #         print("UNet compiled successfully.")
    #     except Exception as compile_err:
    #         print(f"Torch compile failed: {compile_err}")

    print(f"SDXL pipeline loaded successfully on {device}.")

except Exception as e:
    print(f"Error loading SDXL pipeline: {e}")
    pipe = None

def generate_image(prompt):
    """Generates an image based on the text prompt."""
    if pipe is None:
        # Handle case where pipeline failed to load
        # Create a placeholder image or return an error message
        from PIL import Image, ImageDraw, ImageFont
        img = Image.new('RGB', (512, 512), color = (200, 200, 200))
        d = ImageDraw.Draw(img)
        try:
            # Try to load a default font
            fnt = ImageFont.truetype("arial.ttf", 15)
        except IOError:
            fnt = ImageFont.load_default()
        d.text((10,10), "Error: Model pipeline failed to load.", fill=(255,0,0), font=fnt)
        return img

    if not prompt:
        return None # Return nothing if prompt is empty

    print(f"Generating image for prompt: '{prompt}'")
    try:
        # Generate the image
        # Using default steps/guidance scale, can be customized
        with torch.inference_mode(): # Use inference mode for efficiency
            image = pipe(prompt=prompt, num_inference_steps=30).images
        print("Image generated successfully.")
        return image
    except Exception as e:
        print(f"Error during image generation: {e}")
        # Return an error image or message
        from PIL import Image, ImageDraw, ImageFont
        img = Image.new('RGB', (512, 512), color = (200, 200, 200))
        d = ImageDraw.Draw(img)
        try: fnt = ImageFont.truetype("arial.ttf", 15)
        except IOError: fnt = ImageFont.load_default()
        d.text((10,10), f"Error generating image:\n{e}", fill=(255,0,0), font=fnt)
        return img


# Create the Gradio interface
demo = gr.Interface(
    fn=generate_image,
    inputs=gr.Textbox(label="Enter Text Prompt", placeholder="e.g., 'An astronaut riding a green horse'"),
    outputs=gr.Image(label="Generated Image", type="pil"),
    title="Text-to-Image Generation with Stable Diffusion XL",
    description=f"Generate images from text prompts using the {model_id} model. Loading and inference might take a moment, especially on the first run or on CPU.",
     examples=["A high-tech cityscape at sunset, cinematic lighting"]
)

if __name__ == "__main__":
    # Launch the Gradio app
    demo.launch(debug=True)