fluxhdupscaler / app.py
comrender's picture
Update app.py
9fa71c2 verified
raw
history blame
14.1 kB
import logging
import random
import warnings
import os
import gradio as gr
import numpy as np
import spaces
import torch
from gradio_imageslider import ImageSlider
from PIL import Image
import requests
import sys
import subprocess
from huggingface_hub import hf_hub_download
import tempfile
os.environ["GIT_TERMINAL_PROMPT"] = "0"
# Setup ComfyUI and custom nodes
if not os.path.exists("ComfyUI"):
subprocess.run(["git", "clone", "https://github.com/comfyanonymous/ComfyUI"])
custom_nodes_dir = os.path.join("ComfyUI", "custom_nodes")
os.makedirs(custom_nodes_dir, exist_ok=True)
# Clone UltimateSDUpscale
usd_dir = os.path.join(custom_nodes_dir, "ComfyUI_UltimateSDUpscale")
if not os.path.exists(usd_dir):
subprocess.run(["git", "clone", "https://github.com/ssitu/ComfyUI_UltimateSDUpscale", usd_dir])
# Clone comfy_mtb
mtb_dir = os.path.join(custom_nodes_dir, "comfy_mtb")
if not os.path.exists(mtb_dir):
subprocess.run(["git", "clone", "https://github.com/melMass/comfy_mtb", mtb_dir])
# Install requirements
if os.path.exists(os.path.join(mtb_dir, "requirements.txt")):
subprocess.run([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"], cwd=mtb_dir)
# Clone KJNodes
kjn_dir = os.path.join(custom_nodes_dir, "ComfyUI-KJNodes")
if not os.path.exists(kjn_dir):
subprocess.run(["git", "clone", "https://github.com/kijai/ComfyUI-KJNodes", kjn_dir])
# Install requirements
if os.path.exists(os.path.join(kjn_dir, "requirements.txt")):
subprocess.run([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"], cwd=kjn_dir)
# Download models if not present
comfy_models_dir = os.path.join("ComfyUI", "models")
os.makedirs(comfy_models_dir, exist_ok=True)
# Diffusion models (Flux FP8)
diffusion_dir = os.path.join(comfy_models_dir, "diffusion_models")
os.makedirs(diffusion_dir, exist_ok=True)
if not os.path.exists(os.path.join(diffusion_dir, "flux1-dev-fp8.safetensors")):
hf_hub_download(repo_id="Kijai/flux-fp8", filename="flux1-dev-fp8.safetensors", local_dir=diffusion_dir)
# CLIP models
clip_dir = os.path.join(comfy_models_dir, "clip")
os.makedirs(clip_dir, exist_ok=True)
if not os.path.exists(os.path.join(clip_dir, "clip_l.safetensors")):
hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="clip_l.safetensors", local_dir=clip_dir)
if not os.path.exists(os.path.join(clip_dir, "t5xxl_fp8_e4m3fn.safetensors")):
hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="t5xxl_fp8_e4m3fn.safetensors", local_dir=clip_dir)
# VAE
vae_dir = os.path.join(comfy_models_dir, "vae")
os.makedirs(vae_dir, exist_ok=True)
if not os.path.exists(os.path.join(vae_dir, "ae.safetensors")):
hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev", filename="ae.safetensors", local_dir=vae_dir)
# Upscale models
upscale_dir = os.path.join(comfy_models_dir, "upscale_models")
os.makedirs(upscale_dir, exist_ok=True)
for model_name in ["RealESRGAN_x2.pth", "RealESRGAN_x4.pth"]:
model_path = os.path.join(upscale_dir, model_name)
if not os.path.exists(model_path):
url = f"https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/{model_name}"
with open(model_path, "wb") as f:
f.write(requests.get(url).content)
# Add ComfyUI to sys.path
sys.path.append(os.path.abspath("ComfyUI"))
# Import custom nodes
from nodes import NODE_CLASS_MAPPINGS, init_custom_nodes
init_custom_nodes()
# From the provided script
def get_value_at_index(obj, index):
try:
return obj[index]
except KeyError:
return obj["result"][index]
# CSS and constants similar to original
css = """
#col-container {
margin: 0 auto;
max-width: 800px;
}
.main-header {
text-align: center;
margin-bottom: 2rem;
}
"""
power_device = "ZeroGPU"
MAX_SEED = 1000000
MAX_PIXEL_BUDGET = 8192 * 8192
def make_divisible_by_16(size):
return ((size // 16) * 16) if (size % 16) < 8 else ((size // 16 + 1) * 16)
def process_input(input_image, upscale_factor):
w, h = input_image.size
w_original, h_original = w, h
was_resized = False
if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET:
gr.Info("Requested output too large. Resizing input.")
target_input_pixels = MAX_PIXEL_BUDGET / (upscale_factor ** 2)
scale = (target_input_pixels / (w * h)) ** 0.5
new_w = max(16, int(w * scale) // 16 * 16)
new_h = max(16, int(h * scale) // 16 * 16)
input_image = input_image.resize((new_w, new_h), resample=Image.LANCZOS)
was_resized = True
return input_image, w_original, h_original, was_resized
def load_image_from_url(url):
try:
response = requests.get(url, stream=True)
response.raise_for_status()
return Image.open(response.raw)
except Exception as e:
raise gr.Error(f"Failed to load image: {e}")
@spaces.GPU(duration=120)
def enhance_image(
image_input,
image_url,
seed,
randomize_seed,
num_inference_steps,
upscale_factor,
denoising_strength,
custom_prompt,
tile_size,
progress=gr.Progress(track_tqdm=True),
):
with torch.inference_mode():
# Handle input image
if image_input is not None:
true_input_image = image_input
elif image_url:
true_input_image = load_image_from_url(image_url)
else:
raise gr.Error("Provide an image or URL")
input_image, w_original, h_original, was_resized = process_input(true_input_image, upscale_factor)
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Prepare ComfyUI input image
input_dir = os.path.join("ComfyUI", "input")
os.makedirs(input_dir, exist_ok=True)
temp_filename = f"input_{random.randint(0, 1000000)}.png"
input_path = os.path.join(input_dir, temp_filename)
input_image.save(input_path)
# Nodes
load_image_node = NODE_CLASS_MAPPINGS["LoadImage"]()
image_loaded = load_image_node.load_image(image=temp_filename)
image = get_value_at_index(image_loaded, 0)
text_multiline = NODE_CLASS_MAPPINGS["Text Multiline"]()
text_out = text_multiline.text_multiline(text=custom_prompt if custom_prompt.strip() else "")
prompt_text = get_value_at_index(text_out, 0)
dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
clip_out = dualcliploader.load_clip(
clip_name1="clip_l.safetensors",
clip_name2="t5xxl_fp8_e4m3fn.safetensors",
type="flux",
)
clip = get_value_at_index(clip_out, 0)
cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
conditioning = get_value_at_index(cliptextencode.encode(text=prompt_text, clip=clip), 0)
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
positive_out = fluxguidance.append(guidance=3.5, conditioning=conditioning) # Using 3.5 as in original app
positive = get_value_at_index(positive_out, 0)
conditioningzeroout = NODE_CLASS_MAPPINGS["ConditioningZeroOut"]()
negative_out = conditioningzeroout.zero_out(conditioning=conditioning)
negative = get_value_at_index(negative_out, 0)
upscale_name = "RealESRGAN_x2.pth" if upscale_factor == 2 else "RealESRGAN_x4.pth"
upscalemodelloader = NODE_CLASS_MAPPINGS["UpscaleModelLoader"]()
upscale_model = get_value_at_index(upscalemodelloader.load_model(model_name=upscale_name), 0)
vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
vae = get_value_at_index(vaeloader.load_vae(vae_name="ae.safetensors"), 0)
unetloader = NODE_CLASS_MAPPINGS["LoadDiffusionModel"]()
model = get_value_at_index(unetloader.load_diffusion_model(unet_name="flux1-dev-fp8.safetensors", weight_dtype="fp8_e4m3fn"), 0)
ultimatesdupscale = NODE_CLASS_MAPPINGS["UltimateSDUpscale"]()
upscale_out = ultimatesdupscale.upscale(
upscale_by=float(upscale_factor),
seed=seed,
steps=num_inference_steps,
cfg=1.0,
sampler_name="euler",
scheduler="normal",
denoise=denoising_strength,
mode_type="Linear",
tile_width=tile_size,
tile_height=tile_size,
mask_blur=8,
tile_padding=32,
seam_fix_mode="None",
seam_fix_denoise=1.0,
seam_fix_width=64,
seam_fix_mask_blur=8,
seam_fix_padding=16,
force_uniform_tiles=True,
tiled_decode=False,
image=image,
model=model,
positive=positive,
negative=negative,
vae=vae,
upscale_model=upscale_model,
)
upscaled_tensor = get_value_at_index(upscale_out, 0)
# Convert to PIL
upscaled_img = Image.fromarray((upscaled_tensor[0].cpu().numpy() * 255).astype(np.uint8))
target_w, target_h = w_original * upscale_factor, h_original * upscale_factor
if upscaled_img.size != (target_w, target_h):
upscaled_img = upscaled_img.resize((target_w, target_h), resample=Image.LANCZOS)
if was_resized:
upscaled_img = upscaled_img.resize((target_w, target_h), resample=Image.LANCZOS)
resized_input = true_input_image.resize(upscaled_img.size, resample=Image.LANCZOS)
# Cleanup temp file
os.remove(input_path)
return [resized_input, upscaled_img]
# Gradio interface similar to original
with gr.Blocks(css=css, title="🎨 AI Image Upscaler - Flux FP8") as demo:
gr.HTML("""
<div class="main-header">
<h1>🎨 AI Image Upscaler - Flux FP8</h1>
<p>Upscale images using Flux FP8 with ComfyUI workflow</p>
<p>Running on <strong>{}</strong></p>
</div>
""".format(power_device))
with gr.Row():
with gr.Column(scale=1):
gr.HTML("<h3>πŸ“€ Input</h3>")
with gr.Tabs():
with gr.TabItem("πŸ“ Upload Image"):
input_image = gr.Image(label="Upload Image", type="pil", height=200)
with gr.TabItem("πŸ”— Image URL"):
image_url = gr.Textbox(
label="Image URL",
placeholder="https://example.com/image.jpg",
value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg"
)
gr.HTML("<h3>πŸŽ›οΈ Prompt Settings</h3>")
custom_prompt = gr.Textbox(
label="Custom Prompt (optional)",
placeholder="Enter custom prompt or leave empty",
lines=2
)
gr.HTML("<h3>βš™οΈ Upscaling Settings</h3>")
upscale_factor = gr.Slider(
label="Upscale Factor",
minimum=1,
maximum=4,
step=1,
value=2
)
num_inference_steps = gr.Slider(
label="Inference Steps",
minimum=1,
maximum=50,
step=1,
value=25
)
denoising_strength = gr.Slider(
label="Denoising Strength",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.3
)
tile_size = gr.Slider(
label="Tile Size",
minimum=256,
maximum=2048,
step=64,
value=1024
)
with gr.Row():
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
enhance_btn = gr.Button("πŸš€ Upscale Image", variant="primary", size="lg")
with gr.Column(scale=2):
gr.HTML("<h3>πŸ“Š Results</h3>")
result_slider = ImageSlider(type="pil", interactive=False, height=600, label=None)
enhance_btn.click(
fn=enhance_image,
inputs=[
input_image,
image_url,
seed,
randomize_seed,
num_inference_steps,
upscale_factor,
denoising_strength,
custom_prompt,
tile_size
],
outputs=[result_slider]
)
gr.HTML("""
<div style="margin-top: 2rem; padding: 1rem; background: #f0f0f0; border-radius: 8px;">
<p><strong>Note:</strong> Uses Flux FP8 model. Ensure compliance with licenses for commercial use.</p>
</div>
""")
gr.HTML("""
<style>
#result_slider .slider { width: 100% !important; }
#result_slider img { object-fit: contain !important; width: 100% !important; height: auto !important; }
#result_slider .gr-button-tool, #result_slider .gr-button-undo, #result_slider .gr-button-clear { display: none !important; }
#result_slider .badge-container .badge { display: none !important; }
#result_slider .badge-container::before { content: "Before"; position: absolute; top: 10px; left: 10px; background: rgba(0,0,0,0.5); color: white; padding: 5px; border-radius: 5px; z-index: 10; }
#result_slider .badge-container::after { content: "After"; position: absolute; top: 10px; right: 10px; background: rgba(0,0,0,0.5); color: white; padding: 5px; border-radius: 5px; z-index: 10; }
</style>
""")
gr.HTML("""
<script>
document.addEventListener('DOMContentLoaded', function() {
const sliderInput = document.querySelector('#result_slider input[type="range"]');
if (sliderInput) { sliderInput.value = 50; sliderInput.dispatchEvent(new Event('input')); }
});
</script>
""")
if __name__ == "__main__":
demo.queue().launch(share=True, server_name="0.0.0.0", server_port=7860)