Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import random | |
| from PIL import Image | |
| import io | |
| import json | |
| import uuid | |
| import os | |
| from stable_diffusion_demo import StableDiffusion | |
| from datasets import Dataset, Features, Value, Image as HFImage, load_dataset, concatenate_datasets | |
| import tempfile | |
| # Setup directories | |
| BASE_DIR = os.path.abspath(os.path.dirname(__file__)) | |
| IMAGE_DIR = os.path.join(BASE_DIR, "neutral_images_storage") | |
| os.makedirs(IMAGE_DIR, exist_ok=True) | |
| # HuggingFace dataset configuration | |
| DATASET_REPO = "willsh1997/neutral-sd-outputs" | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
| def generate_image(): | |
| """Generate a neutral image using Stable Diffusion""" | |
| generated_image = StableDiffusion( | |
| uncond_embeddings=[''], | |
| text_embeddings=[''], | |
| height=512, | |
| width=512, | |
| num_inference_steps=25, | |
| guidance_scale=7.5, | |
| seed=None, | |
| ) | |
| return generated_image | |
| def load_dataset_from_hf(): | |
| """Load dataset from HuggingFace Hub""" | |
| try: | |
| dataset = load_dataset(DATASET_REPO, split="train") | |
| return dataset | |
| except Exception as e: | |
| print(f"Error loading dataset: {e}") | |
| # Return empty dataset with correct schema if repo doesn't exist | |
| return Dataset.from_dict({ | |
| "image": [], | |
| "description": [], | |
| "uuid": [] | |
| }).cast_column("image", HFImage()) | |
| def save_to_hf_dataset(image, description): | |
| """Save new image and description to HuggingFace dataset""" | |
| # try: | |
| # Generate UUID for the new entry | |
| image_id = str(uuid.uuid4()) | |
| # Load existing dataset | |
| try: | |
| existing_dataset = load_dataset(DATASET_REPO, split="train") | |
| except: | |
| # Create empty dataset if it doesn't exist | |
| existing_dataset = Dataset.from_dict({ | |
| "image": [], | |
| "description": [], | |
| "uuid": [] | |
| }).cast_column("image", HFImage()) | |
| # Create temporary file for the image | |
| with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file: | |
| image.save(tmp_file.name, format='PNG') | |
| # Create new entry | |
| new_entry = { | |
| "image": [tmp_file.name], | |
| "description": [description], | |
| "uuid": [image_id] | |
| } | |
| # Create new dataset from the entry | |
| new_dataset = Dataset.from_dict(new_entry).cast_column("image", HFImage()) | |
| # Concatenate with existing dataset | |
| if len(existing_dataset) > 0: | |
| combined_dataset = concatenate_datasets([existing_dataset, new_dataset]) | |
| else: | |
| combined_dataset = new_dataset | |
| # Push to HuggingFace Hub | |
| combined_dataset.push_to_hub(DATASET_REPO, private=False, token=HF_TOKEN) | |
| # Clean up temporary file | |
| os.unlink(tmp_file.name) | |
| return True, "Successfully saved to HuggingFace dataset!" | |
| # except Exception as e: | |
| # return False, f"Error saving to HuggingFace: {str(e)}" | |
| def save_image_and_description(image, description): | |
| """Save the generated image and its description to HuggingFace dataset""" | |
| if image is None: | |
| return "No image to save!", None, None | |
| if not description: | |
| return "Please provide a description!", None, None | |
| # Save to HuggingFace dataset | |
| success, message = save_to_hf_dataset(image, description) | |
| if success: | |
| # Also save locally for backup/caching | |
| try: | |
| image_id = uuid.uuid4() | |
| save_path = os.path.join(IMAGE_DIR, f"{image_id}.png") | |
| json_path = os.path.join(IMAGE_DIR, f"{image_id}.json") | |
| image.save(save_path) | |
| desc_json = {"description": description} | |
| with open(json_path, "w") as f: | |
| json.dump(desc_json, f) | |
| except: | |
| pass # Local save is just backup, don't fail if it doesn't work | |
| return None, load_previous_examples() | |
| else: | |
| return None, None | |
| def load_previous_examples(): | |
| """Load examples from HuggingFace dataset""" | |
| try: | |
| dataset = load_dataset_from_hf() | |
| examples = [] | |
| # Convert dataset to gallery format | |
| for item in dataset: | |
| if item['image'] is not None and item['description']: | |
| examples.append((item['image'], item['description'])) | |
| return examples | |
| except Exception as e: | |
| print(f"Error loading examples from HuggingFace: {e}") | |
| # Fallback to local examples | |
| return load_local_examples() | |
| def load_local_examples(): | |
| """Fallback: Load examples from local storage""" | |
| examples = [] | |
| try: | |
| for file in os.listdir(IMAGE_DIR): | |
| if file.endswith(".png"): | |
| image_id = file.replace(".png", "") | |
| image_path = os.path.join(IMAGE_DIR, f"{image_id}.png") | |
| json_path = os.path.join(IMAGE_DIR, f"{image_id}.json") | |
| if os.path.exists(json_path): | |
| image = Image.open(image_path) | |
| with open(json_path, "r") as f: | |
| desc = json.load(f)["description"] | |
| examples.append((image, desc)) | |
| except Exception as e: | |
| print(f"Error loading local examples: {e}") | |
| return examples | |
| def create_initial_dataset(): | |
| """Create initial dataset from local files if HF dataset doesn't exist""" | |
| try: | |
| # Check if we have local files to upload | |
| local_examples = load_local_examples() | |
| if not local_examples: | |
| return | |
| # Try to load existing dataset | |
| try: | |
| existing_dataset = load_dataset(DATASET_REPO, split="train") | |
| if len(existing_dataset) > 0: | |
| return # Dataset already exists with data | |
| except: | |
| pass # Dataset doesn't exist, we'll create it | |
| # Create dataset from local files | |
| images = [] | |
| descriptions = [] | |
| uuids = [] | |
| for file in os.listdir(IMAGE_DIR): | |
| if file.endswith(".png"): | |
| image_id = file.replace(".png", "") | |
| image_path = os.path.join(IMAGE_DIR, f"{image_id}.png") | |
| json_path = os.path.join(IMAGE_DIR, f"{image_id}.json") | |
| if os.path.exists(json_path): | |
| with open(json_path, "r") as f: | |
| desc = json.load(f)["description"] | |
| images.append(image_path) | |
| descriptions.append(desc) | |
| uuids.append(image_id) | |
| if images: | |
| # Create dataset | |
| dataset_dict = { | |
| "image": images, | |
| "description": descriptions, | |
| "uuid": uuids | |
| } | |
| dataset = Dataset.from_dict(dataset_dict).cast_column("image", HFImage()) | |
| dataset.push_to_hub(DATASET_REPO, private=False) | |
| print(f"Uploaded {len(images)} images to HuggingFace dataset") | |
| except Exception as e: | |
| print(f"Error creating initial dataset: {e}") | |
| # Create the Gradio interface | |
| with gr.Blocks(title="Neutral Image App") as demo: | |
| gr.Markdown("# Neutral Image App") | |
| gr.Markdown(f"*Images are saved to HuggingFace dataset: [{DATASET_REPO}](https://huggingface.co/datasets/{DATASET_REPO})*") | |
| with gr.Row(): | |
| with gr.Column(): | |
| generate_btn = gr.Button("Generate Image") | |
| image_output = gr.Image(type="pil", label="Generated Image", interactive=False) | |
| description_input = gr.Textbox(label="Describe the image", lines=3) | |
| save_btn = gr.Button("Save Image and Description") | |
| # status_output = gr.Textbox(label="Status") | |
| with gr.Accordion("Previous Examples", open=False): | |
| gallery = gr.Gallery( | |
| label="Previous Images from HuggingFace Dataset", | |
| show_label=True, | |
| elem_id="gallery" | |
| ) | |
| refresh_btn = gr.Button("Refresh Gallery") | |
| # Set up event handlers | |
| generate_btn.click( | |
| fn=generate_image, | |
| outputs=[image_output] | |
| ) | |
| save_btn.click( | |
| fn=save_image_and_description, | |
| inputs=[image_output, description_input], | |
| outputs=[image_output, gallery] | |
| ) | |
| refresh_btn.click( | |
| fn=load_previous_examples, | |
| outputs=[gallery] | |
| ) | |
| # Load previous examples on startup | |
| demo.load( | |
| fn=load_previous_examples, | |
| outputs=[gallery] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| # Create initial dataset from local files if needed | |
| create_initial_dataset() | |
| demo.launch() |