Spaces:
Running
on
Zero
Running
on
Zero
fix some bug in RVC
Browse files
trellis/pipelines/samplers/flow_euler.py
CHANGED
|
@@ -183,6 +183,8 @@ class FlowEulerSampler(Sampler):
|
|
| 183 |
model,
|
| 184 |
slat_decoder_gs,
|
| 185 |
slat_decoder_mesh,
|
|
|
|
|
|
|
| 186 |
dreamsim_model,
|
| 187 |
learning_rate,
|
| 188 |
input_images,
|
|
@@ -222,10 +224,10 @@ class FlowEulerSampler(Sampler):
|
|
| 222 |
for step in range(total_steps):
|
| 223 |
optimizer.zero_grad()
|
| 224 |
pred_x_0, _ = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v_opt)
|
| 225 |
-
pred_gs = slat_decoder_gs(pred_x_0)
|
| 226 |
-
# pred_mesh = slat_decoder_mesh(pred_x_0)
|
| 227 |
rend_gs = render_utils.render_frames(pred_gs[0], extrinsics, intrinsics, {'resolution': 259, 'bg_color': (0, 0, 0)}, need_depth=True, opt=True)['color']
|
| 228 |
-
# rend_mesh = render_utils.render_frames_opt(pred_mesh[0], extrinsics, intrinsics, {'resolution': 518, 'bg_color': (0, 0, 0)}, need_depth=True)['color']
|
| 229 |
rend_gs = torch.stack(rend_gs, dim=0)
|
| 230 |
loss_gs = loss_utils.l1_loss(rend_gs, input_images, size_average=False).mean(dim=(1,2,3)) + \
|
| 231 |
(1 - loss_utils.ssim(rend_gs, input_images, size_average=False)) + \
|
|
@@ -345,6 +347,8 @@ class FlowEulerSampler(Sampler):
|
|
| 345 |
model,
|
| 346 |
slat_decoder_gs,
|
| 347 |
slat_decoder_mesh,
|
|
|
|
|
|
|
| 348 |
dreamsim_model,
|
| 349 |
apperance_learning_rate,
|
| 350 |
start_t,
|
|
@@ -392,7 +396,7 @@ class FlowEulerSampler(Sampler):
|
|
| 392 |
else:
|
| 393 |
# learning_rate = cosine_anealing(i - int(np.where(t_seq <= start_t)[0].min()), int(steps - np.where(t_seq <= start_t)[0].min()), apperance_learning_rate, 1e-5)
|
| 394 |
learning_rate = apperance_learning_rate
|
| 395 |
-
out = self.sample_slat_once_opt_delta_v(model, slat_decoder_gs, slat_decoder_mesh, dreamsim_model, learning_rate, input_images, extrinsics, intrinsics, sample, t, t_prev, cond, **kwargs)
|
| 396 |
sample = out.pred_x_prev
|
| 397 |
ret.pred_x_t.append(out.pred_x_prev)
|
| 398 |
ret.pred_x_0.append(out.pred_x_0)
|
|
@@ -865,6 +869,8 @@ class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSa
|
|
| 865 |
model,
|
| 866 |
slat_decoder_gs,
|
| 867 |
slat_decoder_mesh,
|
|
|
|
|
|
|
| 868 |
dreamsim_model,
|
| 869 |
apperance_learning_rate,
|
| 870 |
start_t,
|
|
@@ -902,7 +908,7 @@ class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSa
|
|
| 902 |
- 'pred_x_t': a list of prediction of x_t.
|
| 903 |
- 'pred_x_0': a list of prediction of x_0.
|
| 904 |
"""
|
| 905 |
-
return super().sample_slat_opt_delta_v(model, slat_decoder_gs, slat_decoder_mesh, dreamsim_model, apperance_learning_rate, start_t, input_images, extrinsics, intrinsics,noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs)
|
| 906 |
|
| 907 |
|
| 908 |
class LatentMatchGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, LatentMatchSampler):
|
|
|
|
| 183 |
model,
|
| 184 |
slat_decoder_gs,
|
| 185 |
slat_decoder_mesh,
|
| 186 |
+
std,
|
| 187 |
+
mean,
|
| 188 |
dreamsim_model,
|
| 189 |
learning_rate,
|
| 190 |
input_images,
|
|
|
|
| 224 |
for step in range(total_steps):
|
| 225 |
optimizer.zero_grad()
|
| 226 |
pred_x_0, _ = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v_opt)
|
| 227 |
+
pred_gs = slat_decoder_gs(pred_x_0 * std + mean)
|
| 228 |
+
# pred_mesh = slat_decoder_mesh(pred_x_0 * std + mean)
|
| 229 |
rend_gs = render_utils.render_frames(pred_gs[0], extrinsics, intrinsics, {'resolution': 259, 'bg_color': (0, 0, 0)}, need_depth=True, opt=True)['color']
|
| 230 |
+
# rend_mesh = render_utils.render_frames_opt(pred_mesh[0], extrinsics, intrinsics, {'resolution': 518, 'bg_color': (0, 0, 0)}, need_depth=True, opt=True)['color']
|
| 231 |
rend_gs = torch.stack(rend_gs, dim=0)
|
| 232 |
loss_gs = loss_utils.l1_loss(rend_gs, input_images, size_average=False).mean(dim=(1,2,3)) + \
|
| 233 |
(1 - loss_utils.ssim(rend_gs, input_images, size_average=False)) + \
|
|
|
|
| 347 |
model,
|
| 348 |
slat_decoder_gs,
|
| 349 |
slat_decoder_mesh,
|
| 350 |
+
std,
|
| 351 |
+
mean,
|
| 352 |
dreamsim_model,
|
| 353 |
apperance_learning_rate,
|
| 354 |
start_t,
|
|
|
|
| 396 |
else:
|
| 397 |
# learning_rate = cosine_anealing(i - int(np.where(t_seq <= start_t)[0].min()), int(steps - np.where(t_seq <= start_t)[0].min()), apperance_learning_rate, 1e-5)
|
| 398 |
learning_rate = apperance_learning_rate
|
| 399 |
+
out = self.sample_slat_once_opt_delta_v(model, slat_decoder_gs, slat_decoder_mesh, std, mean, dreamsim_model, learning_rate, input_images, extrinsics, intrinsics, sample, t, t_prev, cond, **kwargs)
|
| 400 |
sample = out.pred_x_prev
|
| 401 |
ret.pred_x_t.append(out.pred_x_prev)
|
| 402 |
ret.pred_x_0.append(out.pred_x_0)
|
|
|
|
| 869 |
model,
|
| 870 |
slat_decoder_gs,
|
| 871 |
slat_decoder_mesh,
|
| 872 |
+
std,
|
| 873 |
+
mean,
|
| 874 |
dreamsim_model,
|
| 875 |
apperance_learning_rate,
|
| 876 |
start_t,
|
|
|
|
| 908 |
- 'pred_x_t': a list of prediction of x_t.
|
| 909 |
- 'pred_x_0': a list of prediction of x_0.
|
| 910 |
"""
|
| 911 |
+
return super().sample_slat_opt_delta_v(model, slat_decoder_gs, slat_decoder_mesh, std, mean, dreamsim_model, apperance_learning_rate, start_t, input_images, extrinsics, intrinsics,noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs)
|
| 912 |
|
| 913 |
|
| 914 |
class LatentMatchGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, LatentMatchSampler):
|