Update app.py
Browse files
app.py
CHANGED
|
@@ -6,6 +6,8 @@ from diffusers import StableDiffusionPipeline
|
|
| 6 |
from transformers import CLIPTokenizer
|
| 7 |
import os
|
| 8 |
import zipfile
|
|
|
|
|
|
|
| 9 |
import gradio as gr
|
| 10 |
|
| 11 |
# Define the device
|
|
@@ -104,9 +106,24 @@ def zip_model(model_path):
|
|
| 104 |
|
| 105 |
# Gradio interface functions
|
| 106 |
def start_fine_tuning(uploaded_files, prompts, num_epochs):
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
model_save_path = "fine_tuned_model"
|
| 109 |
fine_tune_model(images, prompts, model_save_path, num_epochs=int(num_epochs))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
return "Fine-tuning completed! Model is ready for download."
|
| 111 |
|
| 112 |
def download_model():
|
|
@@ -156,4 +173,4 @@ with gr.Blocks() as demo:
|
|
| 156 |
|
| 157 |
generate_button.click(generate_new_image, [prompt_input], generated_image)
|
| 158 |
|
| 159 |
-
demo.launch()
|
|
|
|
| 6 |
from transformers import CLIPTokenizer
|
| 7 |
import os
|
| 8 |
import zipfile
|
| 9 |
+
import tempfile
|
| 10 |
+
import shutil
|
| 11 |
import gradio as gr
|
| 12 |
|
| 13 |
# Define the device
|
|
|
|
| 106 |
|
| 107 |
# Gradio interface functions
|
| 108 |
def start_fine_tuning(uploaded_files, prompts, num_epochs):
|
| 109 |
+
# Create a temporary directory for storing files
|
| 110 |
+
temp_dir = tempfile.mkdtemp()
|
| 111 |
+
print("Temporary directory:", temp_dir)
|
| 112 |
+
|
| 113 |
+
images = []
|
| 114 |
+
for file in uploaded_files:
|
| 115 |
+
# Store the uploaded file in the temp directory
|
| 116 |
+
image_path = os.path.join(temp_dir, file.name)
|
| 117 |
+
with open(image_path, 'wb') as f:
|
| 118 |
+
f.write(file.read()) # Save file content
|
| 119 |
+
images.append(Image.open(image_path).convert("RGB"))
|
| 120 |
+
|
| 121 |
model_save_path = "fine_tuned_model"
|
| 122 |
fine_tune_model(images, prompts, model_save_path, num_epochs=int(num_epochs))
|
| 123 |
+
|
| 124 |
+
# Clean up the temporary directory after fine-tuning
|
| 125 |
+
shutil.rmtree(temp_dir)
|
| 126 |
+
|
| 127 |
return "Fine-tuning completed! Model is ready for download."
|
| 128 |
|
| 129 |
def download_model():
|
|
|
|
| 173 |
|
| 174 |
generate_button.click(generate_new_image, [prompt_input], generated_image)
|
| 175 |
|
| 176 |
+
demo.launch()
|