fluxhdupscaler / app.py
comrender's picture
Update app.py
f5cd9ba verified
raw
history blame
14 kB
import logging
import random
import warnings
import os
os.environ["GIT_TERMINAL_PROMPT"] = "0"
import gradio as gr
import numpy as np
import spaces
import torch
from gradio_imageslider import ImageSlider
from PIL import Image
import requests
import sys
import subprocess
from huggingface_hub import hf_hub_download
import tempfile
# Setup ComfyUI and custom nodes
if not os.path.exists("ComfyUI"):
subprocess.run(["git", "clone", "https://github.com/comfyanonymous/ComfyUI"])
custom_nodes_dir = os.path.join("ComfyUI", "custom_nodes")
os.makedirs(custom_nodes_dir, exist_ok=True)
# Clone UltimateSDUpscaler
usd_dir = os.path.join(custom_nodes_dir, "ComfyUI_UltimateSDUpscaler")
if not os.path.exists(usd_dir):
subprocess.run(["git", "clone", "https://github.com/ssitu/ComfyUI_UltimateSDUpscaler", usd_dir])
# Clone comfy_mtb
mtb_dir = os.path.join(custom_nodes_dir, "comfy_mtb")
if not os.path.exists(mtb_dir):
subprocess.run(["git", "clone", "https://github.com/melMass/comfy_mtb", mtb_dir])
# Install requirements
if os.path.exists(os.path.join(mtb_dir, "requirements.txt")):
subprocess.run([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"], cwd=mtb_dir)
# Clone KJNodes
kjn_dir = os.path.join(custom_nodes_dir, "ComfyUI-KJNodes")
if not os.path.exists(kjn_dir):
subprocess.run(["git", "clone", "https://github.com/kijai/ComfyUI-KJNodes", kjn_dir])
# Install requirements
if os.path.exists(os.path.join(kjn_dir, "requirements.txt")):
subprocess.run([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"], cwd=kjn_dir)
# Download models if not present
comfy_models_dir = os.path.join("ComfyUI", "models")
os.makedirs(comfy_models_dir, exist_ok=True)
# UNET (Flux FP8)
unet_dir = os.path.join(comfy_models_dir, "unet")
os.makedirs(unet_dir, exist_ok=True)
if not os.path.exists(os.path.join(unet_dir, "flux1-dev-fp8.safetensors")):
hf_hub_download(repo_id="Kijai/flux-fp8", filename="flux1-dev-fp8.safetensors", local_dir=unet_dir)
# CLIP models
clip_dir = os.path.join(comfy_models_dir, "clip")
os.makedirs(clip_dir, exist_ok=True)
if not os.path.exists(os.path.join(clip_dir, "clip_l.safetensors")):
hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="clip_l.safetensors", local_dir=clip_dir)
if not os.path.exists(os.path.join(clip_dir, "t5xxl_fp8_e4m3fn.safetensors")):
hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="t5xxl_fp8_e4m3fn.safetensors", local_dir=clip_dir)
# VAE
vae_dir = os.path.join(comfy_models_dir, "vae")
os.makedirs(vae_dir, exist_ok=True)
if not os.path.exists(os.path.join(vae_dir, "ae.safetensors")):
hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev", filename="ae.safetensors", local_dir=vae_dir)
# Upscale models
upscale_dir = os.path.join(comfy_models_dir, "upscale_models")
os.makedirs(upscale_dir, exist_ok=True)
for model_name in ["RealESRGAN_x2.pth", "RealESRGAN_x4.pth"]:
model_path = os.path.join(upscale_dir, model_name)
if not os.path.exists(model_path):
url = f"https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/{model_name}"
with open(model_path, "wb") as f:
f.write(requests.get(url).content)
# Add ComfyUI to sys.path
sys.path.append(os.path.abspath("ComfyUI"))
# Import custom nodes
from nodes import NODE_CLASS_MAPPINGS, init_custom_nodes
init_custom_nodes()
# From the provided script
def get_value_at_index(obj, index):
try:
return obj[index]
except KeyError:
return obj["result"][index]
# CSS and constants similar to original
css = """
#col-container {
margin: 0 auto;
max-width: 800px;
}
.main-header {
text-align: center;
margin-bottom: 2rem;
}
"""
power_device = "ZeroGPU"
MAX_SEED = 1000000
MAX_PIXEL_BUDGET = 8192 * 8192
def make_divisible_by_16(size):
return ((size // 16) * 16) if (size % 16) < 8 else ((size // 16 + 1) * 16)
def process_input(input_image, upscale_factor):
w, h = input_image.size
w_original, h_original = w, h
was_resized = False
if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET:
gr.Info("Requested output too large. Resizing input.")
target_input_pixels = MAX_PIXEL_BUDGET / (upscale_factor ** 2)
scale = (target_input_pixels / (w * h)) ** 0.5
new_w = max(16, int(w * scale) // 16 * 16)
new_h = max(16, int(h * scale) // 16 * 16)
input_image = input_image.resize((new_w, new_h), resample=Image.LANCZOS)
was_resized = True
return input_image, w_original, h_original, was_resized
def load_image_from_url(url):
try:
response = requests.get(url, stream=True)
response.raise_for_status()
return Image.open(response.raw)
except Exception as e:
raise gr.Error(f"Failed to load image: {e}")
@spaces.GPU(duration=120)
def enhance_image(
image_input,
image_url,
seed,
randomize_seed,
num_inference_steps,
upscale_factor,
denoising_strength,
custom_prompt,
tile_size,
progress=gr.Progress(track_tqdm=True),
):
with torch.inference_mode():
# Handle input image
if image_input is not None:
true_input_image = image_input
elif image_url:
true_input_image = load_image_from_url(image_url)
else:
raise gr.Error("Provide an image or URL")
input_image, w_original, h_original, was_resized = process_input(true_input_image, upscale_factor)
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Prepare ComfyUI input image
input_dir = os.path.join("ComfyUI", "input")
os.makedirs(input_dir, exist_ok=True)
temp_filename = f"input_{random.randint(0, 1000000)}.png"
input_path = os.path.join(input_dir, temp_filename)
input_image.save(input_path)
# Nodes
load_image_node = NODE_CLASS_MAPPINGS["LoadImage"]()
image_loaded = load_image_node.load_image(image=temp_filename)
image = get_value_at_index(image_loaded, 0)
text_multiline = NODE_CLASS_MAPPINGS["Text Multiline"]()
text_out = text_multiline.text_multiline(text=custom_prompt if custom_prompt.strip() else "")
prompt_text = get_value_at_index(text_out, 0)
dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
clip_out = dualcliploader.load_clip(
clip_name1="clip_l.safetensors",
clip_name2="t5xxl_fp8_e4m3fn.safetensors",
type="flux",
)
clip = get_value_at_index(clip_out, 0)
cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
conditioning = get_value_at_index(cliptextencode.encode(text=prompt_text, clip=clip), 0)
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
positive_out = fluxguidance.append(guidance=3.5, conditioning=conditioning) # Using 3.5 as in original app
positive = get_value_at_index(positive_out, 0)
conditioningzeroout = NODE_CLASS_MAPPINGS["ConditioningZeroOut"]()
negative_out = conditioningzeroout.zero_out(conditioning=conditioning)
negative = get_value_at_index(negative_out, 0)
upscale_name = "RealESRGAN_x2.pth" if upscale_factor == 2 else "RealESRGAN_x4.pth"
upscalemodelloader = NODE_CLASS_MAPPINGS["UpscaleModelLoader"]()
upscale_model = get_value_at_index(upscalemodelloader.load_model(model_name=upscale_name), 0)
vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
vae = get_value_at_index(vaeloader.load_vae(vae_name="ae.safetensors"), 0)
unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
model = get_value_at_index(unetloader.load_unet(unet_name="flux1-dev-fp8.safetensors", weight_dtype="fp8_e4m3fn"), 0)
ultimatesdupscale = NODE_CLASS_MAPPINGS["UltimateSDUpscale"]()
upscale_out = ultimatesdupscale.upscale(
upscale_by=float(upscale_factor),
seed=seed,
steps=num_inference_steps,
cfg=1.0,
sampler_name="euler",
scheduler="normal",
denoise=denoising_strength,
mode_type="Linear",
tile_width=tile_size,
tile_height=tile_size,
mask_blur=8,
tile_padding=32,
seam_fix_mode="None",
seam_fix_denoise=1.0,
seam_fix_width=64,
seam_fix_mask_blur=8,
seam_fix_padding=16,
force_uniform_tiles=True,
tiled_decode=False,
image=image,
model=model,
positive=positive,
negative=negative,
vae=vae,
upscale_model=upscale_model,
)
upscaled_tensor = get_value_at_index(upscale_out, 0)
# Convert to PIL
upscaled_img = Image.fromarray((upscaled_tensor[0].cpu().numpy() * 255).astype(np.uint8))
target_w, target_h = w_original * upscale_factor, h_original * upscale_factor
if upscaled_img.size != (target_w, target_h):
upscaled_img = upscaled_img.resize((target_w, target_h), resample=Image.LANCZOS)
if was_resized:
upscaled_img = upscaled_img.resize((target_w, target_h), resample=Image.LANCZOS)
resized_input = true_input_image.resize(upscaled_img.size, resample=Image.LANCZOS)
# Cleanup temp file
os.remove(input_path)
return [resized_input, upscaled_img]
# Gradio interface similar to original
with gr.Blocks(css=css, title="🎨 AI Image Upscaler - Flux FP8") as demo:
gr.HTML("""
<div class="main-header">
<h1>🎨 AI Image Upscaler - Flux FP8</h1>
<p>Upscale images using Flux FP8 with ComfyUI workflow</p>
<p>Running on <strong>{}</strong></p>
</div>
""".format(power_device))
with gr.Row():
with gr.Column(scale=1):
gr.HTML("<h3>πŸ“€ Input</h3>")
with gr.Tabs():
with gr.TabItem("πŸ“ Upload Image"):
input_image = gr.Image(label="Upload Image", type="pil", height=200)
with gr.TabItem("πŸ”— Image URL"):
image_url = gr.Textbox(
label="Image URL",
placeholder="https://example.com/image.jpg",
value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg"
)
gr.HTML("<h3>πŸŽ›οΈ Prompt Settings</h3>")
custom_prompt = gr.Textbox(
label="Custom Prompt (optional)",
placeholder="Enter custom prompt or leave empty",
lines=2
)
gr.HTML("<h3>βš™οΈ Upscaling Settings</h3>")
upscale_factor = gr.Slider(
label="Upscale Factor",
minimum=1,
maximum=4,
step=1,
value=2
)
num_inference_steps = gr.Slider(
label="Inference Steps",
minimum=1,
maximum=50,
step=1,
value=25
)
denoising_strength = gr.Slider(
label="Denoising Strength",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.3
)
tile_size = gr.Slider(
label="Tile Size",
minimum=256,
maximum=2048,
step=64,
value=1024
)
with gr.Row():
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
enhance_btn = gr.Button("πŸš€ Upscale Image", variant="primary", size="lg")
with gr.Column(scale=2):
gr.HTML("<h3>πŸ“Š Results</h3>")
result_slider = ImageSlider(type="pil", interactive=False, height=600, label=None)
enhance_btn.click(
fn=enhance_image,
inputs=[
input_image,
image_url,
seed,
randomize_seed,
num_inference_steps,
upscale_factor,
denoising_strength,
custom_prompt,
tile_size
],
outputs=[result_slider]
)
gr.HTML("""
<div style="margin-top: 2rem; padding: 1rem; background: #f0f0f0; border-radius: 8px;">
<p><strong>Note:</strong> Uses Flux FP8 model. Ensure compliance with licenses for commercial use.</p>
</div>
""")
gr.HTML("""
<style>
#result_slider .slider { width: 100% !important; }
#result_slider img { object-fit: contain !important; width: 100% !important; height: auto !important; }
#result_slider .gr-button-tool, #result_slider .gr-button-undo, #result_slider .gr-button-clear { display: none !important; }
#result_slider .badge-container .badge { display: none !important; }
#result_slider .badge-container::before { content: "Before"; position: absolute; top: 10px; left: 10px; background: rgba(0,0,0,0.5); color: white; padding: 5px; border-radius: 5px; z-index: 10; }
#result_slider .badge-container::after { content: "After"; position: absolute; top: 10px; right: 10px; background: rgba(0,0,0,0.5); color: white; padding: 5px; border-radius: 5px; z-index: 10; }
</style>
""")
gr.HTML("""
<script>
document.addEventListener('DOMContentLoaded', function() {
const sliderInput = document.querySelector('#result_slider input[type="range"]');
if (sliderInput) { sliderInput.value = 50; sliderInput.dispatchEvent(new Event('input')); }
});
</script>
""")
if __name__ == "__main__":
demo.queue().launch(share=True, server_name="0.0.0.0", server_port=7860)