import gradio as gr import random from PIL import Image import io import json import uuid import os from stable_diffusion_demo import StableDiffusion from datasets import Dataset, Features, Value, Image as HFImage, load_dataset, concatenate_datasets import tempfile # Setup directories BASE_DIR = os.path.abspath(os.path.dirname(__file__)) IMAGE_DIR = os.path.join(BASE_DIR, "neutral_images_storage") os.makedirs(IMAGE_DIR, exist_ok=True) # HuggingFace dataset configuration DATASET_REPO = "willsh1997/neutral-sd-outputs" HF_TOKEN = os.environ.get("HF_TOKEN", "") def generate_image(): """Generate a neutral image using Stable Diffusion""" generated_image = StableDiffusion( uncond_embeddings=[''], text_embeddings=[''], height=512, width=512, num_inference_steps=25, guidance_scale=7.5, seed=None, ) return generated_image def load_dataset_from_hf(): """Load dataset from HuggingFace Hub""" try: dataset = load_dataset(DATASET_REPO, split="train") return dataset except Exception as e: print(f"Error loading dataset: {e}") # Return empty dataset with correct schema if repo doesn't exist return Dataset.from_dict({ "image": [], "description": [], "uuid": [] }).cast_column("image", HFImage()) def save_to_hf_dataset(image, description): """Save new image and description to HuggingFace dataset""" # try: # Generate UUID for the new entry image_id = str(uuid.uuid4()) # Load existing dataset try: existing_dataset = load_dataset(DATASET_REPO, split="train") except: # Create empty dataset if it doesn't exist existing_dataset = Dataset.from_dict({ "image": [], "description": [], "uuid": [] }).cast_column("image", HFImage()) # Create temporary file for the image with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file: image.save(tmp_file.name, format='PNG') # Create new entry new_entry = { "image": [tmp_file.name], "description": [description], "uuid": [image_id] } # Create new dataset from the entry new_dataset = Dataset.from_dict(new_entry).cast_column("image", HFImage()) # Concatenate with existing dataset if len(existing_dataset) > 0: combined_dataset = concatenate_datasets([existing_dataset, new_dataset]) else: combined_dataset = new_dataset # Push to HuggingFace Hub combined_dataset.push_to_hub(DATASET_REPO, private=False, token=HF_TOKEN) # Clean up temporary file os.unlink(tmp_file.name) return True, "Successfully saved to HuggingFace dataset!" # except Exception as e: # return False, f"Error saving to HuggingFace: {str(e)}" def save_image_and_description(image, description): """Save the generated image and its description to HuggingFace dataset""" if image is None: return "No image to save!", None, None if not description: return "Please provide a description!", None, None # Save to HuggingFace dataset success, message = save_to_hf_dataset(image, description) if success: # Also save locally for backup/caching try: image_id = uuid.uuid4() save_path = os.path.join(IMAGE_DIR, f"{image_id}.png") json_path = os.path.join(IMAGE_DIR, f"{image_id}.json") image.save(save_path) desc_json = {"description": description} with open(json_path, "w") as f: json.dump(desc_json, f) except: pass # Local save is just backup, don't fail if it doesn't work return None, load_previous_examples() else: return None, None def load_previous_examples(): """Load examples from HuggingFace dataset""" try: dataset = load_dataset_from_hf() examples = [] # Convert dataset to gallery format for item in dataset: if item['image'] is not None and item['description']: examples.append((item['image'], item['description'])) return examples except Exception as e: print(f"Error loading examples from HuggingFace: {e}") # Fallback to local examples return load_local_examples() def load_local_examples(): """Fallback: Load examples from local storage""" examples = [] try: for file in os.listdir(IMAGE_DIR): if file.endswith(".png"): image_id = file.replace(".png", "") image_path = os.path.join(IMAGE_DIR, f"{image_id}.png") json_path = os.path.join(IMAGE_DIR, f"{image_id}.json") if os.path.exists(json_path): image = Image.open(image_path) with open(json_path, "r") as f: desc = json.load(f)["description"] examples.append((image, desc)) except Exception as e: print(f"Error loading local examples: {e}") return examples def create_initial_dataset(): """Create initial dataset from local files if HF dataset doesn't exist""" try: # Check if we have local files to upload local_examples = load_local_examples() if not local_examples: return # Try to load existing dataset try: existing_dataset = load_dataset(DATASET_REPO, split="train") if len(existing_dataset) > 0: return # Dataset already exists with data except: pass # Dataset doesn't exist, we'll create it # Create dataset from local files images = [] descriptions = [] uuids = [] for file in os.listdir(IMAGE_DIR): if file.endswith(".png"): image_id = file.replace(".png", "") image_path = os.path.join(IMAGE_DIR, f"{image_id}.png") json_path = os.path.join(IMAGE_DIR, f"{image_id}.json") if os.path.exists(json_path): with open(json_path, "r") as f: desc = json.load(f)["description"] images.append(image_path) descriptions.append(desc) uuids.append(image_id) if images: # Create dataset dataset_dict = { "image": images, "description": descriptions, "uuid": uuids } dataset = Dataset.from_dict(dataset_dict).cast_column("image", HFImage()) dataset.push_to_hub(DATASET_REPO, private=False) print(f"Uploaded {len(images)} images to HuggingFace dataset") except Exception as e: print(f"Error creating initial dataset: {e}") # Create the Gradio interface with gr.Blocks(title="Neutral Image App") as demo: gr.Markdown("# Neutral Image App") gr.Markdown(f"*Images are saved to HuggingFace dataset: [{DATASET_REPO}](https://huggingface.co/datasets/{DATASET_REPO})*") with gr.Row(): with gr.Column(): generate_btn = gr.Button("Generate Image") image_output = gr.Image(type="pil", label="Generated Image", interactive=False) description_input = gr.Textbox(label="Describe the image", lines=3) save_btn = gr.Button("Save Image and Description") # status_output = gr.Textbox(label="Status") with gr.Accordion("Previous Examples", open=False): gallery = gr.Gallery( label="Previous Images from HuggingFace Dataset", show_label=True, elem_id="gallery" ) refresh_btn = gr.Button("Refresh Gallery") # Set up event handlers generate_btn.click( fn=generate_image, outputs=[image_output] ) save_btn.click( fn=save_image_and_description, inputs=[image_output, description_input], outputs=[image_output, gallery] ) refresh_btn.click( fn=load_previous_examples, outputs=[gallery] ) # Load previous examples on startup demo.load( fn=load_previous_examples, outputs=[gallery] ) # Launch the app if __name__ == "__main__": # Create initial dataset from local files if needed create_initial_dataset() demo.launch()