xizaoqu
commited on
Commit
ยท
cef86dc
1
Parent(s):
db555c7
update
Browse files
app.py
CHANGED
|
@@ -71,10 +71,10 @@ KEY_TO_ACTION = {
|
|
| 71 |
}
|
| 72 |
|
| 73 |
example_images = [
|
| 74 |
-
["1", "assets/ice_plains.png", "turn
|
| 75 |
-
["2", "assets/place.png", "put item
|
| 76 |
-
["3", "assets/rain_sunflower_plains.png", "turn right
|
| 77 |
-
["4", "assets/desert.png", "turn 360 degree
|
| 78 |
]
|
| 79 |
|
| 80 |
def load_custom_checkpoint(algo, checkpoint_path):
|
|
@@ -264,10 +264,18 @@ def generate(keys, input_history, memory_frames, self_frames, self_actions, self
|
|
| 264 |
|
| 265 |
memory_frames = np.concatenate([memory_frames, new_frame[:,0]])
|
| 266 |
|
| 267 |
-
|
|
|
|
| 268 |
out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
|
| 269 |
out_video = (out_video * 255).astype(np.uint8)
|
| 270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
| 272 |
save_video(out_video, temporal_video_path)
|
| 273 |
input_history += keys
|
|
@@ -289,7 +297,7 @@ def generate(keys, input_history, memory_frames, self_frames, self_actions, self
|
|
| 289 |
|
| 290 |
# np.savez(os.path.join(folder_path, "data_bundle.npz"), **data_dict)
|
| 291 |
|
| 292 |
-
return
|
| 293 |
|
| 294 |
def reset(selected_image):
|
| 295 |
self_frames = None
|
|
@@ -381,6 +389,24 @@ with gr.Blocks(css=css) as demo:
|
|
| 381 |
"""
|
| 382 |
)
|
| 383 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
# <div style="text-align: center;">
|
| 385 |
# <!-- Public Website -->
|
| 386 |
# <a style="display:inline-block" href="https://nirvanalan.github.io/projects/GA/">
|
|
@@ -403,25 +429,50 @@ with gr.Blocks(css=css) as demo:
|
|
| 403 |
# </a>
|
| 404 |
# </div>
|
| 405 |
|
| 406 |
-
example_actions = {"turn left
|
| 407 |
"turn 360 degree": "AAAAAAAAAAAAAAAAAAAAAAAA",
|
| 408 |
-
"turn right
|
| 409 |
-
"turn right
|
| 410 |
-
"turn right
|
| 411 |
-
"put item
|
| 412 |
|
| 413 |
selected_image = gr.State(ICE_PLAINS_IMAGE)
|
| 414 |
|
| 415 |
with gr.Row(variant="panel"):
|
| 416 |
-
|
| 417 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
|
| 419 |
|
| 420 |
with gr.Row(variant="panel"):
|
| 421 |
with gr.Column(scale=2):
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
with gr.Row():
|
| 426 |
buttons = []
|
| 427 |
for action_key in list(example_actions.keys())[:2]:
|
|
@@ -437,11 +488,28 @@ with gr.Blocks(css=css) as demo:
|
|
| 437 |
buttons.append(gr.Button(action_key))
|
| 438 |
|
| 439 |
with gr.Column(scale=1):
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
|
| 446 |
sampling_timesteps_state = gr.State(worldmem.sampling_timesteps)
|
| 447 |
sampling_context_length_state = gr.State(worldmem.n_tokens)
|
|
@@ -457,24 +525,12 @@ with gr.Blocks(css=css) as demo:
|
|
| 457 |
def set_action(action):
|
| 458 |
return action
|
| 459 |
|
| 460 |
-
# gr.Markdown("### Action sequence examples.")
|
| 461 |
|
| 462 |
|
| 463 |
for button, action_key in zip(buttons, list(example_actions.keys())):
|
| 464 |
button.click(set_action, inputs=[gr.State(value=example_actions[action_key])], outputs=input_box)
|
| 465 |
|
| 466 |
-
|
| 467 |
-
gr.Markdown("### Click on the images below to reset the sequence and generate from the new image.")
|
| 468 |
-
|
| 469 |
-
with gr.Row():
|
| 470 |
-
image_display_1 = gr.Image(value=SUNFLOWERS_IMAGE, interactive=False, label="Sunflower Plains")
|
| 471 |
-
image_display_2 = gr.Image(value=DESERT_IMAGE, interactive=False, label="Desert")
|
| 472 |
-
image_display_3 = gr.Image(value=SAVANNA_IMAGE, interactive=False, label="Savanna")
|
| 473 |
-
image_display_4 = gr.Image(value=ICE_PLAINS_IMAGE, interactive=False, label="Ice Plains")
|
| 474 |
-
image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains")
|
| 475 |
-
image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place")
|
| 476 |
-
|
| 477 |
-
gr.Markdown("### Click the examples below for a quick review, and continue generating based on them.")
|
| 478 |
|
| 479 |
example_case = gr.Textbox(label="Case", visible=False)
|
| 480 |
image_output = gr.Image(visible=False)
|
|
@@ -499,29 +555,6 @@ with gr.Blocks(css=css) as demo:
|
|
| 499 |
)
|
| 500 |
|
| 501 |
|
| 502 |
-
|
| 503 |
-
gr.Markdown(
|
| 504 |
-
"""
|
| 505 |
-
## Instructions & Notes:
|
| 506 |
-
|
| 507 |
-
1. Enter an action sequence in the **"Action Sequence"** text box and click **"Generate"** to begin.
|
| 508 |
-
2. You can continue generation by clicking **"Generation"** again and again. Previous sequences are logged in the history panel.
|
| 509 |
-
3. Click **"Reset"** to clear the current sequence and start fresh.
|
| 510 |
-
4. Action sequences can be composed using the following keys:
|
| 511 |
-
- W: turn up
|
| 512 |
-
- S: turn down
|
| 513 |
-
- A: turn left
|
| 514 |
-
- D: turn right
|
| 515 |
-
- Q: move forward
|
| 516 |
-
- E: move backward
|
| 517 |
-
- N: no-op (do nothing)
|
| 518 |
-
- U: use item
|
| 519 |
-
5. Higher denoising steps produce more detailed results but take longer. 20 steps is a good balance between quality and speed. The same applies to context and memory length.
|
| 520 |
-
6. For faster performance, we recommend running the demo locally (~1s/frame on H100 vs ~5s on Spaces).
|
| 521 |
-
7. If you find this project interesting or useful, please consider giving it a โญ๏ธ on [GitHub]()!
|
| 522 |
-
8. For feedback or suggestions, feel free to open a GitHub issue or contact me directly at **zeqixiao1@gmail.com**.
|
| 523 |
-
"""
|
| 524 |
-
)
|
| 525 |
# input_box.submit(update_image_and_log, inputs=[input_box], outputs=[image_display, video_display, log_output])
|
| 526 |
submit_button.click(generate, inputs=[input_box, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx], outputs=[image_display, video_display, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
| 527 |
reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
|
|
|
| 71 |
}
|
| 72 |
|
| 73 |
example_images = [
|
| 74 |
+
["1", "assets/ice_plains.png", "turn rightgo backwardโlook upโturn leftโlook downโturn rightโgo forwardโturn left", 20, 3, 8],
|
| 75 |
+
["2", "assets/place.png", "put itemโgo backwardโput itemโgo backwardโgo around", 20, 3, 8],
|
| 76 |
+
["3", "assets/rain_sunflower_plains.png", "turn rightโlook upโturn rightโlook downโturn leftโgo backwardโturn left", 20, 3, 8],
|
| 77 |
+
["4", "assets/desert.png", "turn 360 degreeโturn rightโgo forwardโturn left", 20, 3, 8],
|
| 78 |
]
|
| 79 |
|
| 80 |
def load_custom_checkpoint(algo, checkpoint_path):
|
|
|
|
| 264 |
|
| 265 |
memory_frames = np.concatenate([memory_frames, new_frame[:,0]])
|
| 266 |
|
| 267 |
+
|
| 268 |
+
out_video = memory_frames.transpose(0,2,3,1).copy()
|
| 269 |
out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
|
| 270 |
out_video = (out_video * 255).astype(np.uint8)
|
| 271 |
|
| 272 |
+
last_frame = out_video[-1].copy()
|
| 273 |
+
border_thickness = 2
|
| 274 |
+
out_video[-len(new_frame):, :border_thickness, :, :] = [255, 0, 0]
|
| 275 |
+
out_video[-len(new_frame):, -border_thickness:, :, :] = [255, 0, 0]
|
| 276 |
+
out_video[-len(new_frame):, :, :border_thickness, :] = [255, 0, 0]
|
| 277 |
+
out_video[-len(new_frame):, :, -border_thickness:, :] = [255, 0, 0]
|
| 278 |
+
|
| 279 |
temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
| 280 |
save_video(out_video, temporal_video_path)
|
| 281 |
input_history += keys
|
|
|
|
| 297 |
|
| 298 |
# np.savez(os.path.join(folder_path, "data_bundle.npz"), **data_dict)
|
| 299 |
|
| 300 |
+
return last_frame, temporal_video_path, input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
|
| 301 |
|
| 302 |
def reset(selected_image):
|
| 303 |
self_frames = None
|
|
|
|
| 389 |
"""
|
| 390 |
)
|
| 391 |
|
| 392 |
+
gr.Markdown(
|
| 393 |
+
"""
|
| 394 |
+
## ๐ How to Explore WorldMem
|
| 395 |
+
|
| 396 |
+
Follow these simple steps to get started:
|
| 397 |
+
|
| 398 |
+
1. **Choose a scene**.
|
| 399 |
+
2. **Input your action sequence**.
|
| 400 |
+
3. **Click "Generate"**.
|
| 401 |
+
|
| 402 |
+
- You can continuously click **"Generate"** to **extend the video** and observe how well the world maintains consistency over time.
|
| 403 |
+
- For best performance, we recommend **running locally** (1s/frame on H100) instead of Spaces (5s/frame).
|
| 404 |
+
- โญ๏ธ If you like this project, please [give it a star on GitHub]()!
|
| 405 |
+
- ๐ฌ For questions or feedback, feel free to open an issue or email me at **zeqixiao1@gmail.com**.
|
| 406 |
+
|
| 407 |
+
Happy exploring! ๐
|
| 408 |
+
"""
|
| 409 |
+
)
|
| 410 |
# <div style="text-align: center;">
|
| 411 |
# <!-- Public Website -->
|
| 412 |
# <a style="display:inline-block" href="https://nirvanalan.github.io/projects/GA/">
|
|
|
|
| 429 |
# </a>
|
| 430 |
# </div>
|
| 431 |
|
| 432 |
+
example_actions = {"turn leftโturn right": "AAAAAAAAAAAADDDDDDDDDDDD",
|
| 433 |
"turn 360 degree": "AAAAAAAAAAAAAAAAAAAAAAAA",
|
| 434 |
+
"turn rightโgo backwardโlook upโturn leftโlook down": "DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW",
|
| 435 |
+
"turn rightโgo forwardโturn right": "DDDDDDDDDDDDQQQQQQQQQQQQQQQDDDDDDDDDDDD",
|
| 436 |
+
"turn rightโlook upโturn rightโlook down": "DDDDWWWDDDDDDDDDDDDDDDDDDDDSSS",
|
| 437 |
+
"put itemโgo backwardโput itemโgo backward":"SSUNNWWEEEEEEEEEAAASSUNNWWEEEEEEEEE"}
|
| 438 |
|
| 439 |
selected_image = gr.State(ICE_PLAINS_IMAGE)
|
| 440 |
|
| 441 |
with gr.Row(variant="panel"):
|
| 442 |
+
with gr.Column():
|
| 443 |
+
gr.Markdown("๐ผ๏ธ Start from this frame.")
|
| 444 |
+
image_display = gr.Image(value=selected_image.value, interactive=False, label="Current Frame")
|
| 445 |
+
with gr.Column():
|
| 446 |
+
gr.Markdown("๐๏ธ Generated videos. New contents are marked in red box.")
|
| 447 |
+
video_display = gr.Video(autoplay=True, loop=True)
|
| 448 |
+
|
| 449 |
+
gr.Markdown("### ๐๏ธ Choose a scene and start generation.")
|
| 450 |
+
|
| 451 |
+
with gr.Row():
|
| 452 |
+
image_display_1 = gr.Image(value=SUNFLOWERS_IMAGE, interactive=False, label="Sunflower Plains")
|
| 453 |
+
image_display_2 = gr.Image(value=DESERT_IMAGE, interactive=False, label="Desert")
|
| 454 |
+
image_display_3 = gr.Image(value=SAVANNA_IMAGE, interactive=False, label="Savanna")
|
| 455 |
+
image_display_4 = gr.Image(value=ICE_PLAINS_IMAGE, interactive=False, label="Ice Plains")
|
| 456 |
+
image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains")
|
| 457 |
+
image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place")
|
| 458 |
|
| 459 |
|
| 460 |
with gr.Row(variant="panel"):
|
| 461 |
with gr.Column(scale=2):
|
| 462 |
+
gr.Markdown("### ๐น๏ธ Input action sequences for interaction.")
|
| 463 |
+
input_box = gr.Textbox(label="Action Sequences", placeholder="Enter action sequences here, e.g. (AAAAAAAAAAAADDDDDDDDDDDD)", lines=1, max_lines=1)
|
| 464 |
+
log_output = gr.Textbox(label="History Sequences", interactive=False)
|
| 465 |
+
gr.Markdown(
|
| 466 |
+
"""
|
| 467 |
+
### ๐ก Action Key Guide
|
| 468 |
+
|
| 469 |
+
<pre style="font-family: monospace; font-size: 14px; line-height: 1.6;">
|
| 470 |
+
W: Turn up S: Turn down A: Turn left D: Turn right
|
| 471 |
+
Q: Go forward E: Go backward N: No-op U: Use item
|
| 472 |
+
</pre>
|
| 473 |
+
"""
|
| 474 |
+
)
|
| 475 |
+
gr.Markdown("### ๐ Click to quickly set action sequence examples.")
|
| 476 |
with gr.Row():
|
| 477 |
buttons = []
|
| 478 |
for action_key in list(example_actions.keys())[:2]:
|
|
|
|
| 488 |
buttons.append(gr.Button(action_key))
|
| 489 |
|
| 490 |
with gr.Column(scale=1):
|
| 491 |
+
submit_button = gr.Button("๐ฌ Generate!", variant="primary")
|
| 492 |
+
reset_btn = gr.Button("๐ Reset")
|
| 493 |
+
gr.Markdown("<div style='flex-grow:1; height: 100px'></div>")
|
| 494 |
+
|
| 495 |
+
gr.Markdown("### โ๏ธ Advanced Settings")
|
| 496 |
+
|
| 497 |
+
slider_denoising_step = gr.Slider(
|
| 498 |
+
minimum=10, maximum=50, value=worldmem.sampling_timesteps, step=1,
|
| 499 |
+
label="Denoising Steps",
|
| 500 |
+
info="Higher values yield better quality but slower speed"
|
| 501 |
+
)
|
| 502 |
+
slider_context_length = gr.Slider(
|
| 503 |
+
minimum=2, maximum=10, value=worldmem.n_tokens, step=1,
|
| 504 |
+
label="Context Length",
|
| 505 |
+
info="How many previous frames in temporal context window."
|
| 506 |
+
)
|
| 507 |
+
slider_memory_length = gr.Slider(
|
| 508 |
+
minimum=4, maximum=16, value=worldmem.condition_similar_length, step=1,
|
| 509 |
+
label="Memory Length",
|
| 510 |
+
info="How many previous frames in memory window."
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
|
| 514 |
sampling_timesteps_state = gr.State(worldmem.sampling_timesteps)
|
| 515 |
sampling_context_length_state = gr.State(worldmem.n_tokens)
|
|
|
|
| 525 |
def set_action(action):
|
| 526 |
return action
|
| 527 |
|
|
|
|
| 528 |
|
| 529 |
|
| 530 |
for button, action_key in zip(buttons, list(example_actions.keys())):
|
| 531 |
button.click(set_action, inputs=[gr.State(value=example_actions[action_key])], outputs=input_box)
|
| 532 |
|
| 533 |
+
gr.Markdown("### ๐ Click to review generated examples, and continue generation based on them.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 534 |
|
| 535 |
example_case = gr.Textbox(label="Case", visible=False)
|
| 536 |
image_output = gr.Image(visible=False)
|
|
|
|
| 555 |
)
|
| 556 |
|
| 557 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 558 |
# input_box.submit(update_image_and_log, inputs=[input_box], outputs=[image_display, video_display, log_output])
|
| 559 |
submit_button.click(generate, inputs=[input_box, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx], outputs=[image_display, video_display, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
| 560 |
reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|