willsh1997 commited on
Commit
0c4a319
·
1 Parent(s): 57c6da8

:clown_face: silly errors - dataset persistent storage

Browse files
Files changed (1) hide show
  1. gradio_neutral_input_func.py +181 -31
gradio_neutral_input_func.py CHANGED
@@ -6,12 +6,18 @@ import json
6
  import uuid
7
  import os
8
  from stable_diffusion_demo import StableDiffusion
 
 
9
 
10
  # Setup directories
11
  BASE_DIR = os.path.abspath(os.path.dirname(__file__))
12
  IMAGE_DIR = os.path.join(BASE_DIR, "neutral_images_storage")
13
  os.makedirs(IMAGE_DIR, exist_ok=True)
14
 
 
 
 
 
15
  def generate_image():
16
  """Generate a neutral image using Stable Diffusion"""
17
  generated_image = StableDiffusion(
@@ -25,56 +31,193 @@ def generate_image():
25
  )
26
  return generated_image
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def save_image_and_description(image, description):
29
- """Save the generated image and its description"""
30
  if image is None:
31
  return "No image to save!", None, None
32
 
33
  if not description:
34
  return "Please provide a description!", None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
 
 
 
 
 
 
36
  try:
37
- image_id = uuid.uuid4()
38
- save_path = os.path.join(IMAGE_DIR, f"{image_id}.png")
39
- json_path = os.path.join(IMAGE_DIR, f"{image_id}.json")
40
 
41
- # Save image
42
- image.save(save_path)
 
 
43
 
44
- # Save description
45
- desc_json = {"description": description}
46
- with open(json_path, "w") as f:
47
- json.dump(desc_json, f)
48
 
49
- # Return success message, clear the image output, and return updated gallery
50
- return "Saved successfully!", None, load_previous_examples()
51
  except Exception as e:
52
- return f"Error saving: {str(e)}", None, None
 
 
53
 
54
- def load_previous_examples():
55
- """Load all previously saved images and descriptions"""
56
  examples = []
57
- for file in os.listdir(IMAGE_DIR):
58
- if file.endswith(".png"):
59
- image_id = file.replace(".png", "")
60
- image_path = os.path.join(IMAGE_DIR, f"{image_id}.png")
61
- json_path = os.path.join(IMAGE_DIR, f"{image_id}.json")
62
-
63
- if os.path.exists(json_path):
64
- image = Image.open(image_path)
65
- with open(json_path, "r") as f:
66
- desc = json.load(f)["description"]
67
- examples.append((image, desc))
 
 
 
 
68
  return examples
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  # Create the Gradio interface
71
  with gr.Blocks(title="Neutral Image App") as demo:
72
  gr.Markdown("# Neutral Image App")
 
73
 
74
  with gr.Row():
75
  with gr.Column():
76
  generate_btn = gr.Button("Generate Image")
77
- # Disable image upload by setting interactive=False
78
  image_output = gr.Image(type="pil", label="Generated Image", interactive=False)
79
  description_input = gr.Textbox(label="Describe the image", lines=3)
80
  save_btn = gr.Button("Save Image and Description")
@@ -82,10 +225,11 @@ with gr.Blocks(title="Neutral Image App") as demo:
82
 
83
  with gr.Accordion("Previous Examples", open=False):
84
  gallery = gr.Gallery(
85
- label="Previous Images",
86
  show_label=True,
87
  elem_id="gallery"
88
- )#.style(grid=2, height="auto")
 
89
 
90
  # Set up event handlers
91
  generate_btn.click(
@@ -93,11 +237,15 @@ with gr.Blocks(title="Neutral Image App") as demo:
93
  outputs=[image_output]
94
  )
95
 
96
- # Updated to include gallery refresh in outputs
97
  save_btn.click(
98
  fn=save_image_and_description,
99
  inputs=[image_output, description_input],
100
- outputs=[status_output, image_output, gallery] # Added gallery to outputs
 
 
 
 
 
101
  )
102
 
103
  # Load previous examples on startup
@@ -108,4 +256,6 @@ with gr.Blocks(title="Neutral Image App") as demo:
108
 
109
  # Launch the app
110
  if __name__ == "__main__":
 
 
111
  demo.launch()
 
6
  import uuid
7
  import os
8
  from stable_diffusion_demo import StableDiffusion
9
+ from datasets import Dataset, Features, Value, Image as HFImage, load_dataset
10
+ import tempfile
11
 
