fluxhdupscaler / app.py
comrender's picture
Update app.py
fd70ade verified
import logging
import random
import warnings
import os
import gradio as gr
import numpy as np
import spaces
import torch
from gradio_imageslider import ImageSlider
from PIL import Image
from huggingface_hub import hf_hub_download
import subprocess
import sys
import tempfile
from typing import Sequence, Mapping, Any, Union
import asyncio
import shutil
# Copy functions from FluxSimpleUpscaler.txt
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
try:
return obj[index]
except KeyError:
return obj["result"][index]
def find_path(name: str, path: str = None) -> str:
if path is None:
path = os.getcwd()
if name in os.listdir(path):
path_name = os.path.join(path, name)
print(f"{name} found: {path_name}")
return path_name
parent_directory = os.path.dirname(path)
if parent_directory == path:
return None
return find_path(name, parent_directory)
def add_comfyui_directory_to_sys_path() -> None:
comfyui_path = find_path("ComfyUI")
if comfyui_path is not None and os.path.isdir(comfyui_path):
sys.path.insert(0, comfyui_path)
print(f"'{comfyui_path}' inserted to sys.path")
def add_extra_model_paths() -> None:
try:
from main import load_extra_path_config
except ImportError:
print("Could not import load_extra_path_config from main.py. Looking in utils.extra_config instead.")
from utils.extra_config import load_extra_path_config
extra_model_paths = find_path("extra_model_paths.yaml")
if extra_model_paths is not None:
load_extra_path_config(extra_model_paths)
else:
print("Could not find the extra_model_paths config file.")
def import_custom_nodes() -> None:
import asyncio
import execution
from nodes import init_extra_nodes
import server
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server_instance = server.PromptServer(loop)
execution.PromptQueue(server_instance)
loop.run_until_complete(init_extra_nodes())
# Setup ComfyUI and custom nodes
if not os.path.exists("ComfyUI"):
subprocess.run(["git", "clone", "https://github.com/comfyanonymous/ComfyUI.git"])
subprocess.run(["pip", "install", "-r", "ComfyUI/requirements.txt"])
custom_node_path = "ComfyUI/custom_nodes/ComfyUI_UltimateSDUpscale"
if not os.path.exists(custom_node_path):
subprocess.run(["git", "clone", "https://github.com/ssitu/ComfyUI_UltimateSDUpscale.git", custom_node_path])
subprocess.run(["pip", "install", "spandrel", "kornia"])
# Create model directories
os.makedirs("ComfyUI/models/diffusion_models", exist_ok=True)
os.makedirs("ComfyUI/models/clip", exist_ok=True)
os.makedirs("ComfyUI/models/vae", exist_ok=True)
os.makedirs("ComfyUI/models/upscale_models", exist_ok=True)
# Download models if not present
diffusion_path = "ComfyUI/models/diffusion_models/flux1-dev-fp8.safetensors"
if not os.path.exists(diffusion_path):
hf_hub_download("Kijai/flux-fp8", "flux1-dev-fp8.safetensors", local_dir="ComfyUI/models/diffusion_models")
clip_l_path = "ComfyUI/models/clip/clip_l.safetensors"
if not os.path.exists(clip_l_path):
hf_hub_download("comfyanonymous/flux_text_encoders", "clip_l.safetensors", local_dir="ComfyUI/models/clip")
t5_path = "ComfyUI/models/clip/t5xxl_fp8_e4m3fn.safetensors"
if not os.path.exists(t5_path):
hf_hub_download("comfyanonymous/flux_text_encoders", "t5xxl_fp8_e4m3fn.safetensors", local_dir="ComfyUI/models/clip")
vae_path = "ComfyUI/models/vae/ae.safetensors"
if not os.path.exists(vae_path):
hf_hub_download("black-forest-labs/FLUX.1-dev", "ae.safetensors", local_dir="ComfyUI/models/vae")
esrgan_x2_path = "ComfyUI/models/upscale_models/RealESRGAN_x2.pth"
if not os.path.exists(esrgan_x2_path):
hf_hub_download("ai-forever/Real-ESRGAN", "RealESRGAN_x2.pth", local_dir="ComfyUI/models/upscale_models")
esrgan_x4_path = "ComfyUI/models/upscale_models/RealESRGAN_x4.pth"
if not os.path.exists(esrgan_x4_path):
hf_hub_download("ai-forever/Real-ESRGAN", "RealESRGAN_x4.pth", local_dir="ComfyUI/models/upscale_models")
# Add ComfyUI to path and import custom nodes
add_comfyui_directory_to_sys_path()
add_extra_model_paths()
from folder_paths import add_model_folder_path
comfy_dir = find_path("ComfyUI")
add_model_folder_path("unet", os.path.join(comfy_dir, "models", "diffusion_models"))
import_custom_nodes()
from nodes import NODE_CLASS_MAPPINGS
css = """
#col-container {
margin: 0 auto;
max-width: 800px;
}
.main-header {
text-align: center;
margin-bottom: 2rem;
}
"""
MAX_SEED = 1000000
MAX_PIXEL_BUDGET = 8192 * 8192
def make_divisible_by_16(size):
return ((size // 16) * 16) if (size % 16) < 8 else ((size // 16 + 1) * 16)
def process_input(input_image, upscale_factor):
w, h = input_image.size
w_original, h_original = w, h
was_resized = False
if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET:
gr.Info(f"Requested output image is too large. Resizing input to fit within pixel budget.")
target_input_pixels = MAX_PIXEL_BUDGET / (upscale_factor ** 2)
scale = (target_input_pixels / (w * h)) ** 0.5
new_w = max(16, int(w * scale) // 16 * 16)
new_h = max(16, int(h * scale) // 16 * 16)
input_image = input_image.resize((new_w, new_h), resample=Image.LANCZOS)
was_resized = True
return input_image, w_original, h_original, was_resized
import requests
def load_image_from_url(url):
try:
response = requests.get(url, stream=True)
response.raise_for_status()
return Image.open(response.raw)
except Exception as e:
raise gr.Error(f"Failed to load image from URL: {e}")
def tensor_to_pil(tensor):
tensor = tensor.cpu().clamp(0, 1) * 255
img = tensor.numpy().astype(np.uint8)[0]
return Image.fromarray(img)
@spaces.GPU(duration=120)
def enhance_image(
image_input,
image_url,
seed,
randomize_seed,
upscale_factor,
denoising_strength,
custom_prompt,
progress=gr.Progress(track_tqdm=True),
):
if image_input is not None:
true_input_image = image_input
elif image_url:
true_input_image = load_image_from_url(image_url)
else:
raise gr.Error("Please provide an image (upload or URL)")
if randomize_seed:
seed = random.randint(0, MAX_SEED)
input_image, w_original, h_original, was_resized = process_input(true_input_image, upscale_factor)
if upscale_factor == 2:
upscale_model_name = "RealESRGAN_x2.pth"
else:
upscale_model_name = "RealESRGAN_x4.pth"
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
input_image.save(tmp.name)
temp_path = tmp.name
image_base = os.path.basename(temp_path)
comfy_dir = find_path("ComfyUI")
input_dir = os.path.join(comfy_dir, "input")
input_image_path = os.path.join(input_dir, image_base)
shutil.copy(temp_path, input_image_path)
with torch.inference_mode():
dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
dualcliploader_res = dualcliploader.load_clip(
clip_name1="clip_l.safetensors",
clip_name2="t5xxl_fp8_e4m3fn.safetensors",
type="flux",
)
clip = get_value_at_index(dualcliploader_res, 0)
cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
positive_res = cliptextencode.encode(
text=custom_prompt,
clip=clip
)
negative_res = cliptextencode.encode(
text="",
clip=clip
)
upscalemodelloader = NODE_CLASS_MAPPINGS["UpscaleModelLoader"]()
upscalemodelloader_res = upscalemodelloader.load_model(
model_name=upscale_model_name
)
vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
vaeloader_res = vaeloader.load_vae(vae_name="ae.safetensors")
unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
unetloader_res = unetloader.load_unet(
unet_name="flux1-dev-fp8.safetensors", weight_dtype="fp8_e4m3fn"
)
loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
loadimage_res = loadimage.load_image(image=image_base)
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
fluxguidance_res = fluxguidance.append(
guidance=30, conditioning=get_value_at_index(positive_res, 0)
)
ultimatesdupscale = NODE_CLASS_MAPPINGS["UltimateSDUpscale"]()
usd_res = ultimatesdupscale.upscale(
upscale_by=upscale_factor,
seed=seed,
steps=25,
cfg=1,
sampler_name="euler",
scheduler="normal",
denoise=denoising_strength,
mode_type="Linear",
tile_width=1024,
tile_height=1024,
mask_blur=8,
tile_padding=32,
seam_fix_mode="None",
seam_fix_denoise=1,
seam_fix_width=64,
seam_fix_mask_blur=8,
seam_fix_padding=16,
force_uniform_tiles=True,
tiled_decode=False,
image=get_value_at_index(loadimage_res, 0),
model=get_value_at_index(unetloader_res, 0),
positive=get_value_at_index(fluxguidance_res, 0),
negative=get_value_at_index(negative_res, 0),
vae=get_value_at_index(vaeloader_res, 0),
upscale_model=get_value_at_index(upscalemodelloader_res, 0),
)
output_tensor = get_value_at_index(usd_res, 0)
image = tensor_to_pil(output_tensor)
os.unlink(input_image_path)
os.unlink(temp_path)
target_w, target_h = w_original * upscale_factor, h_original * upscale_factor
if image.size != (target_w, target_h):
image = image.resize((target_w, target_h), resample=Image.LANCZOS)
if was_resized:
gr.Info(f"Resizing output to target size: {target_w}x{target_h}")
image = image.resize((target_w, target_h), resample=Image.LANCZOS)
resized_input = true_input_image.resize(image.size, resample=Image.LANCZOS)
return [resized_input, image]
with gr.Blocks(css=css, title="🎨 AI Image Upscaler - FLUX ComfyUI") as demo:
gr.HTML("""
<div class="main-header">
<h1>🎨 Flux Dev Image Upscaler (FP8)</h1>
<p>Upload an image or provide a URL to upscale it using FLUX FP8 with Ultimate SD Upscale</p>
<p>Using FLUX.1-dev FP8 model</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.HTML("<h3>📤 Input</h3>")
with gr.Tabs():
with gr.TabItem("📁 Upload Image"):
input_image = gr.Image(
label="Upload Image",
type="pil",
height=200
)
with gr.TabItem("🔗 Image URL"):
image_url = gr.Textbox(
label="Image URL",
placeholder="https://example.com/image.jpg",
value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg"
)
gr.HTML("<h3>🎛️ Prompt Settings</h3>")
custom_prompt = gr.Textbox(
label="Custom Prompt (optional)",
placeholder="Enter custom prompt or leave empty",
lines=2
)
gr.HTML("<h3>⚙️ Upscaling Settings</h3>")
upscale_factor = gr.Slider(
label="Upscale Factor",
minimum=1,
maximum=4,
step=1,
value=2,
info="How much to upscale the image"
)
denoising_strength = gr.Slider(
label="Denoising Strength",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.3,
info="Controls how much the image is transformed"
)
with gr.Row():
randomize_seed = gr.Checkbox(
label="Randomize seed",
value=True
)
seed = gr.Textbox(
label="Seed",
value="42",
placeholder="Enter seed value",
interactive=True
)
enhance_btn = gr.Button(
"🚀 Upscale Image",
variant="primary",
size="lg"
)
with gr.Column(scale=2):
gr.HTML("<h3>📊 Results</h3>")
result_slider = ImageSlider(
type="pil",
interactive=False,
height=600,
elem_id="result_slider",
label=None
)
enhance_btn.click(
fn=enhance_image,
inputs=[
input_image,
image_url,
seed,
randomize_seed,
upscale_factor,
denoising_strength,
custom_prompt
],
outputs=[result_slider]
)
gr.HTML("""
<div style="margin-top: 2rem; padding: 1rem; background: #f0f0f0; border-radius: 8px;">
<p><strong>Note:</strong> This upscaler uses the Flux.1-dev model. Users are responsible for obtaining commercial rights if used commercially under their license.</p>
</div>
""")
gr.HTML("""
<style>
#result_slider .slider {
width: 100% !important;
max-width: inherit !important;
}
#result_slider img {
object-fit: contain !important;
width: 100% !important;
height: auto !important;
}
#result_slider .gr-button-tool {
display: none !important;
}
#result_slider .gr-button-undo {
display: none !important;
}
#result_slider .gr-button-clear {
display: none !important;
}
#result_slider .badge-container .badge {
display: none !important;
}
#result_slider .badge-container::before {
content: "Before";
position: absolute;
top: 10px;
left: 10px;
background: rgba(0,0,0,0.5);
color: white;
padding: 5px;
border-radius: 5px;
z-index: 10;
}
#result_slider .badge-container::after {
content: "After";
position: absolute;
top: 10px;
right: 10px;
background: rgba(0,0,0,0.5);
color: white;
padding: 5px;
border-radius: 5px;
z-index: 10;
}
#result_slider .fullscreen img {
object-fit: contain !important;
width: 100vw !important;
height: 100vh !important;
position: absolute;
top: 0;
left: 0;
}
</style>
""")
gr.HTML("""
<script>
document.addEventListener('DOMContentLoaded', function() {
const sliderInput = document.querySelector('#result_slider input[type="range"]');
if (sliderInput) {
sliderInput.value = 50;
sliderInput.dispatchEvent(new Event('input'));
}
});
</script>
""")
if __name__ == "__main__":
demo.queue().launch(share=True, server_name="0.0.0.0", server_port=7860)