cavargas10 commited on
Commit
6e7dc72
verified
1 Parent(s): f7ef315
Files changed (1) hide show
  1. app.py +45 -22
app.py CHANGED
@@ -17,6 +17,10 @@ from PIL import Image
17
  from trellis.pipelines import TrellisImageTo3DPipeline
18
  from trellis.representations import Gaussian, MeshExtractResult
19
  from trellis.utils import render_utils, postprocessing_utils
 
 
 
 
20
  NUM_INFERENCE_STEPS = 8
21
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
22
  # Constants
@@ -24,17 +28,29 @@ MAX_SEED = np.iinfo(np.int32).max
24
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
25
  os.makedirs(TMP_DIR, exist_ok=True)
26
 
27
- # Funciones auxiliares
28
  def start_session(req: gr.Request):
29
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
 
30
  os.makedirs(user_dir, exist_ok=True)
31
 
32
  def end_session(req: gr.Request):
33
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
34
- shutil.rmtree(user_dir)
 
 
 
 
 
 
 
 
 
35
 
36
  def preprocess_image(image: Image.Image) -> Image.Image:
 
37
  processed_image = trellis_pipeline.preprocess_image(image)
 
38
  return processed_image
39
 
40
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
@@ -74,7 +90,9 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict]:
74
  return gs, mesh
75
 
76
  def get_seed(randomize_seed: bool, seed: int) -> int:
77
- return np.random.randint(0, MAX_SEED) if randomize_seed else seed
 
 
78
 
79
  @spaces.GPU