12
  # Setup directories
13
  BASE_DIR = os.path.abspath(os.path.dirname(__file__))
14
  IMAGE_DIR = os.path.join(BASE_DIR, "neutral_images_storage")
15
  os.makedirs(IMAGE_DIR, exist_ok=True)
16
 
17
+ # HuggingFace dataset configuration
18
+ DATASET_REPO = "willsh1997/neutral-sd-outputs"
19
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
20
+
21
  def generate_image():
22
  """Generate a neutral image using Stable Diffusion"""
23
  generated_image = StableDiffusion(
 
31
  )
32
  return generated_image
33
 
34
+ def load_dataset_from_hf():
35
+ """Load dataset from HuggingFace Hub"""
36
+ try:
37
+ dataset = load_dataset(DATASET_REPO, split="train")
38
+ return dataset
39
+ except Exception as e:
40
+ print(f"Error loading dataset: {e}")
41
+ # Return empty dataset with correct schema if repo doesn't exist
42
+ return Dataset.from_dict({
43
+ "image": [],
44
+ "description": [],
45
+ "uuid": []
46
+ }).cast_column("image", HFImage())
47
+
48
+ def save_to_hf_dataset(image, description):
49
+ """Save new image and description to HuggingFace dataset"""
50
+ try:
51
+ # Generate UUID for the new entry
52
+ image_id = str(uuid.uuid4())
53
+
54
+ # Load existing dataset
55
+ try:
56
+ existing_dataset = load_dataset(DATASET_REPO, split="train")
57
+ except:
58
+ # Create empty dataset if it doesn't exist
59
+ existing_dataset = Dataset.from_dict({
60
+ "image": [],
61
+ "description": [],
62
+ "uuid": []
63
+ }).cast_column("image", HFImage())
64
+
65
+ # Create temporary file for the image
66
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file:
67
+ image.save(tmp_file.name, format='PNG')
68
+
69
+ # Create new entry
70
+ new_entry = {
71
+ "image": [tmp_file.name],
72
+ "description": [description],
73
+ "uuid": [image_id]
74
+ }
75
+
76
+ # Create new dataset from the entry
77
+ new_dataset = Dataset.from_dict(new_entry).cast_column("image", HFImage())
78
+
79
+ # Concatenate with existing dataset
80
+ if len(existing_dataset) > 0:
81
+ combined_dataset = existing_dataset.concatenate(new_dataset)
82
+ else:
83
+ combined_dataset = new_dataset
84
+
85
+ # Push to HuggingFace Hub
86
+ combined_dataset.push_to_hub(DATASET_REPO, private=False, token=HF_TOKEN)
87
+
88
+ # Clean up temporary file
89
+ os.unlink(tmp_file.name)
90
+
91
+ return True, "Successfully saved to HuggingFace dataset!"
92
+
93
+ except Exception as e:
94
+ return False, f"Error saving to HuggingFace: {str(e)}"
95
+
96
  def save_image_and_description(image, description):
97
+ """Save the generated image and its description to HuggingFace dataset"""
98
  if image is None:
99
  return "No image to save!", None, None
100
 
101
  if not description:
102
  return "Please provide a description!", None, None
103
+
104
+ # Save to HuggingFace dataset
105
+ success, message = save_to_hf_dataset(image, description)
106
+
107
+ if success:
108
+ # Also save locally for backup/caching
109
+ try:
110
+ image_id = uuid.uuid4()
111
+ save_path = os.path.join(IMAGE_DIR, f"{image_id}.png")
112
+ json_path = os.path.join(IMAGE_DIR, f"{image_id}.json")
113
+
114
+ image.save(save_path)
115
+ desc_json = {"description": description}
116
+ with open(json_path, "w") as f:
117
+ json.dump(desc_json, f)
118
+ except:
119
+ pass # Local save is just backup, don't fail if it doesn't work
120
 
121
+ return message, None, load_previous_examples()
122
+ else:
123
+ return message, None, None
124
+
125
+ def load_previous_examples():
126
+ """Load examples from HuggingFace dataset"""
127
  try:
128
+ dataset = load_dataset_from_hf()
129
+ examples = []
 
130
 
131
+ # Convert dataset to gallery format
132
+ for item in dataset:
133
+ if item['image'] is not None and item['description']:
134
+ examples.append((item['image'], item['description']))
135
 
136
+ return examples
 
 
 
137
 
 
 
138
  except Exception as e:
