Stable-X commited on
Commit
eb02b04
·
verified ·
1 Parent(s): 4eb0f43

fix some bug in RVC

Browse files
trellis/pipelines/samplers/flow_euler.py CHANGED
@@ -183,6 +183,8 @@ class FlowEulerSampler(Sampler):
183
  model,
184
  slat_decoder_gs,
185
  slat_decoder_mesh,
 
 
186
  dreamsim_model,
187
  learning_rate,
188
  input_images,
@@ -222,10 +224,10 @@ class FlowEulerSampler(Sampler):
222
  for step in range(total_steps):
223
  optimizer.zero_grad()
224
  pred_x_0, _ = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v_opt)
225
- pred_gs = slat_decoder_gs(pred_x_0)
226
- # pred_mesh = slat_decoder_mesh(pred_x_0)
227
  rend_gs = render_utils.render_frames(pred_gs[0], extrinsics, intrinsics, {'resolution': 259, 'bg_color': (0, 0, 0)}, need_depth=True, opt=True)['color']
228
- # rend_mesh = render_utils.render_frames_opt(pred_mesh[0], extrinsics, intrinsics, {'resolution': 518, 'bg_color': (0, 0, 0)}, need_depth=True)['color']
229
  rend_gs = torch.stack(rend_gs, dim=0)
230
  loss_gs = loss_utils.l1_loss(rend_gs, input_images, size_average=False).mean(dim=(1,2,3)) + \
231
  (1 - loss_utils.ssim(rend_gs, input_images, size_average=False)) + \
@@ -345,6 +347,8 @@ class FlowEulerSampler(Sampler):
345
  model,
346
  slat_decoder_gs,
347
  slat_decoder_mesh,
 
 
348
  dreamsim_model,
349
  apperance_learning_rate,
350
  start_t,
@@ -392,7 +396,7 @@ class FlowEulerSampler(Sampler):
392
  else:
393
  # learning_rate = cosine_anealing(i - int(np.where(t_seq <= start_t)[0].min()), int(steps - np.where(t_seq <= start_t)[0].min()), apperance_learning_rate, 1e-5)
394
  learning_rate = apperance_learning_rate
395
- out = self.sample_slat_once_opt_delta_v(model, slat_decoder_gs, slat_decoder_mesh, dreamsim_model, learning_rate, input_images, extrinsics, intrinsics, sample, t, t_prev, cond, **kwargs)
396
  sample = out.pred_x_prev
397
  ret.pred_x_t.append(out.pred_x_prev)
398
  ret.pred_x_0.append(out.pred_x_0)
@@ -865,6 +869,8 @@ class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSa
865
  model,
866
  slat_decoder_gs,
867
  slat_decoder_mesh,
 
 
868
  dreamsim_model,
869
  apperance_learning_rate,
870
  start_t,
@@ -902,7 +908,7 @@ class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSa
902
  - 'pred_x_t': a list of prediction of x_t.
903
  - 'pred_x_0': a list of prediction of x_0.
904
  """
905
- return super().sample_slat_opt_delta_v(model, slat_decoder_gs, slat_decoder_mesh, dreamsim_model, apperance_learning_rate, start_t, input_images, extrinsics, intrinsics,noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs)
906
 
907
 
908
  class LatentMatchGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, LatentMatchSampler):
 
183
  model,
184
  slat_decoder_gs,
185
  slat_decoder_mesh,
186
+ std,
187
+ mean,
188
  dreamsim_model,
189
  learning_rate,
190
  input_images,
 
224
  for step in range(total_steps):
225
  optimizer.zero_grad()
226
  pred_x_0, _ = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v_opt)
227
+ pred_gs = slat_decoder_gs(pred_x_0 * std + mean)
228
+ # pred_mesh = slat_decoder_mesh(pred_x_0 * std + mean)
229
  rend_gs = render_utils.render_frames(pred_gs[0], extrinsics, intrinsics, {'resolution': 259, 'bg_color': (0, 0, 0)}, need_depth=True, opt=True)['color']
230
+ # rend_mesh = render_utils.render_frames_opt(pred_mesh[0], extrinsics, intrinsics, {'resolution': 518, 'bg_color': (0, 0, 0)}, need_depth=True, opt=True)['color']
231
  rend_gs = torch.stack(rend_gs, dim=0)
232
  loss_gs = loss_utils.l1_loss(rend_gs, input_images, size_average=False).mean(dim=(1,2,3)) + \
233
  (1 - loss_utils.ssim(rend_gs, input_images, size_average=False)) + \
 
347
  model,
348
  slat_decoder_gs,
349
  slat_decoder_mesh,
350
+ std,
351
+ mean,
352
  dreamsim_model,
353
  apperance_learning_rate,
354
  start_t,
 
396
  else:
397
  # learning_rate = cosine_anealing(i - int(np.where(t_seq <= start_t)[0].min()), int(steps - np.where(t_seq <= start_t)[0].min()), apperance_learning_rate, 1e-5)
398
  learning_rate = apperance_learning_rate
399
+ out = self.sample_slat_once_opt_delta_v(model, slat_decoder_gs, slat_decoder_mesh, std, mean, dreamsim_model, learning_rate, input_images, extrinsics, intrinsics, sample, t, t_prev, cond, **kwargs)
400
  sample = out.pred_x_prev
401
  ret.pred_x_t.append(out.pred_x_prev)
402
  ret.pred_x_0.append(out.pred_x_0)
 
869
  model,
870
  slat_decoder_gs,
871
  slat_decoder_mesh,
872
+ std,
873
+ mean,
874
  dreamsim_model,
875
  apperance_learning_rate,
876
  start_t,
 
908
  - 'pred_x_t': a list of prediction of x_t.
909
  - 'pred_x_0': a list of prediction of x_0.
910
  """
911
+ return super().sample_slat_opt_delta_v(model, slat_decoder_gs, slat_decoder_mesh, std, mean, dreamsim_model, apperance_learning_rate, start_t, input_images, extrinsics, intrinsics,noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs)
912
 
913
 
914
  class LatentMatchGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, LatentMatchSampler):