80
  def generate_flux_image(
@@ -87,13 +105,15 @@ def generate_flux_image(
87
  req: gr.Request,
88
  progress: gr.Progress = gr.Progress(track_tqdm=True),
89
  ) -> Image.Image:
90
- """Generate image using Flux pipeline"""
 
91
  if randomize_seed:
92
  seed = random.randint(0, MAX_SEED)
 
93
  generator = torch.Generator(device=device).manual_seed(seed)
94
- prompt = "wbgmsst, " + prompt + ", 3D isometric, white background"
95
  image = flux_pipeline(
96
- prompt=prompt,
97
  guidance_scale=guidance_scale,
98
  num_inference_steps=NUM_INFERENCE_STEPS,
99
  width=width,
@@ -101,13 +121,12 @@ def generate_flux_image(
101
  generator=generator,
102
  ).images[0]
103
 
104
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
105
- os.makedirs(user_dir, exist_ok=True)
106
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
107
- unique_id = str(uuid.uuid4())[:8]
108
- filename = f"{timestamp}_{unique_id}.png"
109
- filepath = os.path.join(user_dir, filename)
110
  image.save(filepath)
 
111
  return image
112
 
113
  @spaces.GPU
@@ -120,7 +139,9 @@ def image_to_3d(
120
  slat_sampling_steps: int,
121
  req: gr.Request,
122
  ) -> Tuple[dict, str]:
123
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
 
124
  outputs = trellis_pipeline.run(
125
  image,
126
  seed=seed,
@@ -135,6 +156,7 @@ def image_to_3d(
135
  "cfg_strength": slat_guidance_strength,
136
  },
137
  )
 
138
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
139
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
140
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
@@ -142,6 +164,7 @@ def image_to_3d(
142
  imageio.mimsave(video_path, video, fps=15)
143
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
144
  torch.cuda.empty_cache()
 
145
  return state, video_path
146
 
147
  @spaces.GPU(duration=90)
@@ -151,12 +174,15 @@ def extract_glb(
151
  texture_size: int,
152
  req: gr.Request,
153
  ) -> Tuple[str, str]:
154
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
 
155
  gs, mesh = unpack_state(state)
156
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
157
  glb_path = os.path.join(user_dir, 'sample.glb')
158
  glb.export(glb_path)
159
  torch.cuda.empty_cache()
 
160
  return glb_path, glb_path
161
 
162
  # Interfaz Gradio
@@ -194,8 +220,7 @@ with gr.Blocks() as demo:
194
 
195
  with gr.Row():
196
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
197
-
198
- # Variables adicionales para la generaci贸n 3D
199
  with gr.Accordion("3D Generation Settings", open=False):
200
  gr.Markdown("Stage 1: Sparse Structure Generation")
201
  with gr.Row():
@@ -205,15 +230,13 @@ with gr.Blocks() as demo:
205
  with gr.Row():
206
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
207
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
208
-
209
- # Variables para la extracci贸n de GLB
210
  with gr.Accordion("GLB Extraction Settings", open=False):
211
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
212
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
213
 
214
  output_buf = gr.State()
215
-
216
- # Event handlers
217
  demo.load(start_session)
218
  demo.unload(end_session)
219
 
 
17
  from trellis.pipelines import TrellisImageTo3DPipeline
18
  from trellis.representations import Gaussian, MeshExtractResult
19
  from trellis.utils import render_utils, postprocessing_utils
20
+ import logging
21
+
22
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - HF_SPACE_IMG - %(levelname)s - %(message)s')
23
+
24
  NUM_INFERENCE_STEPS = 8
25
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
26
  # Constants
 
28
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
29
  os.makedirs(TMP_DIR, exist_ok=True)
30
 
 
31
  def start_session(req: gr.Request):
32
+ session_hash = str(req.session_hash)
33
+ user_dir = os.path.join(TMP_DIR, session_hash)
34
+ logging.info(f"START SESSION: Creando directorio para la sesi贸n {session_hash} en {user_dir}")
35
  os.makedirs(user_dir, exist_ok=True)
36
 
37
  def end_session(req: gr.Request):
38
+ session_hash = str(req.session_hash)
39
+ user_dir = os.path.join(TMP_DIR, session_hash)
40
+ logging.info(f"END SESSION: Intentando eliminar el directorio de la sesi贸n {session_hash} en {user_dir}")
41
+ if os.path.exists(user_dir):
42
+ try:
43
+ shutil.rmtree(user_dir)
44
+ logging.info(f"Directorio de la sesi贸n {session_hash} eliminado correctamente.")
45
+ except Exception as e:
46
+ logging.error(f"Error al eliminar el directorio de la sesi贸n {session_hash}: {e}")
47
+ else:
48
+ logging.warning(f"El directorio de la sesi贸n {session_hash} no fue encontrado al intentar eliminarlo. Es posible que ya haya sido limpiado.")
49
 
50
  def preprocess_image(image: Image.Image) -> Image.Image:
51
+ logging.info("Preprocesando imagen para Trellis...")
52
  processed_image = trellis_pipeline.preprocess_image(image)
53
+ logging.info("Preprocesamiento de imagen completado.")
54
  return processed_image
55
 
56
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
 
90
  return gs, mesh
91
 
92
  def get_seed(randomize_seed: bool, seed: int) -> int:
93
+ new_seed = np.random.randint(0, MAX_SEED) if randomize_seed else seed
94
+ logging.info(f"Usando seed: {new_seed}")
95
+ return new_seed
96
 
97
  @spaces.GPU
98
  def generate_flux_image(
 
105
  req: gr.Request,
106
  progress: gr.Progress = gr.Progress(track_tqdm=True),
107
  ) -> Image.Image:
108
+ session_hash = str(req.session_hash)
109
+ logging.info(f"[{session_hash}] Iniciando generate_flux_image con prompt: '{prompt[:50]}...'")
110
  if randomize_seed:
111
  seed = random.randint(0, MAX_SEED)
112
+ logging.info(f"[{session_hash}] Seed aleatorizado a: {seed}")
113
  generator = torch.Generator(device=device).manual_seed(seed)
114
+ full_prompt = "wbgmsst, " + prompt + ", 3D isometric, white background"
115
  image = flux_pipeline(
116
+ prompt=full_prompt,
117
  guidance_scale=guidance_scale,
118
  num_inference_steps=NUM_INFERENCE_STEPS,
119
  width=width,
 
121
  generator=generator,
122
  ).images[0]
123
 
124
+ user_dir = os.path.join(TMP_DIR, session_hash)
125
+ os.makedirs(user_dir, exist_ok=True)
126
+
127
+ filepath = os.path.join(user_dir, "generated_2d_image.png")
 
 
128
  image.save(filepath)
129
+ logging.info(f"[{session_hash}] Imagen 2D guardada en: {filepath}")
130
  return image
131
 
132
  @spaces.GPU
 
139
  slat_sampling_steps: int,
140
  req: gr.Request,
141
  ) -> Tuple[dict, str]:
142
+ session_hash = str(req.session_hash)
143
+ logging.info(f"[{session_hash}] Iniciando image_to_3d...")
144
+ user_dir = os.path.join(TMP_DIR, session_hash)
145
  outputs = trellis_pipeline.run(
146
  image,
147
  seed=seed,
 
156
  "cfg_strength": slat_guidance_strength,
157
  },
158
  )
159
+ logging.info(f"[{session_hash}] Generaci贸n 3D completada. Renderizando video...")
160
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
161
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
162
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
 
164
  imageio.mimsave(video_path, video, fps=15)
165
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
166
  torch.cuda.empty_cache()
167
+ logging.info(f"[{session_hash}] Video renderizado y estado empaquetado. Devolviendo: {video_path}")
168
  return state, video_path
169
 
170
  @spaces.GPU(duration=90)
 
174
  texture_size: int,
175
  req: gr.Request,
176
  ) -> Tuple[str, str]:
177
+ session_hash = str(req.session_hash)
178
+ logging.info(f"[{session_hash}] Iniciando extract_glb...")
179
+ user_dir = os.path.join(TMP_DIR, session_hash)
180
  gs, mesh = unpack_state(state)
181
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
182
  glb_path = os.path.join(user_dir, 'sample.glb')
183
  glb.export(glb_path)
184
  torch.cuda.empty_cache()
185
+ logging.info(f"[{session_hash}] GLB extra铆do. Devolviendo: {glb_path}")
186
  return glb_path, glb_path
187
 
188
  # Interfaz Gradio
 
220
 
221
  with gr.Row():
222
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
223
+
 
224
  with gr.Accordion("3D Generation Settings", open=False):
225
  gr.Markdown("Stage 1: Sparse Structure Generation")
226
  with gr.Row():
 
230
  with gr.Row():
231
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
232
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
233
+
 
234
  with gr.Accordion("GLB Extraction Settings", open=False):
235
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
236
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
237
 
238
  output_buf = gr.State()
239
+
 
240
  demo.load(start_session)
241
  demo.unload(end_session)
242