139
+ print(f"Error loading examples from HuggingFace: {e}")
140
+ # Fallback to local examples
141
+ return load_local_examples()
142
 
143
+ def load_local_examples():
144
+ """Fallback: Load examples from local storage"""
145
  examples = []
146
+ try:
147
+ for file in os.listdir(IMAGE_DIR):
148
+ if file.endswith(".png"):
149
+ image_id = file.replace(".png", "")
150
+ image_path = os.path.join(IMAGE_DIR, f"{image_id}.png")
151
+ json_path = os.path.join(IMAGE_DIR, f"{image_id}.json")
152
+
153
+ if os.path.exists(json_path):
154
+ image = Image.open(image_path)
155
+ with open(json_path, "r") as f:
156
+ desc = json.load(f)["description"]
157
+ examples.append((image, desc))
158
+ except Exception as e:
159
+ print(f"Error loading local examples: {e}")
160
+
161
  return examples
162
 
163
+ def create_initial_dataset():
164
+ """Create initial dataset from local files if HF dataset doesn't exist"""
165
+ try:
166
+ # Check if we have local files to upload
167
+ local_examples = load_local_examples()
168
+ if not local_examples:
169
+ return
170
+
171
+ # Try to load existing dataset
172
+ try:
173
+ existing_dataset = load_dataset(DATASET_REPO, split="train")
174
+ if len(existing_dataset) > 0:
175
+ return # Dataset already exists with data
176
+ except:
177
+ pass # Dataset doesn't exist, we'll create it
178
+
179
+ # Create dataset from local files
180
+ images = []
181
+ descriptions = []
182
+ uuids = []
183
+
184
+ for file in os.listdir(IMAGE_DIR):
185
+ if file.endswith(".png"):
186
+ image_id = file.replace(".png", "")
187
+ image_path = os.path.join(IMAGE_DIR, f"{image_id}.png")
188
+ json_path = os.path.join(IMAGE_DIR, f"{image_id}.json")
189
+
190
+ if os.path.exists(json_path):
191
+ with open(json_path, "r") as f:
192
+ desc = json.load(f)["description"]
193
+
194
+ images.append(image_path)
195
+ descriptions.append(desc)
196
+ uuids.append(image_id)
197
+
198
+ if images:
199
+ # Create dataset
200
+ dataset_dict = {
201
+ "image": images,
202
+ "description": descriptions,
203
+ "uuid": uuids
204
+ }
205
+
206
+ dataset = Dataset.from_dict(dataset_dict).cast_column("image", HFImage())
207
+ dataset.push_to_hub(DATASET_REPO, private=False)
208
+ print(f"Uploaded {len(images)} images to HuggingFace dataset")
209
+
210
+ except Exception as e:
211
+ print(f"Error creating initial dataset: {e}")
212
+
213
  # Create the Gradio interface
214
  with gr.Blocks(title="Neutral Image App") as demo:
215
  gr.Markdown("# Neutral Image App")
216
+ gr.Markdown(f"*Images are saved to HuggingFace dataset: [{DATASET_REPO}](https://huggingface.co/datasets/{DATASET_REPO})*")
217
 
218
  with gr.Row():
219
  with gr.Column():
220
  generate_btn = gr.Button("Generate Image")
 
221
  image_output = gr.Image(type="pil", label="Generated Image", interactive=False)
222
  description_input = gr.Textbox(label="Describe the image", lines=3)
223
  save_btn = gr.Button("Save Image and Description")
 
225
 
226
  with gr.Accordion("Previous Examples", open=False):
227
  gallery = gr.Gallery(
228
+ label="Previous Images from HuggingFace Dataset",
229
  show_label=True,
230
  elem_id="gallery"
231
+ )
232
+ refresh_btn = gr.Button("Refresh Gallery")
233
 
234
  # Set up event handlers
235
  generate_btn.click(
 
237
  outputs=[image_output]
238
  )
239
 
 
240
  save_btn.click(
241
  fn=save_image_and_description,
242
  inputs=[image_output, description_input],
243
+ outputs=[status_output, image_output, gallery]
244
+ )
245
+
246
+ refresh_btn.click(
247
+ fn=load_previous_examples,
248
+ outputs=[gallery]
249
  )
250
 
251
  # Load previous examples on startup
 
256
 
257
  # Launch the app
258
  if __name__ == "__main__":
259
+ # Create initial dataset from local files if needed
260
+ create_initial_dataset()
261
  demo.launch()