neutral-sd-dev / gradio_neutral_input_func.py
willsh1997's picture
:bug: incorrect obj method - removed status printout
4edde26
import gradio as gr
import random
from PIL import Image
import io
import json
import uuid
import os
from stable_diffusion_demo import StableDiffusion
from datasets import Dataset, Features, Value, Image as HFImage, load_dataset, concatenate_datasets
import tempfile
# Setup directories
BASE_DIR = os.path.abspath(os.path.dirname(__file__))
IMAGE_DIR = os.path.join(BASE_DIR, "neutral_images_storage")
os.makedirs(IMAGE_DIR, exist_ok=True)
# HuggingFace dataset configuration
DATASET_REPO = "willsh1997/neutral-sd-outputs"
HF_TOKEN = os.environ.get("HF_TOKEN", "")
def generate_image():
"""Generate a neutral image using Stable Diffusion"""
generated_image = StableDiffusion(
uncond_embeddings=[''],
text_embeddings=[''],
height=512,
width=512,
num_inference_steps=25,
guidance_scale=7.5,
seed=None,
)
return generated_image
def load_dataset_from_hf():
"""Load dataset from HuggingFace Hub"""
try:
dataset = load_dataset(DATASET_REPO, split="train")
return dataset
except Exception as e:
print(f"Error loading dataset: {e}")
# Return empty dataset with correct schema if repo doesn't exist
return Dataset.from_dict({
"image": [],
"description": [],
"uuid": []
}).cast_column("image", HFImage())
def save_to_hf_dataset(image, description):
"""Save new image and description to HuggingFace dataset"""
# try:
# Generate UUID for the new entry
image_id = str(uuid.uuid4())
# Load existing dataset
try:
existing_dataset = load_dataset(DATASET_REPO, split="train")
except:
# Create empty dataset if it doesn't exist
existing_dataset = Dataset.from_dict({
"image": [],
"description": [],
"uuid": []
}).cast_column("image", HFImage())
# Create temporary file for the image
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file:
image.save(tmp_file.name, format='PNG')
# Create new entry
new_entry = {
"image": [tmp_file.name],
"description": [description],
"uuid": [image_id]
}
# Create new dataset from the entry
new_dataset = Dataset.from_dict(new_entry).cast_column("image", HFImage())
# Concatenate with existing dataset
if len(existing_dataset) > 0:
combined_dataset = concatenate_datasets([existing_dataset, new_dataset])
else:
combined_dataset = new_dataset
# Push to HuggingFace Hub
combined_dataset.push_to_hub(DATASET_REPO, private=False, token=HF_TOKEN)
# Clean up temporary file
os.unlink(tmp_file.name)
return True, "Successfully saved to HuggingFace dataset!"
# except Exception as e:
# return False, f"Error saving to HuggingFace: {str(e)}"
def save_image_and_description(image, description):
"""Save the generated image and its description to HuggingFace dataset"""
if image is None:
return "No image to save!", None, None
if not description:
return "Please provide a description!", None, None
# Save to HuggingFace dataset
success, message = save_to_hf_dataset(image, description)
if success:
# Also save locally for backup/caching
try:
image_id = uuid.uuid4()
save_path = os.path.join(IMAGE_DIR, f"{image_id}.png")
json_path = os.path.join(IMAGE_DIR, f"{image_id}.json")
image.save(save_path)
desc_json = {"description": description}
with open(json_path, "w") as f:
json.dump(desc_json, f)
except:
pass # Local save is just backup, don't fail if it doesn't work
return None, load_previous_examples()
else:
return None, None
def load_previous_examples():
"""Load examples from HuggingFace dataset"""
try:
dataset = load_dataset_from_hf()
examples = []
# Convert dataset to gallery format
for item in dataset:
if item['image'] is not None and item['description']:
examples.append((item['image'], item['description']))
return examples
except Exception as e:
print(f"Error loading examples from HuggingFace: {e}")
# Fallback to local examples
return load_local_examples()
def load_local_examples():
"""Fallback: Load examples from local storage"""
examples = []
try:
for file in os.listdir(IMAGE_DIR):
if file.endswith(".png"):
image_id = file.replace(".png", "")
image_path = os.path.join(IMAGE_DIR, f"{image_id}.png")
json_path = os.path.join(IMAGE_DIR, f"{image_id}.json")
if os.path.exists(json_path):
image = Image.open(image_path)
with open(json_path, "r") as f:
desc = json.load(f)["description"]
examples.append((image, desc))
except Exception as e:
print(f"Error loading local examples: {e}")
return examples
def create_initial_dataset():
"""Create initial dataset from local files if HF dataset doesn't exist"""
try:
# Check if we have local files to upload
local_examples = load_local_examples()
if not local_examples:
return
# Try to load existing dataset
try:
existing_dataset = load_dataset(DATASET_REPO, split="train")
if len(existing_dataset) > 0:
return # Dataset already exists with data
except:
pass # Dataset doesn't exist, we'll create it
# Create dataset from local files
images = []
descriptions = []
uuids = []
for file in os.listdir(IMAGE_DIR):
if file.endswith(".png"):
image_id = file.replace(".png", "")
image_path = os.path.join(IMAGE_DIR, f"{image_id}.png")
json_path = os.path.join(IMAGE_DIR, f"{image_id}.json")
if os.path.exists(json_path):
with open(json_path, "r") as f:
desc = json.load(f)["description"]
images.append(image_path)
descriptions.append(desc)
uuids.append(image_id)
if images:
# Create dataset
dataset_dict = {
"image": images,
"description": descriptions,
"uuid": uuids
}
dataset = Dataset.from_dict(dataset_dict).cast_column("image", HFImage())
dataset.push_to_hub(DATASET_REPO, private=False)
print(f"Uploaded {len(images)} images to HuggingFace dataset")
except Exception as e:
print(f"Error creating initial dataset: {e}")
# Create the Gradio interface
with gr.Blocks(title="Neutral Image App") as demo:
gr.Markdown("# Neutral Image App")
gr.Markdown(f"*Images are saved to HuggingFace dataset: [{DATASET_REPO}](https://huggingface.co/datasets/{DATASET_REPO})*")
with gr.Row():
with gr.Column():
generate_btn = gr.Button("Generate Image")
image_output = gr.Image(type="pil", label="Generated Image", interactive=False)
description_input = gr.Textbox(label="Describe the image", lines=3)
save_btn = gr.Button("Save Image and Description")
# status_output = gr.Textbox(label="Status")
with gr.Accordion("Previous Examples", open=False):
gallery = gr.Gallery(
label="Previous Images from HuggingFace Dataset",
show_label=True,
elem_id="gallery"
)
refresh_btn = gr.Button("Refresh Gallery")
# Set up event handlers
generate_btn.click(
fn=generate_image,
outputs=[image_output]
)
save_btn.click(
fn=save_image_and_description,
inputs=[image_output, description_input],
outputs=[image_output, gallery]
)
refresh_btn.click(
fn=load_previous_examples,
outputs=[gallery]
)
# Load previous examples on startup
demo.load(
fn=load_previous_examples,
outputs=[gallery]
)
# Launch the app
if __name__ == "__main__":
# Create initial dataset from local files if needed
create_initial_dataset()
demo.launch()