Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,8 +6,6 @@ from diffusers import StableDiffusionPipeline
|
|
| 6 |
from transformers import CLIPTokenizer
|
| 7 |
import os
|
| 8 |
import zipfile
|
| 9 |
-
import tempfile
|
| 10 |
-
import shutil
|
| 11 |
import gradio as gr
|
| 12 |
|
| 13 |
# Define the device
|
|
@@ -104,26 +102,18 @@ def zip_model(model_path):
|
|
| 104 |
zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), model_path))
|
| 105 |
return zip_path
|
| 106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
# Gradio interface functions
|
| 108 |
def start_fine_tuning(uploaded_files, prompts, num_epochs):
|
| 109 |
-
|
| 110 |
-
temp_dir = tempfile.mkdtemp()
|
| 111 |
-
print("Temporary directory:", temp_dir)
|
| 112 |
-
|
| 113 |
-
images = []
|
| 114 |
-
for file in uploaded_files:
|
| 115 |
-
# Store the uploaded file in the temp directory
|
| 116 |
-
image_path = os.path.join(temp_dir, file.name)
|
| 117 |
-
with open(image_path, 'wb') as f:
|
| 118 |
-
f.write(file.read()) # Save file content
|
| 119 |
-
images.append(Image.open(image_path).convert("RGB"))
|
| 120 |
-
|
| 121 |
model_save_path = "fine_tuned_model"
|
| 122 |
fine_tune_model(images, prompts, model_save_path, num_epochs=int(num_epochs))
|
| 123 |
-
|
| 124 |
-
# Clean up the temporary directory after fine-tuning
|
| 125 |
-
shutil.rmtree(temp_dir)
|
| 126 |
-
|
| 127 |
return "Fine-tuning completed! Model is ready for download."
|
| 128 |
|
| 129 |
def download_model():
|
|
@@ -173,4 +163,4 @@ with gr.Blocks() as demo:
|
|
| 173 |
|
| 174 |
generate_button.click(generate_new_image, [prompt_input], generated_image)
|
| 175 |
|
| 176 |
-
demo.launch()
|
|
|
|
| 6 |
from transformers import CLIPTokenizer
|
| 7 |
import os
|
| 8 |
import zipfile
|
|
|
|
|
|
|
| 9 |
import gradio as gr
|
| 10 |
|
| 11 |
# Define the device
|
|
|
|
| 102 |
zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), model_path))
|
| 103 |
return zip_path
|
| 104 |
|
| 105 |
+
# Function to save uploaded files
|
| 106 |
+
def save_uploaded_file(uploaded_file, save_path):
|
| 107 |
+
# Open the file in binary write mode
|
| 108 |
+
with open(save_path, 'wb') as f:
|
| 109 |
+
f.write(uploaded_file.data) # Use .data for the file content
|
| 110 |
+
return f"File saved at {save_path}"
|
| 111 |
+
|
| 112 |
# Gradio interface functions
|
| 113 |
def start_fine_tuning(uploaded_files, prompts, num_epochs):
|
| 114 |
+
images = [Image.open(file).convert("RGB") for file in uploaded_files]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
model_save_path = "fine_tuned_model"
|
| 116 |
fine_tune_model(images, prompts, model_save_path, num_epochs=int(num_epochs))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
return "Fine-tuning completed! Model is ready for download."
|
| 118 |
|
| 119 |
def download_model():
|
|
|
|
| 163 |
|
| 164 |
generate_button.click(generate_new_image, [prompt_input], generated_image)
|
| 165 |
|
| 166 |
+
demo.launch()
|