diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..d098c62b13925420b12e93df1f4e2ba1acf3ac80 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,40 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text +*.jpeg filter=lfs diff=lfs merge=lfs -text +*.mp3 filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..08610c946ea133b4c49d12337e2abbef03f24766 --- /dev/null +++ b/.gitignore @@ -0,0 +1,38 @@ +*.idea/ +*.vscode/ + +*__pycache__/ +*.pyc +*.pyo +*.pyd + +*.pt +*.pth +*.tar.gz +*.zip + +!**/train/ +**/train/* +!**/train/saved + +!**/inference/ +**/inference/* +!**/inference/saved + +!**/maps/ +**/maps/* +!**/maps/example +!**/maps/gpt4o +!**/maps/lisa + +# For taxabind_avs +**/dataset/ +**/checkpoints/ + +!**/lightning_logs/ +**/lightning_logs/* +!**/lightning_logs/saved + +# Saved weights & logs +**avs_rl_policy.pth +**/avs_rl_policy_21.5k/* \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000000000000000000000000000000000000..6a1736dd3a13900ed66d69cb586a8fc6ce484e4f --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,15 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Debug app.py", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/app.py", + "cwd": "${workspaceFolder}", + "console": "integratedTerminal", + "justMyCode": false, + "python": "/home/user/anaconda3/envs/vlm-search/bin/python3" + } + ] +} \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f11219855e4a44e0f24ebd0d4e0da2f529774791 --- /dev/null +++ b/README.md @@ -0,0 +1,13 @@ +--- +title: Search-TTA +emoji: 🦁 +colorFrom: green +colorTo: gray +sdk: gradio +sdk_version: 5.31.0 +app_file: app.py +pinned: false +short_description: Multimodal Test-time Adaptation Framework for Visual Search +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..c59a521dedbac48cbd9e1f663decae60902555e3 --- /dev/null +++ b/app.py @@ -0,0 +1,528 @@ +""" +Simplified Gradio demo for Search-TTA evaluation. +""" + +# ────────────────────────── imports ─────────────────────────────────── +from pathlib import Path +import matplotlib +matplotlib.use("Agg", force=True) + +import gradio as gr +import ctypes # for safely stopping background threads +import os, glob, threading, time +import torch +from PIL import Image +import json +import shutil +import spaces # integration with ZeroGPU on hf +from planner.test_parameter import * +from planner.model import PolicyNet +from planner.test_worker import TestWorker +from taxabind_avs.satbind.clip_seg_tta import ClipSegTTA + + +# Helper to kill a Python thread by injecting SystemExit +def _stop_thread(thread: threading.Thread): + """Forcefully raise SystemExit in the given thread (best-effort).""" + if thread is None or not thread.is_alive(): + return + tid = thread.ident + if tid is None: + return + # Ask CPython to raise SystemExit in the thread context + res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), ctypes.py_object(SystemExit)) + if res > 1: + # If it returned >1, cleanup and fail safe + ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), None) + +# ──────────── Thread Registry for Cleanup on Tab Switch ───────────── +_running_threads: list[threading.Thread] = [] +_running_threads_lock = threading.Lock() + +# Map worker threads to their ClipSegTTA instance so UI can read executing_tta flag +_thread_clip_map: dict[threading.Thread, ClipSegTTA] = {} + +# ──────────── Run directory rotation ───────────── +RUN_HISTORY_LIMIT = 30 # keep at most this many timestamped run directories per instance + +def _prune_old_run_dirs(base_dir: str, limit: int = RUN_HISTORY_LIMIT): + """Delete oldest timestamp-named run directories leaving only *limit* of the newest ones.""" + try: + dirs = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))] + dirs.sort() + if len(dirs) > limit: + for obsolete in dirs[:-limit]: + shutil.rmtree(os.path.join(base_dir, obsolete), ignore_errors=True) + except Exception: + pass + + +# CHANGE ME! +POLL_INTERVAL = 1.0 # For visualization + +# Prepare the model +device = torch.device('cuda') if USE_GPU and torch.cuda.is_available() else torch.device('cpu') +policy_net = PolicyNet(INPUT_DIM, EMBEDDING_DIM).to(device) +script_dir = Path(__file__).resolve().parent +print("real_script_dir: ", script_dir) +checkpoint = torch.load(f'{MODEL_PATH}/{MODEL_NAME}') +policy_net.load_state_dict(checkpoint['policy_model']) +print('Model loaded!') + +# Load metadata json +tgts_metadata_json_path = os.path.join(script_dir, "examples/metadata.json") +tgts_metadata = json.load(open(tgts_metadata_json_path)) + + +# ────────────────────────── Gradio process fn ───────────────────────── + +### integration with ZeroGPU on hf +# @spaces.GPU +def process_search_tta( + sat_path: str | None, + ground_path: str | None, + taxonomy: str | None = None, + session_threads: list[threading.Thread] | None = None, +): + """Run both TTA and non-TTA search episodes concurrently and stream both heat-maps.""" + + if session_threads is None: + session_threads = [] + + # Disable Run button and clear image/status outputs, hide sliders, clear frame states + yield ( + gr.update(interactive=False), + gr.update(value=None), + gr.update(value=None), + gr.update(value="Initializing model…", visible=True), + gr.update(value="Initializing model…", visible=True), + gr.update(visible=False), + gr.update(visible=False), + [], + [], + session_threads, + ) + + # Bail early if satellite image missing + if sat_path is None: + yield ( + gr.update(interactive=True), + gr.update(value=None), + gr.update(value=None), + gr.update(value="No satellite image provided.", visible=True), + gr.update(value="", visible=True), + gr.update(visible=False), + gr.update(visible=False), + [], + [], + session_threads, + ) + return + + # Prepare PIL images + sat_img = Image.open(sat_path).convert("RGB") + ground_img_pil = Image.open(ground_path).convert("RGB") if ground_path else None + + # Lookup target positions metadata (may be empty) + tgt_positions = [] + if taxonomy and taxonomy in tgts_metadata: + tgt_positions = [tuple(t) for t in tgts_metadata[taxonomy]["target_positions"]] + + # Helper to build a TestWorker with/without TTA + def build_planner(enable_tta: bool, save_dir: str, clip_obj): + # Lazily (re)create a ClipSegTTA instance per thread if not provided + local_clip = clip_obj + if LOAD_AVS_BENCH and local_clip is None: + local_clip = ClipSegTTA( + img_dir=AVS_IMG_DIR, + imo_dir=AVS_IMO_DIR, + json_path=AVS_INAT_JSON_PATH, + sat_to_img_ids_path=AVS_SAT_TO_IMG_IDS_PATH, + sat_checkpoint_path=AVS_SAT_CHECKPOINT_PATH, + load_pretrained_hf_ckpt=AVS_LOAD_PRETRAINED_HF_CHECKPOINT, + blur_kernel = AVS_GAUSSIAN_BLUR_KERNEL, + sample_index=-1, + device=device, + sat_to_img_ids_json_is_train_dict=False, + tax_to_filter_val=QUERY_TAX, + load_model=USE_CLIP_PREDS, + query_modality=QUERY_MODALITY, + sound_dir = AVS_SOUND_DIR, + sound_checkpoint_path=AVS_SOUND_CHECKPOINT_PATH, + ) + + if local_clip is not None: + # Feed inputs to ClipSegTTA copy + local_clip.img_paths = [ground_path] if ground_path else [] + local_clip.imo_path = sat_path + local_clip.imgs = ([local_clip.dataset.img_transform(ground_img_pil).to(device)] if ground_img_pil else []) + local_clip.imo = local_clip.dataset.imo_transform(sat_img).to(device) + local_clip.sounds = [] + local_clip.sound_ids = [] + local_clip.species_name = taxonomy or "" + local_clip.gt_mask_name = taxonomy.replace(" ", "_") if taxonomy else "" + local_clip.target_positions = tgt_positions if tgt_positions else [(0, 0)] + + planner = TestWorker( + meta_agent_id=0, + n_agent=1, + policy_net=policy_net, + global_step=-1, + device=device, + greedy=True, + save_image=SAVE_GIFS, + clip_seg_tta=local_clip, + ) + planner.execute_tta = enable_tta + planner.gifs_path = save_dir + return planner + + # ────────────── Per-run output directories ────────────── + # Ensure base directory exists + os.makedirs(GIFS_PATH, exist_ok=True) + + run_id = time.strftime("%Y%m%d_%H%M%S") # unique timestamp + run_root = os.path.join(GIFS_PATH, run_id) + gifs_dir_tta = os.path.join(run_root, "with_tta") + gifs_dir_no = os.path.join(run_root, "no_tta") + + os.makedirs(gifs_dir_tta, exist_ok=True) + os.makedirs(gifs_dir_no, exist_ok=True) + + # House-keep old runs so we never keep more than RUN_HISTORY_LIMIT + _prune_old_run_dirs(GIFS_PATH, RUN_HISTORY_LIMIT) + + # Shared dict to record if a thread hit an exception + error_flags = {"tta": False, "no": False} + + def _planner_thread(enable_tta: bool, save_dir: str, clip_obj, key: str): + """Prepare directory, build planner, run an episode, record errors.""" + try: + planner = build_planner(enable_tta, save_dir, clip_obj) + _thread_clip_map[threading.current_thread()] = planner.clip_seg_tta + planner.run_episode(0) + except Exception as exc: + # Mark that this planner crashed so UI can show an error status + error_flags[key] = True + # Log full traceback so developers can debug via console logs + import traceback, sys + traceback.print_exc() + # Still exit the thread + return + + # Launch both planners in background threads – preparation included + thread_tta = threading.Thread( + target=_planner_thread, + args=(True, gifs_dir_tta, None, "tta"), + daemon=True, + ) + thread_no = threading.Thread( + target=_planner_thread, + args=(False, gifs_dir_no, None, "no"), + daemon=True, + ) + # Track threads for this user session + session_threads.extend([thread_tta, thread_no]) + thread_tta.start() + thread_no.start() + + + sent_tta: set[str] = set() + sent_no: set[str] = set() + last_tta = None + last_no = None + # Track previous status strings so we can emit updates when only the + # status (Running…/Done.) changes even if no new frame was produced. + # Previous status values so we can detect changes and yield updates + prev_status_tta = "Initializing model…" + prev_status_no = "Initializing model…" + + try: + while thread_tta.is_alive() or thread_no.is_alive(): + updated = False + # Collect new frames from TTA dir + pngs = glob.glob(os.path.join(gifs_dir_tta, "*.png")) + pngs.sort(key=lambda p: int(os.path.splitext(os.path.basename(p))[0])) + for fp in pngs: + if fp not in sent_tta: + # Ensure file is fully written (non-empty & readable) + try: + if os.path.getsize(fp) == 0: + continue + with open(fp, "rb") as fh: + fh.read(1) + except Exception: + # Skip this round; we'll retry next poll + continue + sent_tta.add(fp) + last_tta = fp + updated = True + # Collect new frames from no-TTA dir + pngs = glob.glob(os.path.join(gifs_dir_no, "*.png")) + pngs.sort(key=lambda p: int(os.path.splitext(os.path.basename(p))[0])) + for fp in pngs: + if fp not in sent_no: + try: + if os.path.getsize(fp) == 0: + continue + with open(fp, "rb") as fh: + fh.read(1) + except Exception: + continue + sent_no.add(fp) + last_no = fp + updated = True + + # Determine status based on whether we already have a frame and whether + # the corresponding thread is still alive. + def _mk_status(last_frame, thread_alive, errored: bool, running_tta: bool=False): + if errored: + return "Error!" + if last_frame is None: + return "Initializing model…" + if not thread_alive: + return "Done." + return "Executing TTA (Scheduling GPUs)…" if running_tta else "Executing Planner…" + + exec_tta_flag = False + if thread_tta.is_alive(): + clip_obj = _thread_clip_map.get(thread_tta) + if clip_obj is not None and getattr(clip_obj, "executing_tta", False): + exec_tta_flag = True + + status_tta = _mk_status(last_tta, thread_tta.is_alive(), error_flags["tta"], exec_tta_flag) + status_no = _mk_status(last_no, thread_no.is_alive(), error_flags["no"], False) + + # Determine if we should reveal sliders (once corresponding thread has finished) + show_slider_tta = (not thread_tta.is_alive()) and (last_tta is not None) + show_slider_no = (not thread_no.is_alive()) and (last_no is not None) + + # Build slider updates + slider_tta_upd = gr.update() + slider_no_upd = gr.update() + frames_tta_upd = gr.update() + frames_no_upd = gr.update() + + if show_slider_tta: + n_tta_frames = max(len(sent_tta), 1) + slider_tta_upd = gr.update(visible=True, minimum=1, maximum=n_tta_frames, value=n_tta_frames) + frames_tta_upd = sorted(sent_tta, key=lambda p: int(os.path.splitext(os.path.basename(p))[0])) + if show_slider_no: + n_no_frames = max(len(sent_no), 1) + slider_no_upd = gr.update(visible=True, minimum=1, maximum=n_no_frames, value=n_no_frames) + frames_no_upd = sorted(sent_no, key=lambda p: int(os.path.splitext(os.path.basename(p))[0])) + + # Emit update if we have a new frame OR status changed OR slider visibility changed + if ( + updated + or status_tta != prev_status_tta + or status_no != prev_status_no + or show_slider_tta + or show_slider_no + ): + yield ( + gr.update(interactive=False), + last_tta, + last_no, + gr.update(value=status_tta, visible=True), + gr.update(value=status_no, visible=True), + slider_tta_upd, + slider_no_upd, + frames_tta_upd, + frames_no_upd, + session_threads, + ) + + prev_status_tta = status_tta + prev_status_no = status_no + + time.sleep(POLL_INTERVAL) + finally: + # Ensure background threads are stopped on cancel + for th in (thread_tta, thread_no): + if th.is_alive(): + _stop_thread(th) + th.join(timeout=1) + + # Remove finished threads from global registry + with _running_threads_lock: + # Clear session thread list + session_threads.clear() + + # Small delay to ensure last frame files are fully flushed + time.sleep(0.2) + # One last scan after both threads have finished to catch any frame + # that may have been written just before termination but after the last + # polling iteration. + for fp in sorted(glob.glob(os.path.join(gifs_dir_tta, "*.png")), key=lambda p: int(os.path.splitext(os.path.basename(p))[0])): + if fp not in sent_tta: + sent_tta.add(fp) + last_tta = fp + for fp in sorted(glob.glob(os.path.join(gifs_dir_no, "*.png")), key=lambda p: int(os.path.splitext(os.path.basename(p))[0])): + if fp not in sent_no: + sent_no.add(fp) + last_no = fp + + # Prepare frames list and slider configs + frames_tta = sorted(glob.glob(os.path.join(gifs_dir_tta, "*.png")), key=lambda p: int(os.path.splitext(os.path.basename(p))[0])) + frames_no = sorted(glob.glob(os.path.join(gifs_dir_no, "*.png")), key=lambda p: int(os.path.splitext(os.path.basename(p))[0])) + if last_tta is None and frames_tta: + last_tta = frames_tta[-1] + if last_no is None and frames_no: + last_no = frames_no[-1] + n_tta = len(frames_tta) or 1 # prevent zero-range slider + n_no = len(frames_no) or 1 + + # Final emit: re-enable button, hide statuses, show sliders set to last frame + yield ( + gr.update(interactive=True), + last_tta, + last_no, + gr.update(visible=False), + gr.update(visible=False), + gr.update(visible=True, minimum=1, maximum=n_tta, value=n_tta), + gr.update(visible=True, minimum=1, maximum=n_no, value=n_no), + frames_tta, + frames_no, + session_threads, + ) + + +# ────────────────────────── Gradio UI ───────────────────────────────── +with gr.Blocks(title="Search-TTA (Simplified)", theme=gr.themes.Base()) as demo: + + gr.Markdown( + """ + # Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild Demo + Click on any of the examples below and run the TTA demo. Check out the multimodal heatmap generation feature by switching to the other tab above.
+ Note that the model initialization, RL planner, and TTA updates are not fully optimized on GPU for this huggingface demo, and hence may experience some lag during execution.
+ If you encounter an 'Error' status, refresh the browser and rerun the demo, or try again the next day. We will improve this in the future.
+ Project Website + """ + ) + + with gr.Row(variant="panel"): + with gr.Column(): + gr.Markdown("### Model Inputs") + sat_input = gr.Image( + label="Satellite Image", + sources=["upload"], + type="filepath", + height=320, + ) + taxonomy_input = gr.Textbox( + label="Full Taxonomy Name (optional)", + placeholder="e.g. Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota", + ) + ground_input = gr.Image( + label="Ground-level Image (optional)", + sources=["upload"], + type="filepath", + height=320, + ) + run_btn = gr.Button("Run Search-TTA", variant="primary") + + with gr.Column(): + gr.Markdown("### Live Heatmap Output") + display_img_tta = gr.Image(label="Heatmap (TTA per 20 steps)", type="filepath", height=400) # 512 + status_tta = gr.Markdown("") + slider_tta = gr.Slider(label="TTA Frame", minimum=1, maximum=1, step=1, value=1, visible=False) + + display_img_no_tta = gr.Image(label="Heatmap (no TTA)", type="filepath", height=400) # 512 + status_no_tta = gr.Markdown("") + slider_no = gr.Slider(label="No-TTA Frame", minimum=1, maximum=1, step=1, value=1, visible=False) + + frames_state_tta = gr.State([]) + frames_state_no = gr.State([]) + session_threads_state = gr.State([]) + + # Slider callbacks (updates image when user drags slider) + def _show_frame(idx: int, frames: list[str]): + # Slider is 1-indexed; convert to 0-indexed list access + if 1 <= idx <= len(frames): + return frames[idx - 1] + return gr.update() + + slider_tta.change(_show_frame, inputs=[slider_tta, frames_state_tta], outputs=display_img_tta) + slider_no.change(_show_frame, inputs=[slider_no, frames_state_no], outputs=display_img_no_tta) + + # EXAMPLES + with gr.Row(): + gr.Markdown("### Taxonomy") + with gr.Row(): + gr.Examples( + examples=[ + [ + "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg", + "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg", + "Animalia Chordata Reptilia Squamata Varanidae Varanus salvator", + ], + [ + "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg", + "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg", + "Animalia Chordata Mammalia Carnivora Canidae Canis aureus", + ], + [ + "examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/340271_10.52832_-83.49678.jpg", + "examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/938aab7b-4509-4de7-afad-2c8ea51f4799.jpg", + "Animalia Chordata Reptilia Crocodylia Alligatoridae Caiman crocodilus", + ], + [ + "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg", + "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg", + "Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis", + ], + ], + inputs=[sat_input, ground_input, taxonomy_input], + outputs=[run_btn, display_img_tta, display_img_no_tta, status_tta, status_no_tta, slider_tta, slider_no, frames_state_tta, frames_state_no], + fn=process_search_tta, + cache_examples=False, + ) + + run_btn.click( + fn=process_search_tta, + inputs=[sat_input, ground_input, taxonomy_input, session_threads_state], + outputs=[run_btn, display_img_tta, display_img_no_tta, status_tta, status_no_tta, slider_tta, slider_no, frames_state_tta, frames_state_no, session_threads_state], + ) + + # Footer to point out to model and data from app page. + gr.Markdown( + """ + The satellite image CLIP encoder is fine-tuned using [Sentinel-2 Level 2A](https://docs.sentinel-hub.com/api/latest/data/sentinel-2-l2a/) satellite image and taxonomy images (with GPS locations) from [iNaturalist](https://inaturalist.org/). The sound CLIP encoder is fine-tuned with a subset of the same taxonomy images and their corresponding sounds from [iNaturalist](https://inaturalist.org/). Some of these iNaturalist data are also used in [Taxabind](https://arxiv.org/abs/2411.00683). Note that while some of the examples above result in poor probability distributions, they will be improved using our test-time adaptation framework during the search process. + """ + ) + + +if __name__ == "__main__": + + # Build UI with explicit Tabs so we can detect tab selection and clean up + from app_multimodal_inference import demo as multimodal_demo + + with gr.Blocks() as root: + with gr.Tabs() as tabs: + with gr.TabItem("Multimodal Inference"): + multimodal_demo.render() + with gr.TabItem("Search-TTA"): + demo.render() + + # Hidden textbox purely to satisfy Gradio's need for an output component. + _cleanup_status = gr.Textbox(visible=False) + + outputs_on_tab = [_cleanup_status] + + def _on_tab_change(evt: gr.SelectData, session_threads: list[threading.Thread]): + # evt.value contains the name of the newly-selected tab. + if evt.value == "Multimodal Inference": + # Stop only threads started in this session + for th in list(session_threads): + if th is not None and th.is_alive(): + _stop_thread(th) + th.join(timeout=1) + session_threads.clear() + return "Stopped running Search-TTA threads." + return "" + + tabs.select(_on_tab_change, inputs=[session_threads_state], outputs=outputs_on_tab) + + root.queue(max_size=15) + root.launch(share=True) diff --git a/app_multimodal_inference.py b/app_multimodal_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..d0dcfc5898c6164a313c973521bbe83bc4e097b5 --- /dev/null +++ b/app_multimodal_inference.py @@ -0,0 +1,350 @@ +""" +Search-TTA multimodal heatmap generation demo +""" + +# ────────────────────────── imports ─────────────────────────────────── +import cv2 +import gradio as gr +import torch +import numpy as np +from PIL import Image +import matplotlib.pyplot as plt +import io +import torchaudio +import spaces # integration with ZeroGPU on hf + +from torchvision import transforms +import open_clip +from taxabind_avs.satbind.clip_vision_per_patch_model import CLIPVisionPerPatchModel +from transformers import ClapAudioModelWithProjection +from transformers import ClapProcessor + +# ────────────────────────── global config & models ──────────────────── +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# BioCLIP (ground-image & text encoder) +bio_model, _, _ = open_clip.create_model_and_transforms("hf-hub:imageomics/bioclip") +bio_model = bio_model.to(device).eval() +bio_tokenizer = open_clip.get_tokenizer("hf-hub:imageomics/bioclip") + +# Satellite patch encoder CLIP-L-336 per-patch) +sat_model: CLIPVisionPerPatchModel = ( + CLIPVisionPerPatchModel.from_pretrained("derektan95/search-tta-sat") + .to(device) + .eval() +) + +# Sound CLAP model +sound_model: ClapAudioModelWithProjection = ( + ClapAudioModelWithProjection.from_pretrained("derektan95/search-tta-sound") + .to(device) + .eval() +) +sound_processor: ClapProcessor = ClapProcessor.from_pretrained("derektan95/search-tta-sound") +SAMPLE_RATE = 48000 + +logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) +logit_scale = logit_scale.exp() +blur_kernel = (5,5) + +# ────────────────────────── transforms (exact spec) ─────────────────── +img_transform = transforms.Compose( + [ + transforms.Resize((256, 256)), + transforms.CenterCrop((224, 224)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ), + ] +) + +imo_transform = transforms.Compose( + [ + transforms.Resize((336, 336)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ), + ] +) + +def get_audio_clap(path_to_audio,format="mp3",padding="repeatpad",truncation="fusion"): + track, sr = torchaudio.load(path_to_audio, format=format) # torchaudio.load(path_to_audio) + track = track.mean(axis=0) + track = torchaudio.functional.resample(track, orig_freq=sr, new_freq=SAMPLE_RATE) + output = sound_processor(audios=track, sampling_rate=SAMPLE_RATE, max_length_s=10, return_tensors="pt",padding=padding,truncation=truncation) + return output + +# ────────────────────────── helpers ─────────────────────────────────── + +@torch.no_grad() +def _encode_ground(img_pil: Image.Image) -> torch.Tensor: + img = img_transform(img_pil).unsqueeze(0).to(device) + img_embeds, *_ = bio_model(img) + return img_embeds + + +@torch.no_grad() +def _encode_text(text: str) -> torch.Tensor: + toks = bio_tokenizer(text).to(device) + _, txt_embeds, _ = bio_model(text=toks) + return txt_embeds + + +@torch.no_grad() +def _encode_sat(img_pil: Image.Image) -> torch.Tensor: + imo = imo_transform(img_pil).unsqueeze(0).to(device) + imo_embeds = sat_model(imo) + return imo_embeds + + +@torch.no_grad() +def _encode_sound(sound) -> torch.Tensor: + processed_sound = get_audio_clap(sound) + for k in processed_sound.keys(): + processed_sound[k] = processed_sound[k].to(device) + unnormalized_audio_embeds = sound_model(**processed_sound).audio_embeds + sound_embeds = torch.nn.functional.normalize(unnormalized_audio_embeds, dim=-1) + return sound_embeds + + +def _similarity_heatmap(query: torch.Tensor, patches: torch.Tensor) -> np.ndarray: + sims = torch.matmul(query, patches.t()) * logit_scale + sims = sims.t().sigmoid() + sims = sims[1:].squeeze() # drop CLS token + side = int(np.sqrt(len(sims))) + sims = sims.reshape(side, side) + return sims.cpu().detach().numpy() + + +def _array_to_pil(arr: np.ndarray) -> Image.Image: + """ + Render arr with viridis, automatically stretching its own minβ†’max to 0β†’1 + so that the most-similar patches appear yellow. + """ + + # Gausian Smoothing + if blur_kernel != (0,0): + arr = cv2.GaussianBlur(arr, blur_kernel, 0) + + # --- contrast-stretch to local 0-1 range -------------------------- + arr_min, arr_max = float(arr.min()), float(arr.max()) + if arr_max - arr_min < 1e-6: # avoid /0 when the heat-map is flat + arr_scaled = np.zeros_like(arr) + else: + arr_scaled = (arr - arr_min) / (arr_max - arr_min) + # ------------------------------------------------------------------ + fig, ax = plt.subplots(figsize=(2.6, 2.6), dpi=96) + ax.imshow(arr_scaled, cmap="viridis", vmin=0.0, vmax=1.0) + ax.axis("off") + buf = io.BytesIO() + plt.tight_layout(pad=0) + fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) + plt.close(fig) + buf.seek(0) + return Image.open(buf) + +# ────────────────────────── main inference ──────────────────────────── +# integration with ZeroGPU on hf +@spaces.GPU(duration=5) +def process( + sat_img: Image.Image, + taxonomy: str, + ground_img: Image.Image | None, + sound: torch.Tensor | None, +): + if sat_img is None: + return None, None + + patches = _encode_sat(sat_img) + + heat_ground, heat_text, heat_sound = None, None, None + + if ground_img is not None: + q_img = _encode_ground(ground_img) + heat_ground = _array_to_pil(_similarity_heatmap(q_img, patches)) + + if taxonomy.strip(): + q_txt = _encode_text(taxonomy.strip()) + heat_text = _array_to_pil(_similarity_heatmap(q_txt, patches)) + + if sound is not None: + q_sound = _encode_sound(sound) + heat_sound = _array_to_pil(_similarity_heatmap(q_sound, patches)) + + return heat_ground, heat_text, heat_sound + + +# ────────────────────────── Gradio UI ───────────────────────────────── +with gr.Blocks(title="Search-TTA", theme=gr.themes.Base()) as demo: + + gr.Markdown( + """ + # Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild Demo + Click on any of the examples below and run the multimodal inference demo. Check out the test-time adaptation feature by switching to the other tab above.
+ If you encounter any errors, refresh the browser and rerun the demo, or try again the next day. We will improve this in the future.
+ Project Website + """ + ) + + with gr.Row(variant="panel"): + + # LEFT COLUMN (satellite, taxonomy, run) + with gr.Column(): + sat_input = gr.Image( + label="Satellite Image", + sources=["upload"], + type="pil", + height=320, + ) + taxonomy_input = gr.Textbox( + label="Full Taxonomy Name (optional)", + placeholder="e.g. Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota", + ) + + # ─── NEW: sound input ─────────────────────────── + sound_input = gr.Audio( + label="Sound Input (optional)", + sources=["upload"], + type="filepath", + ) + run_btn = gr.Button("Run", variant="primary") + + # RIGHT COLUMN (ground image + two heat-maps) + with gr.Column(): + ground_input = gr.Image( + label="Ground-level Image (optional)", + sources=["upload"], + type="pil", + height=320, + ) + gr.Markdown("### Heat-map Results") + with gr.Row(): + # Separate label and image to avoid overlap + with gr.Column(scale=1, min_width=100): + gr.Markdown("**Ground Image Query**", elem_id="label-ground") + heat_ground_out = gr.Image( + show_label=False, + height=160, + ) + with gr.Column(scale=1, min_width=100): + gr.Markdown("**Text Query**", elem_id="label-text") + heat_text_out = gr.Image( + show_label=False, + height=160, + ) + with gr.Column(scale=1, min_width=100): + gr.Markdown("**Sound Query**", elem_id="label-sound") + heat_sound_out = gr.Image( + show_label=False, + height=160, + ) + + + # EXAMPLES + with gr.Row(): + gr.Markdown("### In-Domain Taxonomy") + with gr.Row(): + gr.Examples( + examples=[ + [ + "examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg", + "examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/cc1ebaf9-899d-49f2-81c8-d452249a8087.jpg", + "Animalia Chordata Aves Charadriiformes Laridae Larus marinus", + "examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/89758229.mp3" + ], + [ + "examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/28871_-12.80255_-69.29999.jpg", + "examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/1b8064f8-7deb-4b30-98cd-69da98ba6a3d.jpg", + "Animalia Chordata Mammalia Rodentia Caviidae Hydrochoerus hydrochaeris", + "examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/166631961.mp3" + ], + [ + "examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/277303_38.72364_-75.07749.jpg", + "examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/0b9cc264-a2ba-44bd-8e41-0d01a6edd1e8.jpg", + "Animalia Arthropoda Malacostraca Decapoda Ocypodidae Ocypode quadrata", + "examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/12372063.mp3" + ], + [ + "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/388246_45.49036_7.14796.jpg", + "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/327e1f07-692b-4140-8a3e-bd098bc064ff.jpg", + "Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota", + "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/59677071.mp3" + ], + [ + "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg", + "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg", + "Animalia Chordata Reptilia Squamata Varanidae Varanus salvator", + None + ], + ], + inputs=[sat_input, ground_input, taxonomy_input, sound_input], + outputs=[heat_ground_out, heat_text_out, heat_sound_out], + fn=process, + cache_examples=False, + ) + + # EXAMPLES + with gr.Row(): + gr.Markdown("### Out-Domain Taxonomy") + with gr.Row(): + gr.Examples( + examples=[ + [ + "examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/27423_35.64005_-121.17595.jpg", + "examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3aac526d-c921-452a-af6a-cb4f2f52e2c4.jpg", + "Animalia Chordata Mammalia Carnivora Phocidae Mirounga angustirostris", + "examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3123948.mp3" + ], + [ + "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg", + "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg", + "Animalia Chordata Mammalia Carnivora Canidae Canis aureus", + "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/189318716.mp3" + ], + [ + "examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/340271_10.52832_-83.49678.jpg", + "examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/938aab7b-4509-4de7-afad-2c8ea51f4799.jpg", + "Animalia Chordata Reptilia Crocodylia Alligatoridae Caiman crocodilus", + "examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/45193295.mp3" + ], + [ + "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg", + "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg", + "Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis", + None + ], + [ + "examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/5041_-0.28573_-90.54837.jpg", + "examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/c834edf7-b073-4cd5-8726-9c88ebe943c0.jpg", + "Animalia Chordata Elasmobranchii Carcharhiniformes Carcharhinidae Triaenodon obesus", + None + ], + ], + inputs=[sat_input, ground_input, taxonomy_input, sound_input], + outputs=[heat_ground_out, heat_text_out, heat_sound_out], + fn=process, + cache_examples=False, + ) + + # CALLBACK + run_btn.click( + fn=process, + inputs=[sat_input, taxonomy_input, ground_input, sound_input], + outputs=[heat_ground_out, heat_text_out, heat_sound_out], + ) + + # Footer to point out to model and data from app page. + gr.Markdown( + """ + The satellite image CLIP encoder is fine-tuned using [Sentinel-2 Level 2A](https://docs.sentinel-hub.com/api/latest/data/sentinel-2-l2a/) satellite image and taxonomy images (with GPS locations) from [iNaturalist](https://inaturalist.org/). The sound CLIP encoder is fine-tuned with a subset of the same taxonomy images and their corresponding sounds from [iNaturalist](https://inaturalist.org/). Some of these iNaturalist data are also used in [Taxabind](https://arxiv.org/abs/2411.00683). Note that while some of the examples above result in poor probability distributions, they will be improved using our test-time adaptation framework during the search process. + """ + ) + +# LAUNCH +if __name__ == "__main__": + demo.queue(max_size=15) + demo.launch(share=True) diff --git a/examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/0b9cc264-a2ba-44bd-8e41-0d01a6edd1e8.jpg b/examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/0b9cc264-a2ba-44bd-8e41-0d01a6edd1e8.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0ffd63f6e6eae1ace33246f8abecc88997a227ea --- /dev/null +++ b/examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/0b9cc264-a2ba-44bd-8e41-0d01a6edd1e8.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:08aee38091dbb62f0862a184acbc9432f50c03c63fdf357592df8efcacaab485 +size 134759 diff --git a/examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/12372063.mp3 b/examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/12372063.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..f25d6a26dae4ddb1c7764ea5ec3c0066cc44b42c --- /dev/null +++ b/examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/12372063.mp3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:575959883981159f2e40593bf5be87be006026c41da36a34d1e40783de648116 +size 54027 diff --git a/examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/277303_38.72364_-75.07749.jpg b/examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/277303_38.72364_-75.07749.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3a8e17130c52e623889436f126fb287715535089 --- /dev/null +++ b/examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/277303_38.72364_-75.07749.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1caa5d8bab960f559065f79ca554bed63e6b02764096874be6a58b34389855f6 +size 25627 diff --git a/examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg b/examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8fd23a3b82411137495ba6b71c7b6dd358cdc1c4 --- /dev/null +++ b/examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8350770efa7d8e38b91670e757bb82df26167f8989f946132ad978d238baa916 +size 26142 diff --git a/examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/89758229.mp3 b/examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/89758229.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..2a85121f51df7f978961293eff638e883699fe48 --- /dev/null +++ b/examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/89758229.mp3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e2c7ad6df49668d29f9b7f9f9f0739b97ef4edc5219413a41d01983a9863cccc +size 2601487 diff --git a/examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/cc1ebaf9-899d-49f2-81c8-d452249a8087.jpg b/examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/cc1ebaf9-899d-49f2-81c8-d452249a8087.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e2380dc1133ed25489c69ae1de9967a813caa1cb --- /dev/null +++ b/examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/cc1ebaf9-899d-49f2-81c8-d452249a8087.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e346a1c1424e62a040c7e97f17c2e5ccb4c923422682105b2ccedd0ead736170 +size 28444 diff --git a/examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/5041_-0.28573_-90.54837.jpg b/examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/5041_-0.28573_-90.54837.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1b046d43a5083e8207683bced18586dbe3fe8ece --- /dev/null +++ b/examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/5041_-0.28573_-90.54837.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:624443bdb62b8d35e6e63be33e04404a85ad8902b70af67d878a013893656dc2 +size 15306 diff --git a/examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/c834edf7-b073-4cd5-8726-9c88ebe943c0.jpg b/examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/c834edf7-b073-4cd5-8726-9c88ebe943c0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..829aeb164ad91ca1564548db2380160803a5dff5 --- /dev/null +++ b/examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/c834edf7-b073-4cd5-8726-9c88ebe943c0.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e7f8be8790e7c5837d8d8e0d9285adad45138598caef21f528a591a0ab13ee9b +size 58003 diff --git a/examples/Animalia_Chordata_Mammalia_Artiodactyla_Cervidae_Cervus_nippon/1a51f978-c76d-4816-a716-473c21b5495c.jpg b/examples/Animalia_Chordata_Mammalia_Artiodactyla_Cervidae_Cervus_nippon/1a51f978-c76d-4816-a716-473c21b5495c.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0eb620ff8de6b3bfb7e44c3809d10fca8fcd6020 --- /dev/null +++ b/examples/Animalia_Chordata_Mammalia_Artiodactyla_Cervidae_Cervus_nippon/1a51f978-c76d-4816-a716-473c21b5495c.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a47758183723ba17f72dee9acdaf4bcfba2b4d07d0af2e50c125b3fac665ca04 +size 116615 diff --git a/examples/Animalia_Chordata_Mammalia_Artiodactyla_Cervidae_Cervus_nippon/245767_53.0076_-6.35201.jpg b/examples/Animalia_Chordata_Mammalia_Artiodactyla_Cervidae_Cervus_nippon/245767_53.0076_-6.35201.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2370e73283220a9faba6ddb7179bff1519750ac3 --- /dev/null +++ b/examples/Animalia_Chordata_Mammalia_Artiodactyla_Cervidae_Cervus_nippon/245767_53.0076_-6.35201.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7e9e1c3907555774d831b34b9d7ef94b79b7fbe82c3e226baef75e0cf71194e4 +size 23001 diff --git a/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg b/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cbf5d4ba603adb07834bb5d4f321b465814fe79f --- /dev/null +++ b/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e9f1934026db176cdcea261d37eda0a02309e5f2647ecab336e53d571b40f8f4 +size 37212 diff --git a/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/189318716.mp3 b/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/189318716.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..e994349c23bfb8614fc7f6e35e4ec9be1755633c --- /dev/null +++ b/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/189318716.mp3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4639c226ea5a0464b98e89b33a1f821b6625c6637d206d3d355e05bc7c89c641 +size 148019 diff --git a/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg b/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg new file mode 100644 index 0000000000000000000000000000000000000000..de835cf70160678b36fc78be76f97748063de57c --- /dev/null +++ b/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a84eca02154ed12885c075378b0349d6950586a7887883bce414df48adb59746 +size 83083 diff --git a/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg b/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d25a95a88da478475364862dc92c3d85cff0f4c5 --- /dev/null +++ b/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ea5f2dffebd69cdded00548f8773c5a8a8849bbdfba04ae6385fbc2d0983d55f +size 75596 diff --git a/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg b/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg new file mode 100644 index 0000000000000000000000000000000000000000..def90799f19196092663468334976a847a59f416 --- /dev/null +++ b/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3b803c3c2e6fa921d9f83ba3aecccac0796a4cd4774c3263aae54fdfc49d13d6 +size 23309 diff --git a/examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/27423_35.64005_-121.17595.jpg b/examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/27423_35.64005_-121.17595.jpg new file mode 100644 index 0000000000000000000000000000000000000000..de7f0c6267cfe0d0abd4ef95d0c5612299dbfc68 --- /dev/null +++ b/examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/27423_35.64005_-121.17595.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:16dd378607b7303515593491a1247785ae49733da24bbc3ce21e85d6c6341ab2 +size 22576 diff --git a/examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3123948.mp3 b/examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3123948.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..99fe366aae29123a967d5d909f9849de6df660a7 --- /dev/null +++ b/examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3123948.mp3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:96ca3a92e6f614cce82972dacb04f5c0c170c1aea3d70d15778af56820ed02c9 +size 276768 diff --git a/examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3aac526d-c921-452a-af6a-cb4f2f52e2c4.jpg b/examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3aac526d-c921-452a-af6a-cb4f2f52e2c4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6654076f7838645dc3ce9b32b26fcccdfe864891 --- /dev/null +++ b/examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3aac526d-c921-452a-af6a-cb4f2f52e2c4.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bdda6139885cf54acfb1d6c9a56373fbe39e205dac7eb99cd04dbe5eb206b9d6 +size 95413 diff --git a/examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/166631961.mp3 b/examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/166631961.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..e15ebd11da89b57d7ebaef55cdf96c0a810a7175 --- /dev/null +++ b/examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/166631961.mp3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc02eca19d0c408d038e205d82f6624c0515858ac374cf7298161a14e169e6a9 +size 266258 diff --git a/examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/1b8064f8-7deb-4b30-98cd-69da98ba6a3d.jpg b/examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/1b8064f8-7deb-4b30-98cd-69da98ba6a3d.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0066c1e1c26bece8235dd23fd035870958284e37 --- /dev/null +++ b/examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/1b8064f8-7deb-4b30-98cd-69da98ba6a3d.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fe636d18a1e068e85b0bb8cd05ff674eb4b19958cc34d75ef00d385f74254ecb +size 85294 diff --git a/examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/28871_-12.80255_-69.29999.jpg b/examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/28871_-12.80255_-69.29999.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cb656a4c7e24dc5b49e79ebd7f1061512d7db5fb --- /dev/null +++ b/examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/28871_-12.80255_-69.29999.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d28dec20d4f9cba13386ab00f13ddd7cb36fee24ee466e8a5437dbfd778bc2d5 +size 23308 diff --git a/examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/327e1f07-692b-4140-8a3e-bd098bc064ff.jpg b/examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/327e1f07-692b-4140-8a3e-bd098bc064ff.jpg new file mode 100644 index 0000000000000000000000000000000000000000..11254703f2d85ee4b05bbe1784d8a55eb41e7249 --- /dev/null +++ b/examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/327e1f07-692b-4140-8a3e-bd098bc064ff.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5bffc4c332ae6406bcb1b78cd23170bd7c71d58e2e7dac12fb812fc9aa39b8f0 +size 70314 diff --git a/examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/388246_45.49036_7.14796.jpg b/examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/388246_45.49036_7.14796.jpg new file mode 100644 index 0000000000000000000000000000000000000000..de28dc1d30551c4e980243a3806d8f3aba22eee4 --- /dev/null +++ b/examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/388246_45.49036_7.14796.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a5aa9ae1a1dc4c59191bc72005fc9904d4c390f07ce5cc5ed435eb5687ae1d64 +size 33328 diff --git a/examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/59677071.mp3 b/examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/59677071.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..a1b5402460867d2b2574d3d02fcf5d1d338a7d55 --- /dev/null +++ b/examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/59677071.mp3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bb043991fe851d6a1e12f32c5a9277dad5a77a939cf15ccb4afcb215b4bc08e3 +size 92876 diff --git a/examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/340271_10.52832_-83.49678.jpg b/examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/340271_10.52832_-83.49678.jpg new file mode 100644 index 0000000000000000000000000000000000000000..da1f75add0023ba8926dffc94c3c64ee1a3ca805 --- /dev/null +++ b/examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/340271_10.52832_-83.49678.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b31fbe934b245e7274289836b9eee781b2e33c4121dfbafebc473cd45d638825 +size 19839 diff --git a/examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/45193295.mp3 b/examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/45193295.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..b803eca684f205f062a1683f083509f23940cf8a --- /dev/null +++ b/examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/45193295.mp3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c4cd2e4fd7094a07d79da7fd54788705e8ce7567e65911d87edfd23ff1c0e484 +size 247762 diff --git a/examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/938aab7b-4509-4de7-afad-2c8ea51f4799.jpg b/examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/938aab7b-4509-4de7-afad-2c8ea51f4799.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b87fa747ba1d4e0b0630970dbd6b441294191a3b --- /dev/null +++ b/examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/938aab7b-4509-4de7-afad-2c8ea51f4799.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a26f17668646cd25c77483565f6509ca7b21bba09ce92dac0f38d0ecbfdae3b1 +size 86323 diff --git a/examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg b/examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg new file mode 100644 index 0000000000000000000000000000000000000000..53d497a058b2995a75d724a808e74c13553594d4 --- /dev/null +++ b/examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3019573a982d10c4791e357a5bebadfbb245f145c57c60c8a53f2241ac8789fe +size 37040 diff --git a/examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg b/examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9464acdaf32edf40e991cd6d72955d382a34fe5e --- /dev/null +++ b/examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:41b11b1ea9709a9fefabc2c7ddf8aa58a7881749474bf0ccadca3a02e3a97c76 +size 175798 diff --git a/examples/metadata.json b/examples/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..1072d38e6f2a10ad3ca6c1e54b4413f82ebfb4de --- /dev/null +++ b/examples/metadata.json @@ -0,0 +1,173 @@ +{ + "Animalia Chordata Reptilia Squamata Varanidae Varanus salvator": { + "id": 410613, + "sat_key": "410613_5.35573_100.28948", + "sat_path": "410613_5.35573_100.28948.jpg", + "taxonomy": "Animalia Chordata Reptilia Squamata Varanidae Varanus salvator", + "count": 6, + "spread": 58.00460580210422, + "sat_bounds": { + "min_lat": 5.344155081363914, + "max_lat": 5.367304914271601, + "min_lon": 100.27793148340874, + "max_lon": 100.30102851659126 + }, + "img_ids": [ + 707815, + 411949, + 701168, + 1619682, + 2100008, + 1548498 + ], + "target_positions": [ + [ + 225, + 240 + ], + [ + 232, + 275 + ], + [ + 277, + 449 + ], + [ + 220, + 369 + ], + [ + 180, + 393 + ], + [ + 294, + 478 + ] + ], + "num_landmarks": 2 + }, + "Animalia Chordata Mammalia Carnivora Canidae Canis aureus": { + "id": 1528408, + "sat_key": "1528408_13.00422_80.23033", + "sat_path": "1528408_13.00422_80.23033.jpg", + "taxonomy": "Animalia Chordata Mammalia Carnivora Canidae Canis aureus", + "count": 3, + "spread": 58.14007011752667, + "sat_bounds": { + "min_lat": 12.992649951077192, + "max_lat": 13.015790038631529, + "min_lon": 80.21853090802841, + "max_lon": 80.24212909197156 + }, + "img_ids": [ + 1528479, + 2555188, + 2555189 + ], + "target_positions": [ + [ + 309, + 128 + ], + [ + 239, + 428 + ], + [ + 240, + 419 + ] + ], + "num_landmarks": 3 + }, + "Animalia Chordata Reptilia Crocodylia Alligatoridae Caiman crocodilus": { + "id": 340271, + "sat_key": "340271_10.52832_-83.49678", + "sat_path": "340271_10.52832_-83.49678.jpg", + "taxonomy": "Animalia Chordata Reptilia Crocodylia Alligatoridae Caiman crocodilus", + "count": 7, + "spread": 40.13902957324975, + "sat_bounds": { + "min_lat": 10.516747947357544, + "max_lat": 10.53989204420829, + "min_lon": -83.50847402265151, + "max_lon": -83.48508597734848 + }, + "img_ids": [ + 1683531, + 1281855, + 223089, + 688111, + 330757, + 2408375, + 1955359 + ], + "target_positions": [ + [ + 347, + 75 + ], + [ + 47, + 22 + ], + [ + 111, + 43 + ], + [ + 116, + 51 + ], + [ + 86, + 108 + ], + [ + 31, + 62 + ], + [ + 4, + 78 + ] + ], + "num_landmarks": 3 + }, + "Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis": { + "id": 304160, + "sat_key": "304160_34.0144_-119.54417", + "sat_path": "304160_34.0144_-119.54417.jpg", + "taxonomy": "Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis", + "count": 3, + "spread": 237.64152837579553, + "sat_bounds": { + "min_lat": 34.00286041606169, + "max_lat": 34.02593956225012, + "min_lon": -119.55802743361286, + "max_lon": -119.53031256638712 + }, + "img_ids": [ + 304160, + 1473173, + 384867 + ], + "target_positions": [ + [ + 255, + 256 + ], + [ + 19, + 22 + ], + [ + 29, + 274 + ] + ], + "num_landmarks": 3 + } +} \ No newline at end of file diff --git a/inference/model/avs_rl_policy.pth b/inference/model/avs_rl_policy.pth new file mode 100644 index 0000000000000000000000000000000000000000..3721ea2885ee2ad7face94b1a9fecabc823c2044 --- /dev/null +++ b/inference/model/avs_rl_policy.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:44e642df9aaa2847ba44dd4707985c67ef712f5264272ef7993aeb7805c80f5a +size 52167246 diff --git a/maps/example/masks_val/MSK_0001.png b/maps/example/masks_val/MSK_0001.png new file mode 100644 index 0000000000000000000000000000000000000000..21ad5af72bb32265563af29f70ca80ff131568dd --- /dev/null +++ b/maps/example/masks_val/MSK_0001.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:318773e2c18275d84b5145d7e69836baa0bedd833f44b49f98e6619357677cff +size 75884 diff --git a/maps/gpt4o/envs_val/MSK_0001.png b/maps/gpt4o/envs_val/MSK_0001.png new file mode 100644 index 0000000000000000000000000000000000000000..cdfa027638f778caf542942085be8dffc28d9e1e --- /dev/null +++ b/maps/gpt4o/envs_val/MSK_0001.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7af11bcef1972b7e047f53b597fef2a332d82c7feceb21aac6e14a57469c436b +size 2337 diff --git a/planner/env.py b/planner/env.py new file mode 100644 index 0000000000000000000000000000000000000000..7e40201a50084d7532a7c38e6d8b3543f8ea5ca6 --- /dev/null +++ b/planner/env.py @@ -0,0 +1,610 @@ +####################################################################### +# Name: env.py +# +# - Reads and processes training and test maps +# - Processes rewards, new frontiers given action +# - Updates a graph representation of environment for input into network +####################################################################### + +import sys +if sys.modules['TRAINING']: + from .parameter import * +else: + from .test_parameter import * + +import os +import cv2 +import copy +import matplotlib.image as mpimg +import matplotlib.pyplot as plt +from skimage import io +from skimage.measure import block_reduce +from scipy.ndimage import label, find_objects +from .sensor import * +from .graph_generator import * +from .node import * + + +class Env(): + def __init__(self, map_index, n_agent, k_size=20, plot=False, test=False, mask_index=None): + self.n_agent = n_agent + self.test = test + self.map_dir = GRIDMAP_SET_DIR + + # Import environment gridmap + self.map_list = os.listdir(self.map_dir) + self.map_list.sort(reverse=True) + + # NEW: Import segmentation utility map + self.seg_dir = MASK_SET_DIR + self.segmentation_mask, self.target_positions, self.target_found_idxs = None, [], [] + self.segmentation_mask_list = os.listdir(self.seg_dir) + self.segmentation_mask_list.sort(reverse=True) + + # # NEW: Find common files in both directories + self.map_index = map_index % len(self.map_list) + if mask_index is not None: + self.mask_index = mask_index % len(self.segmentation_mask_list) + else: + self.mask_index = map_index % len(self.segmentation_mask_list) + + # Import ground truth and segmentation mask + self.ground_truth, self.map_start_position = self.import_ground_truth( + os.path.join(self.map_dir, self.map_list[self.map_index])) + self.ground_truth_size = np.shape(self.ground_truth) + self.robot_belief = np.ones(self.ground_truth_size) * 127 # unexplored 127 + self.downsampled_belief = None + self.old_robot_belief = copy.deepcopy(self.robot_belief) + self.coverage_belief = np.ones(self.ground_truth_size) * 127 # unexplored 127 + + # Import segmentation mask + mask_filename = self.segmentation_mask_list[self.mask_index] + self.segmentation_mask = self.import_segmentation_mask( + os.path.join(self.seg_dir, mask_filename)) + + # Overwrite target positions if directory specified + if self.test and TARGETS_SET_DIR != "": + self.target_positions = self.import_targets( + os.path.join(TARGETS_SET_DIR, self.map_list[self.map_index])) + + self.segmentation_info_mask = None + self.segmentation_info_mask_unnormalized = None + self.filtered_seg_info_mask = None + self.num_targets_found = 0 + self.num_new_targets_found = 0 + self.resolution = 4 + self.sensor_range = SENSOR_RANGE + self.explored_rate = 0 + self.targets_found_rate = 0 + self.frontiers = None + self.start_positions = [] + self.plot = plot + self.frame_files = [] + self.graph_generator = Graph_generator(map_size=self.ground_truth_size, sensor_range=self.sensor_range, k_size=k_size, plot=plot) + self.node_coords, self.graph, self.node_utility, self.guidepost = None, None, None, None + + self.begin(self.map_start_position) + + + def find_index_from_coords(self, position): + index = np.argmin(np.linalg.norm(self.node_coords - position, axis=1)) + return index + + def begin(self, start_position): + self.robot_belief = self.ground_truth + self.downsampled_belief = block_reduce(self.robot_belief.copy(), block_size=(self.resolution, self.resolution), func=np.min) + self.frontiers = self.find_frontier() + self.old_robot_belief = copy.deepcopy(self.robot_belief) + + self.node_coords, self.graph, self.node_utility, self.guidepost = self.graph_generator.generate_graph( + self.robot_belief, self.frontiers) + + # Define start positions + if FIX_START_POSITION: + coords_res_row = int(self.robot_belief.shape[0]/NUM_COORDS_HEIGHT) + coords_res_col = int(self.robot_belief.shape[1]/NUM_COORDS_WIDTH) + self.start_positions = [(int(self.robot_belief.shape[1]/2)-coords_res_col/2,int(self.robot_belief.shape[0]/2)-coords_res_row/2) for _ in range(self.n_agent)] + else: + nearby_coords = self.graph_generator.get_neighbors_grid_coords(start_position) + itr = 0 + for i in range(self.n_agent): + if i == 0 or len(nearby_coords) == 0: + self.start_positions.append(start_position) + else: + idx = min(itr, len(nearby_coords)-1) + self.start_positions.append(nearby_coords[idx]) + itr += 1 + + for i in range(len(self.start_positions)): + self.start_positions[i] = self.node_coords[self.find_index_from_coords(self.start_positions[i])] + self.coverage_belief = self.update_robot_belief(self.start_positions[i], self.sensor_range, self.coverage_belief, + self.ground_truth) + + for start_position in self.start_positions: + self.graph_generator.route_node.append(start_position) + + # Info map from ground truth + rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH) + rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT) + self.segmentation_info_mask = np.zeros((len(self.node_coords), 1)) + for i, node_coord in enumerate(self.node_coords): + max_x = min(node_coord[0] + int(math.ceil(rng_x)), self.ground_truth.shape[1]) + min_x = max(node_coord[0] - int(math.ceil(rng_x)), 0) + max_y = min(node_coord[1] + int(math.ceil(rng_y)), self.ground_truth.shape[0]) + min_y = max(node_coord[1] - int(math.ceil(rng_y)), 0) + + if TARGETS_SET_DIR == "": + exclude = {208} # Exclude target positions + else: + exclude = {} + self.segmentation_info_mask[i] = max(x for x in self.segmentation_mask[min_y:max_y, min_x:max_x].flatten() if x not in exclude) / 100.0 + + self.filtered_seg_info_mask = copy.deepcopy(self.segmentation_info_mask) + done, num_targets_found = self.check_done() + self.num_targets_found = num_targets_found + + + def multi_robot_step(self, next_position_list, dist_list, travel_dist_list): + reward_list = [] + for dist, robot_position in zip(dist_list, next_position_list): + self.graph_generator.route_node.append(robot_position) + next_node_index = self.find_index_from_coords(robot_position) + self.graph_generator.nodes_list[next_node_index].set_visited() + self.coverage_belief = self.update_robot_belief(robot_position, self.sensor_range, self.coverage_belief, + self.ground_truth) + self.robot_belief = self.ground_truth + self.downsampled_belief = block_reduce(self.robot_belief.copy(), + block_size=(self.resolution, self.resolution), + func=np.min) + + frontiers = self.find_frontier() + individual_reward = -dist / 32 + + info_gain_reward = 0 + robot_position_idx = self.find_index_from_coords(robot_position) + info_gain_reward = self.filtered_seg_info_mask[robot_position_idx][0] * 1.5 + if self.guidepost[robot_position_idx] == 0.0: + info_gain_reward += 0.2 + individual_reward += info_gain_reward + + reward_list.append(individual_reward) + + self.node_coords, self.graph, self.node_utility, self.guidepost = self.graph_generator.update_graph(self.robot_belief, self.old_robot_belief, frontiers, self.frontiers) + self.old_robot_belief = copy.deepcopy(self.robot_belief) + + self.filtered_seg_info_mask = [info[0] if self.guidepost[i] == 0.0 else 0.0 for i, info in enumerate(self.segmentation_info_mask)] + self.filtered_seg_info_mask = np.expand_dims(np.array(self.filtered_seg_info_mask), axis=1) + + self.frontiers = frontiers + self.explored_rate = self.evaluate_exploration_rate() + + done, num_targets_found = self.check_done() + self.num_new_targets_found = num_targets_found - self.num_targets_found + team_reward = 0.0 + + self.num_targets_found = num_targets_found + self.targets_found_rate = self.evaluate_targets_found_rate() + + if done: + team_reward += 40 + for i in range(len(reward_list)): + reward_list[i] += team_reward + + return reward_list, done + + + def import_ground_truth(self, map_index): + # occupied 1, free 255, unexplored 127 + + try: + ground_truth = (io.imread(map_index, 1)).astype(int) + if np.all(ground_truth == 0): + ground_truth = (io.imread(map_index, 1) * 255).astype(int) + except: + new_map_index = self.map_dir + '/' + self.map_list[0] + ground_truth = (io.imread(new_map_index, 1)).astype(int) + print('could not read the map_path ({}), hence skipping it and using ({}).'.format(map_index, new_map_index)) + + robot_location = np.nonzero(ground_truth == 208) + robot_location = np.array([np.array(robot_location)[1, 127], np.array(robot_location)[0, 127]]) + ground_truth = (ground_truth > 150) + ground_truth = ground_truth * 254 + 1 + return ground_truth, robot_location + + + def import_segmentation_mask(self, map_index): + mask = cv2.imread(map_index).astype(int) + return mask + + def import_targets(self, map_index): + # occupied 1, free 255, unexplored 127, target 208 + mask = cv2.imread(map_index).astype(int) + target_positions = self.find_target_locations(mask) + return target_positions + + + def find_target_locations(self, image_array, grey_value=208): + + grey_pixels = np.where(image_array == grey_value) + binary_array = np.zeros_like(image_array, dtype=bool) + binary_array[grey_pixels] = True + labeled_array, num_features = label(binary_array) + slices = find_objects(labeled_array) + + # Calculate the center of each box + centers = [] + for slice in slices: + row_center = (slice[0].start + slice[0].stop - 1) // 2 + col_center = (slice[1].start + slice[1].stop - 1) // 2 + centers.append((col_center, row_center)) # (y,x) + + return centers + + def free_cells(self): + index = np.where(self.ground_truth == 255) + free = np.asarray([index[1], index[0]]).T + return free + + def update_robot_belief(self, robot_position, sensor_range, robot_belief, ground_truth): + robot_belief = sensor_work(robot_position, sensor_range, robot_belief, ground_truth) + return robot_belief + + + def check_done(self): + done = False + num_targets_found = 0 + self.target_found_idxs = [] + for i, target in enumerate(self.target_positions): + if self.coverage_belief[target[1], target[0]] == 255: + num_targets_found += 1 + self.target_found_idxs.append(i) + + if TERMINATE_ON_TGTS_FOUND and num_targets_found >= len(self.target_positions): + done = True + if not TERMINATE_ON_TGTS_FOUND and np.sum(self.coverage_belief == 255) / np.sum(self.ground_truth == 255) >= 0.99: + done = True + + return done, num_targets_found + + + def calculate_num_observed_frontiers(self, old_frontiers, frontiers): + frontiers_to_check = frontiers[:, 0] + frontiers[:, 1] * 1j + pre_frontiers_to_check = old_frontiers[:, 0] + old_frontiers[:, 1] * 1j + frontiers_num = np.intersect1d(frontiers_to_check, pre_frontiers_to_check).shape[0] + pre_frontiers_num = pre_frontiers_to_check.shape[0] + delta_num = pre_frontiers_num - frontiers_num + + return delta_num + + def calculate_reward(self, dist, frontiers): + reward = 0 + reward -= dist / 64 + + frontiers_to_check = frontiers[:, 0] + frontiers[:, 1] * 1j + pre_frontiers_to_check = self.frontiers[:, 0] + self.frontiers[:, 1] * 1j + frontiers_num = np.intersect1d(frontiers_to_check, pre_frontiers_to_check).shape[0] + pre_frontiers_num = pre_frontiers_to_check.shape[0] + delta_num = pre_frontiers_num - frontiers_num + + reward += delta_num / 50 + + return reward + + def evaluate_exploration_rate(self): + rate = np.sum(self.coverage_belief == 255) / np.sum(self.ground_truth == 255) + return rate + + def evaluate_targets_found_rate(self): + if len(self.target_positions) == 0: + return 0 + else: + rate = self.num_targets_found / len(self.target_positions) + return rate + + def calculate_new_free_area(self): + old_free_area = self.old_robot_belief == 255 + current_free_area = self.robot_belief == 255 + + new_free_area = (current_free_area.astype(np.int) - old_free_area.astype(np.int)) * 255 + + return new_free_area, np.sum(old_free_area) + + def calculate_dist_path(self, path): + dist = 0 + start = path[0] + end = path[-1] + for index in path: + if index == end: + break + dist += np.linalg.norm(self.node_coords[start] - self.node_coords[index]) + start = index + return dist + + def find_frontier(self): + y_len = self.downsampled_belief.shape[0] + x_len = self.downsampled_belief.shape[1] + mapping = self.downsampled_belief.copy() + belief = self.downsampled_belief.copy() + # 0-1 unknown area map + mapping = (mapping == 127) * 1 + mapping = np.lib.pad(mapping, ((1, 1), (1, 1)), 'constant', constant_values=0) + fro_map = mapping[2:][:, 1:x_len + 1] + mapping[:y_len][:, 1:x_len + 1] + mapping[1:y_len + 1][:, 2:] + \ + mapping[1:y_len + 1][:, :x_len] + mapping[:y_len][:, 2:] + mapping[2:][:, :x_len] + mapping[2:][:, + 2:] + \ + mapping[:y_len][:, :x_len] + ind_free = np.where(belief.ravel(order='F') == 255)[0] + ind_fron_1 = np.where(1 < fro_map.ravel(order='F'))[0] + ind_fron_2 = np.where(fro_map.ravel(order='F') < 8)[0] + ind_fron = np.intersect1d(ind_fron_1, ind_fron_2) + ind_to = np.intersect1d(ind_free, ind_fron) + + map_x = x_len + map_y = y_len + x = np.linspace(0, map_x - 1, map_x) + y = np.linspace(0, map_y - 1, map_y) + t1, t2 = np.meshgrid(x, y) + points = np.vstack([t1.T.ravel(), t2.T.ravel()]).T + + f = points[ind_to] + f = f.astype(int) + + f = f * self.resolution + + return f + + + + def plot_env(self, n, path, step, travel_dist, robots_route, img_path_override=None, sat_path_override=None, msk_name_override=None, sound_id_override=None): + + plt.switch_backend('agg') + plt.cla() + color_list = ["r", "g", "c", "m", "y", "k"] + + if not LOAD_AVS_BENCH: + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5)) + else: + fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(20, 5.5)) + + ### Fig: Segmentation Mask ### + if LOAD_AVS_BENCH: + ax = ax1 + image = mpimg.imread(img_path_override) + ax.imshow(image) + ax.set_title("Ground Image") + ax.axis("off") + + ### Fig: Environment ### + msk_name = "" + if LOAD_AVS_BENCH: + image = mpimg.imread(sat_path_override) + msk_name = msk_name_override + + ### Fig1: Environment ### + ax = ax2 + ax.imshow(image) + ax.axis((0, self.ground_truth_size[1], self.ground_truth_size[0], 0)) + ax.set_title("Image") + for i, route in enumerate(robots_route): + robot_marker_color = color_list[i % len(color_list)] + xPoints = route[0] + yPoints = route[1] + ax.plot(xPoints, yPoints, c=robot_marker_color, linewidth=2) + ax.plot(xPoints[-1], yPoints[-1], markersize=12, zorder=99, marker="^", ls="-", c=robot_marker_color, mec="black") + ax.plot(xPoints[0], yPoints[0], 'co', c=robot_marker_color, markersize=8, zorder=5) + + # Sensor range + rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH) + rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT) + max_x = min(xPoints[-1] + int(math.ceil(rng_x)), self.ground_truth.shape[1]) + min_x = max(xPoints[-1] - int(math.ceil(rng_x)), 0) + max_y = min(yPoints[-1] + int(math.ceil(rng_y)), self.ground_truth.shape[0]) + min_y = max(yPoints[-1] - int(math.ceil(rng_y)), 0) + ax.plot((min_x, min_x), (min_y, max_y), c=robot_marker_color, linewidth=1) + ax.plot((min_x, max_x), (max_y, max_y), c=robot_marker_color, linewidth=1) + ax.plot((max_x, max_x), (max_y, min_y), c=robot_marker_color, linewidth=1) + ax.plot((max_x, min_x), (min_y, min_y), c=robot_marker_color, linewidth=1) + + + ### Fig: Graph ### + ax = ax3 if LOAD_AVS_BENCH else ax1 + ax.imshow(self.coverage_belief, cmap='gray') + ax.axis((0, self.ground_truth_size[1], self.ground_truth_size[0], 0)) + ax.set_title("Information Graph") + if VIZ_GRAPH_EDGES: + for i in range(len(self.graph_generator.x)): + ax.plot(self.graph_generator.x[i], self.graph_generator.y[i], 'tan', zorder=1) + ax.scatter(self.node_coords[:, 0], self.node_coords[:, 1], c=self.filtered_seg_info_mask, zorder=5, s=8) + + for i, route in enumerate(robots_route): + robot_marker_color = color_list[i % len(color_list)] + xPoints = route[0] + yPoints = route[1] + ax.plot(xPoints, yPoints, c=robot_marker_color, linewidth=2) + ax.plot(xPoints[-1], yPoints[-1], markersize=12, zorder=99, marker="^", ls="-", c=robot_marker_color, mec="black") + ax.plot(xPoints[0], yPoints[0], 'co', c=robot_marker_color, markersize=8, zorder=5) + + # Sensor range + rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH) + rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT) + max_x = min(xPoints[-1] + int(math.ceil(rng_x)), self.ground_truth.shape[1]) + min_x = max(xPoints[-1] - int(math.ceil(rng_x)), 0) + max_y = min(yPoints[-1] + int(math.ceil(rng_y)), self.ground_truth.shape[0]) + min_y = max(yPoints[-1] - int(math.ceil(rng_y)), 0) + ax.plot((min_x, min_x), (min_y, max_y), c=robot_marker_color, linewidth=1) + ax.plot((min_x, max_x), (max_y, max_y), c=robot_marker_color, linewidth=1) + ax.plot((max_x, max_x), (max_y, min_y), c=robot_marker_color, linewidth=1) + ax.plot((max_x, min_x), (min_y, min_y), c=robot_marker_color, linewidth=1) + + # Plot target positions + for target in self.target_positions: + if self.coverage_belief[target[1], target[0]] == 255: + ax.plot(target[0], target[1], color='g', marker='x', linestyle='-', markersize=12, markeredgewidth=4, zorder=99) + else: + ax.plot(target[0], target[1], color='r', marker='x', linestyle='-', markersize=12, markeredgewidth=4, zorder=99) + + ### Fig: Segmentation Mask ### + ax = ax4 if LOAD_AVS_BENCH else ax2 + if LOAD_AVS_BENCH and USE_CLIP_PREDS: + H, W = self.ground_truth_size + mask_viz = self.segmentation_info_mask.squeeze().reshape((NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT)).T + im = ax.imshow( + mask_viz, + cmap="viridis", + origin="upper", + extent=[0, W, H, 0], + interpolation="nearest", + zorder=0, + ) + ax.set_xlim(0, W) + ax.set_ylim(H, 0) + ax.set_axis_off() + else: + im = ax.imshow(self.segmentation_mask.mean(axis=-1), cmap='viridis', vmin=0, vmax=100) # cmap='gray' + ax.axis((0, self.ground_truth_size[1], self.ground_truth_size[0], 0)) + ax.set_title(f"Predicted Mask (Normalized)") + for i, route in enumerate(robots_route): + robot_marker_color = color_list[i % len(color_list)] + xPoints = route[0] + yPoints = route[1] + ax.plot(xPoints, yPoints, c=robot_marker_color, linewidth=2) + ax.plot(xPoints[-1], yPoints[-1], markersize=12, zorder=99, marker="^", ls="-", c=robot_marker_color, mec="black") + ax.plot(xPoints[0], yPoints[0], 'co', c=robot_marker_color, markersize=8, zorder=5) + + # Sensor range + rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH) + rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT) + max_x = min(xPoints[-1] + int(math.ceil(rng_x)), self.ground_truth.shape[1]) + min_x = max(xPoints[-1] - int(math.ceil(rng_x)), 0) + max_y = min(yPoints[-1] + int(math.ceil(rng_y)), self.ground_truth.shape[0]) + min_y = max(yPoints[-1] - int(math.ceil(rng_y)), 0) + ax.plot((min_x, min_x), (min_y, max_y), c=robot_marker_color, linewidth=1) + ax.plot((min_x, max_x), (max_y, max_y), c=robot_marker_color, linewidth=1) + ax.plot((max_x, max_x), (max_y, min_y), c=robot_marker_color, linewidth=1) + ax.plot((max_x, min_x), (min_y, min_y), c=robot_marker_color, linewidth=1) + + # Add a colorbar + cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) + cbar.set_label("Normalized Probs") + + if sound_id_override is not None: + plt.suptitle('Targets Found: {}/{} Coverage ratio: {:.4g} Travel Dist: {:.4g} \n ({}) \n (Sound ID: {})'.format(self.num_targets_found, len(self.target_positions), self.explored_rate, travel_dist, msk_name, sound_id_override)) + elif msk_name != "": + plt.suptitle('Targets Found: {}/{} Coverage ratio: {:.4g} Travel Dist: {:.4g} \n ({})'.format(self.num_targets_found, len(self.target_positions), self.explored_rate, travel_dist, msk_name)) + else: + plt.suptitle('Targets Found: {}/{} Coverage ratio: {:.4g} Travel Dist: {:.4g}'.format(self.num_targets_found, len(self.target_positions), self.explored_rate, travel_dist)) + + plt.tight_layout() + plt.savefig('{}/{}_{}_samples.png'.format(path, n, step, dpi=100)) + frame = '{}/{}_{}_samples.png'.format(path, n, step) + self.frame_files.append(frame) + plt.close() + + + #################### + # ADDED: For app.py + #################### + + def plot_heatmap(self, save_dir, step, travel_dist, robots_route=None): + """Plot only the segmentation heatmap and save it as ``{step}.png`` in + ``save_dir``. This lightweight helper is meant for asynchronous + streaming in the Gradio demo when full `plot_env` is too heavy. + + Parameters + ---------- + save_dir : str + Directory to save the generated PNG file. + step : int + Current timestep; becomes the filename ``{step}.png``. + robots_route : list | None + Optional list of routes (xPoints, yPoints) to overlay. + Returns + ------- + str + Full path to the generated PNG file. + """ + import os + plt.switch_backend('agg') + # Do not clear the global figure state in case it interferes with + # the current figure. Each call creates its own Figure object that + # we close explicitly at the end, so a global clear is unnecessary + # and may break concurrent drawing. + # plt.cla() + + color_list = ["r", "g", "c", "m", "y", "k"] + fig, ax = plt.subplots(1, 1, figsize=(6, 6)) + + # Select the mask to visualise + # if TAXABIND_TTA and USE_CLIP_PREDS: + side_dim = int(np.sqrt(self.segmentation_info_mask.shape[0])) + mask_viz = self.segmentation_info_mask.squeeze().reshape((side_dim, side_dim)).T + + # Properly map image to pixel coordinates and keep limits fixed + H, W = self.ground_truth_size # rows (y), cols (x) + im = ax.imshow( + mask_viz, + cmap="viridis", + origin="upper", + extent=[0, W, H, 0], # x: 0..W, y: H..0 (origin at top-left) + interpolation="nearest", # keep cell edges sharp & aligned + zorder=0, + ) + ax.set_xlim(0, W) + ax.set_ylim(H, 0) + ax.set_axis_off() # hide ticks but keep limits + # else: + # im = ax.imshow(self.segmentation_mask.mean(axis=-1), cmap='viridis', vmin=0, vmax=100) + # ax.axis((0, self.ground_truth_size[1], self.ground_truth_size[0], 0)) + + # Optionally overlay robot paths + if robots_route is not None: + for i, route in enumerate(robots_route): + robot_marker_color = color_list[i % len(color_list)] + xPoints, yPoints = route + ax.plot(xPoints, yPoints, c=robot_marker_color, linewidth=2) + ax.plot(xPoints[-1], yPoints[-1], markersize=12, zorder=99, marker="^", ls="-", c=robot_marker_color, mec="black") + ax.plot(xPoints[0], yPoints[0], 'co', c=robot_marker_color, markersize=8, zorder=5) + + # Plot target positions + for target in self.target_positions: + if self.coverage_belief[target[1], target[0]] == 255: + # ax.plot(target[0], target[1], 'go', markersize=8, zorder=99) + ax.plot(target[0], target[1], color='g', marker='x', linestyle='-', markersize=12, markeredgewidth=4, zorder=99) + else: + # ax.plot(target[0], target[1], 'ro', markersize=8, zorder=99) + ax.plot(target[0], target[1], color='r', marker='x', linestyle='-', markersize=12, markeredgewidth=4, zorder=99) + + # Sensor range + rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH) + rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT) + max_x = min(xPoints[-1] + int(math.ceil(rng_x)), self.ground_truth.shape[1]) + min_x = max(xPoints[-1] - int(math.ceil(rng_x)), 0) + max_y = min(yPoints[-1] + int(math.ceil(rng_y)), self.ground_truth.shape[0]) + min_y = max(yPoints[-1] - int(math.ceil(rng_y)), 0) + ax.plot((min_x, min_x), (min_y, max_y), c=robot_marker_color, linewidth=1) + ax.plot((min_x, max_x), (max_y, max_y), c=robot_marker_color, linewidth=1) + ax.plot((max_x, max_x), (max_y, min_y), c=robot_marker_color, linewidth=1) + ax.plot((max_x, min_x), (min_y, min_y), c=robot_marker_color, linewidth=1) + + # Color bar + cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) + cbar.set_label("Normalized Probs") + + # Change coverage to 1dp + plt.suptitle('Targets Found: {}/{} Coverage: {:.1f}% Steps: {}/{}'.format( + self.num_targets_found, \ + len(self.target_positions), + self.explored_rate*100, + step+1, + NUM_EPS_STEPS), + y=0.94, # Closer to plot + ) + + plt.tight_layout() + os.makedirs(save_dir, exist_ok=True) + out_path = os.path.join(save_dir, f"{step}.png") + # Save atomically: write to temp file then move into place so the poller never sees a partial file. + tmp_path = out_path + ".tmp" + fig.savefig(tmp_path, dpi=100, format='png') + os.replace(tmp_path, out_path) # atomic on same filesystem + plt.close(fig) + return out_path \ No newline at end of file diff --git a/planner/graph.py b/planner/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..68d671478edcf687c8b474160715358584631318 --- /dev/null +++ b/planner/graph.py @@ -0,0 +1,167 @@ +####################################################################### +# Name: env.py +# +# - Adapted from https://gist.github.com/betandr/541a1f6466b6855471de5ca30b74cb31 +# - Simple graph class to perform distance calculations (E.g. A-Star, Djikstra) +####################################################################### + + +class Edge: + def __init__(self, to_node, length): + self.to_node = to_node + self.length = length + + +class Graph: + def __init__(self): + self.nodes = set() + self.edges = dict() + + def add_node(self, node): + self.nodes.add(node) + + def add_edge(self, from_node, to_node, length): + edge = Edge(to_node, length) + if from_node in self.edges: + from_node_edges = self.edges[from_node] + else: + self.edges[from_node] = dict() + from_node_edges = self.edges[from_node] + from_node_edges[to_node] = edge + + def clear_edge(self, from_node): + if from_node in self.edges: + self.edges[from_node] = dict() + +def min_dist(q, dist): + """ + Returns the node with the smallest distance in q. + Implemented to keep the main algorithm clean. + """ + min_node = None + for node in q: + if min_node == None: + min_node = node + elif dist[node] < dist[min_node]: + min_node = node + + return min_node + + +INFINITY = float('Infinity') + + +def dijkstra(graph, source): + q = set() + dist = {} + prev = {} + + for v in graph.nodes: + dist[v] = INFINITY # unknown distance from source to v + prev[v] = INFINITY # previous node in optimal path from source + q.add(v) # all nodes initially in q (unvisited nodes) + + # distance from source to source + dist[source] = 0 + + while q: + # node with the least distance selected first + u = min_dist(q, dist) + + q.remove(u) + + try: + if u in graph.edges: + for _, v in graph.edges[u].items(): + alt = dist[u] + v.length + if alt < dist[v.to_node]: + # a shorter path to v has been found + dist[v.to_node] = alt + prev[v.to_node] = u + except: + pass + + return dist, prev + + +def to_array(prev, from_node): + """Creates an ordered list of labels as a route.""" + previous_node = prev[from_node] + route = [from_node] + + while previous_node != INFINITY: + route.append(previous_node) + temp = previous_node + previous_node = prev[temp] + + route.reverse() + return route + + +def h(index, destination, node_coords): + current = node_coords[index] + end = node_coords[destination] + h = abs(end[0] - current[0]) + abs(end[1] - current[1]) + return h + + +def a_star(start, destination, node_coords, graph): + if start == destination: + return [], 0 + if str(destination) in graph.edges[str(start)].keys(): + cost = graph.edges[str(start)][str(destination)].length + return [start, destination], cost + open_list = {start} + closed_list = set([]) + + g = {start: 0} + parents = {start: start} + + while len(open_list) > 0: + n = None + h_n = 1e5 + for v in open_list: + h_v = h(v, destination, node_coords) + if n is not None: + h_n = h(n, destination, node_coords) + if n is None or g[v] + h_v < g[n] + h_n: + n = v + + if n is None: + print('Path does not exist!') + return None, 1e5 + + if n == destination: + reconst_path = [] + while parents[n] != n: + reconst_path.append(n) + n = parents[n] + reconst_path.append(start) + reconst_path.reverse() + return reconst_path, g[destination] + + for edge in graph.edges[str(n)].values(): + m = int(edge.to_node) + cost = edge.length + if m not in open_list and m not in closed_list: + open_list.add(m) + parents[m] = n + g[m] = g[n] + cost + + else: + if g[m] > g[n] + cost: + g[m] = g[n] + cost + parents[m] = n + + if m in closed_list: + closed_list.remove(m) + open_list.add(m) + + open_list.remove(n) + closed_list.add(n) + + print('Path does not exist!') + return None, 1e5 + + + diff --git a/planner/graph_generator.py b/planner/graph_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..89617c2786227681280237138751373b7b30f589 --- /dev/null +++ b/planner/graph_generator.py @@ -0,0 +1,300 @@ +####################################################################### +# Name: graph_generator.py +# +# - Wrapper for graph.py +# - Sends the formatted inputs into graph.py to get useful info +####################################################################### + +import sys +if sys.modules['TRAINING']: + from .parameter import * +else: + from .test_parameter import * + +import numpy as np +import shapely.geometry +from sklearn.neighbors import NearestNeighbors +from .node import Node +from .graph import Graph, a_star + + +class Graph_generator: + def __init__(self, map_size, k_size, sensor_range, plot=False): + self.k_size = k_size + self.graph = Graph() + self.node_coords = None + self.plot = plot + self.x = [] + self.y = [] + self.map_x = map_size[1] + self.map_y = map_size[0] + self.uniform_points, self.grid_coords = self.generate_uniform_points() + self.sensor_range = sensor_range + self.route_node = [] + self.nodes_list = [] + self.node_utility = None + self.guidepost = None + + + def edge_clear_all_nodes(self): + self.graph = Graph() + self.x = [] + self.y = [] + + + def edge_clear(self, coords): + node_index = str(self.find_index_from_coords(self.node_coords, coords)) + self.graph.clear_edge(node_index) + + + def generate_graph(self, robot_belief, frontiers): + self.edge_clear_all_nodes() + free_area = self.free_area(robot_belief) + + free_area_to_check = free_area[:, 0] + free_area[:, 1] * 1j + uniform_points_to_check = self.uniform_points[:, 0] + self.uniform_points[:, 1] * 1j + _, _, candidate_indices = np.intersect1d(free_area_to_check, uniform_points_to_check, return_indices=True) + node_coords = self.uniform_points[candidate_indices] + + self.node_coords = self.unique_coords(node_coords).reshape(-1, 2) + self.find_nearest_neighbor_all_nodes(self.node_coords, robot_belief) + + self.node_utility = [] + for coords in self.node_coords: + node = Node(coords, frontiers, robot_belief) + self.nodes_list.append(node) + utility = node.utility + self.node_utility.append(utility) + self.node_utility = np.array(self.node_utility) + + self.guidepost = np.zeros((self.node_coords.shape[0], 1)) + x = self.node_coords[:,0] + self.node_coords[:,1]*1j + for node in self.route_node: + index = np.argwhere(x.reshape(-1) == node[0]+node[1]*1j)[0] + self.guidepost[index] = 1 + + return self.node_coords, self.graph.edges, self.node_utility, self.guidepost + + + def update_graph(self, robot_belief, old_robot_belief, frontiers, old_frontiers): + new_free_area = self.free_area((robot_belief - old_robot_belief > 0) * 255) + free_area_to_check = new_free_area[:, 0] + new_free_area[:, 1] * 1j + uniform_points_to_check = self.uniform_points[:, 0] + self.uniform_points[:, 1] * 1j + _, _, candidate_indices = np.intersect1d(free_area_to_check, uniform_points_to_check, return_indices=True) + new_node_coords = self.uniform_points[candidate_indices] + self.node_coords = np.concatenate((self.node_coords, new_node_coords)) + + old_node_to_update = [] + for coords in new_node_coords: + neighbor_indices = self.find_k_neighbor(coords, self.node_coords, robot_belief) + old_node_to_update += neighbor_indices + old_node_to_update = set(old_node_to_update) + for index in old_node_to_update: + coords = self.node_coords[index] + self.edge_clear(coords) + self.find_k_neighbor(coords, self.node_coords, robot_belief) + + old_frontiers_to_check = old_frontiers[:, 0] + old_frontiers[:, 1] * 1j + new_frontiers_to_check = frontiers[:, 0] + frontiers[:, 1] * 1j + observed_frontiers_index = np.where( + np.isin(old_frontiers_to_check, new_frontiers_to_check, assume_unique=True) == False) + new_frontiers_index = np.where( + np.isin(new_frontiers_to_check, old_frontiers_to_check, assume_unique=True) == False) + observed_frontiers = old_frontiers[observed_frontiers_index] + new_frontiers = frontiers[new_frontiers_index] + for node in self.nodes_list: + if node.zero_utility_node is True: + pass + else: + node.update_observable_frontiers(observed_frontiers, new_frontiers, robot_belief) + + for new_coords in new_node_coords: + node = Node(new_coords, frontiers, robot_belief) + self.nodes_list.append(node) + + self.node_utility = [] + for i, coords in enumerate(self.node_coords): + utility = self.nodes_list[i].utility + self.node_utility.append(utility) + self.node_utility = np.array(self.node_utility) + + self.guidepost = np.zeros((self.node_coords.shape[0], 1)) + x = self.node_coords[:, 0] + self.node_coords[:, 1] * 1j + for node in self.route_node: + index = np.argwhere(x.reshape(-1) == node[0] + node[1] * 1j) + self.guidepost[index] = 1 + + return self.node_coords, self.graph.edges, self.node_utility, self.guidepost + + + def generate_uniform_points(self): + padding_x = 0.5 * (self.map_x / NUM_COORDS_WIDTH) + padding_y = 0.5 * (self.map_y / NUM_COORDS_HEIGHT) + x = np.linspace(padding_x, self.map_x - padding_x - 1, NUM_COORDS_WIDTH).round().astype(int) + y = np.linspace(padding_y, self.map_y - padding_y - 1, NUM_COORDS_HEIGHT).round().astype(int) + + t1, t2 = np.meshgrid(x, y) + points = np.vstack([t1.T.ravel(), t2.T.ravel()]).T + matrix = np.stack((t1, t2), axis=-1) + return points, matrix + + + def free_area(self, robot_belief): + index = np.where(robot_belief == 255) + free = np.asarray([index[1], index[0]]).T + return free + + + def unique_coords(self, coords): + x = coords[:, 0] + coords[:, 1] * 1j + indices = np.unique(x, return_index=True)[1] + coords = np.array([coords[idx] for idx in sorted(indices)]) + return coords + + + def find_k_neighbor(self, coords, node_coords, robot_belief): + dist_list = np.linalg.norm((coords-node_coords), axis=-1) + sorted_index = np.argsort(dist_list) + k = 0 + neighbor_index_list = [] + while k < self.k_size and k< node_coords.shape[0]: + neighbor_index = sorted_index[k] + neighbor_index_list.append(neighbor_index) + dist = dist_list[k] + start = coords + end = node_coords[neighbor_index] + if not self.check_collision(start, end, robot_belief): + a = str(self.find_index_from_coords(node_coords, start)) + b = str(neighbor_index) + self.graph.add_node(a) + self.graph.add_edge(a, b, dist) + + if self.plot: + self.x.append([start[0], end[0]]) + self.y.append([start[1], end[1]]) + k += 1 + return neighbor_index_list + + + def find_k_neighbor_all_nodes(self, node_coords, robot_belief): + X = node_coords + if len(node_coords) >= self.k_size: + knn = NearestNeighbors(n_neighbors=self.k_size) + else: + knn = NearestNeighbors(n_neighbors=len(node_coords)) + knn.fit(X) + distances, indices = knn.kneighbors(X) + + for i, p in enumerate(X): + for j, neighbour in enumerate(X[indices[i][:]]): + start = p + end = neighbour + if not self.check_collision(start, end, robot_belief): + a = str(self.find_index_from_coords(node_coords, p)) + b = str(self.find_index_from_coords(node_coords, neighbour)) + self.graph.add_node(a) + self.graph.add_edge(a, b, distances[i, j]) + + if self.plot: + self.x.append([p[0], neighbour[0]]) + self.y.append([p[1], neighbour[1]]) + + + def find_nearest_neighbor_all_nodes(self, node_coords, robot_belief): + for i, p in enumerate(node_coords): + filtered_coords = self.get_neighbors_grid_coords(p) + + for j, neighbour in enumerate(filtered_coords): + start = p + end = neighbour + if not self.check_collision(start, end, robot_belief): + a = str(self.find_index_from_coords(node_coords, p)) + b = str(self.find_index_from_coords(node_coords, neighbour)) + self.graph.add_node(a) + self.graph.add_edge(a, b, np.linalg.norm(start-end)) + + if self.plot: + self.x.append([p[0], neighbour[0]]) + self.y.append([p[1], neighbour[1]]) + + + def find_index_from_coords(self, node_coords, p): + return np.where(np.linalg.norm(node_coords - p, axis=1) < 1e-5)[0][0] + + + def find_closest_index_from_coords(self, node_coords, p): + return np.argmin(np.linalg.norm(node_coords - p, axis=1)) + + + def find_index_from_grid_coords_2d(self, p): + diffs = np.linalg.norm(self.grid_coords - p, axis=2) + indices = np.where(diffs < 1e-5) + + if indices[0].size > 0: + return indices[0][0], indices[1][0] + else: + raise ValueError(f"Coordinate {p} not found in self.grid_coords.") + + + def find_closest_index_from_grid_coords_2d(self, p): + distances = np.linalg.norm(self.grid_coords - p, axis=2) + flat_index = np.argmin(distances) + return np.unravel_index(flat_index, distances.shape) + + + def check_collision(self, start, end, robot_belief): + collision = False + line = shapely.geometry.LineString([start, end]) + + sortx = np.sort([start[0], end[0]]) + sorty = np.sort([start[1], end[1]]) + + robot_belief = robot_belief[sorty[0]:sorty[1]+1, sortx[0]:sortx[1]+1] + + occupied_area_index = np.where(robot_belief == 1) + occupied_area_coords = np.asarray([occupied_area_index[1]+sortx[0], occupied_area_index[0]+sorty[0]]).T + unexplored_area_index = np.where(robot_belief == 127) + unexplored_area_coords = np.asarray([unexplored_area_index[1]+sortx[0], unexplored_area_index[0]+sorty[0]]).T + unfree_area_coords = np.concatenate((occupied_area_coords, unexplored_area_coords)) + + for i in range(unfree_area_coords.shape[0]): + coords = ([(unfree_area_coords[i][0], unfree_area_coords[i][1]), + (unfree_area_coords[i][0] + 1, unfree_area_coords[i][1]), + (unfree_area_coords[i][0], unfree_area_coords[i][1] + 1), + (unfree_area_coords[i][0] + 1, unfree_area_coords[i][1] + 1)]) + obstacle = shapely.geometry.Polygon(coords) + collision = line.intersects(obstacle) + if collision: + break + + return collision + + + def find_shortest_path(self, current, destination, node_coords): + start_node = str(self.find_index_from_coords(node_coords, current)) + end_node = str(self.find_index_from_coords(node_coords, destination)) + route, dist = a_star(int(start_node), int(end_node), self.node_coords, self.graph) + if start_node != end_node: + assert route != [] + route = list(map(str, route)) + return dist, route + + def get_neighbors_grid_coords(self, coord): + # Return the 8 closest neighbors of a given coordinate + + nearest_coord = self.node_coords[self.find_closest_index_from_coords(self.node_coords, coord)] + rows, cols = self.grid_coords.shape[:2] + neighbors = [] + i, j = self.find_index_from_grid_coords_2d(nearest_coord) + + # Create a range of indices for rows and columns + row_range = np.clip([i - 1, i, i + 1], 0, rows - 1) + col_range = np.clip([j - 1, j, j + 1], 0, cols - 1) + + # Iterate over the valid indices + for ni in row_range: + for nj in col_range: + if (ni, nj) != (i, j): # Skip the center point + neighbors.append(tuple(self.grid_coords[ni, nj])) + + return neighbors \ No newline at end of file diff --git a/planner/model.py b/planner/model.py new file mode 100644 index 0000000000000000000000000000000000000000..5cede92ba760d7cc0f789397cc11aa8e0b7a61cb --- /dev/null +++ b/planner/model.py @@ -0,0 +1,312 @@ +####################################################################### +# Name: model.py +# +# - Attention-based encoders & decoders +# - Policy Net: Input = Augmented Graph, Output = Node to go to +# - Critic Net: Input = Augmented Graph + Action, Output = Q_Value +####################################################################### + +import torch +import torch.nn as nn +import math + + +class SingleHeadAttention(nn.Module): + def __init__(self, embedding_dim): + super(SingleHeadAttention, self).__init__() + self.input_dim = embedding_dim + self.embedding_dim = embedding_dim + self.value_dim = embedding_dim + self.key_dim = self.value_dim + self.tanh_clipping = 10 + self.norm_factor = 1 / math.sqrt(self.key_dim) + + self.w_query = nn.Parameter(torch.Tensor(self.input_dim, self.key_dim)) + self.w_key = nn.Parameter(torch.Tensor(self.input_dim, self.key_dim)) + + self.init_parameters() + + def init_parameters(self): + for param in self.parameters(): + stdv = 1. / math.sqrt(param.size(-1)) + param.data.uniform_(-stdv, stdv) + + def forward(self, q, k, mask=None): + + n_batch, n_key, n_dim = k.size() + n_query = q.size(1) + + k_flat = k.reshape(-1, n_dim) + q_flat = q.reshape(-1, n_dim) + + shape_k = (n_batch, n_key, -1) + shape_q = (n_batch, n_query, -1) + + Q = torch.matmul(q_flat, self.w_query).view(shape_q) + K = torch.matmul(k_flat, self.w_key).view(shape_k) + + U = self.norm_factor * torch.matmul(Q, K.transpose(1, 2)) + U = self.tanh_clipping * torch.tanh(U) + + if mask is not None: + U = U.masked_fill(mask == 1, -1e8) + attention = torch.log_softmax(U, dim=-1) # n_batch*n_query*n_key + + return attention + + +class MultiHeadAttention(nn.Module): + def __init__(self, embedding_dim, n_heads=8): + super(MultiHeadAttention, self).__init__() + self.n_heads = n_heads + self.input_dim = embedding_dim + self.embedding_dim = embedding_dim + self.value_dim = self.embedding_dim // self.n_heads + self.key_dim = self.value_dim + self.norm_factor = 1 / math.sqrt(self.key_dim) + + self.w_query = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.key_dim)) + self.w_key = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.key_dim)) + self.w_value = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.value_dim)) + self.w_out = nn.Parameter(torch.Tensor(self.n_heads, self.value_dim, self.embedding_dim)) + + self.init_parameters() + + def init_parameters(self): + for param in self.parameters(): + stdv = 1. / math.sqrt(param.size(-1)) + param.data.uniform_(-stdv, stdv) + + def forward(self, q, k=None, v=None, key_padding_mask=None, attn_mask=None): + if k is None: + k = q + if v is None: + v = q + + n_batch, n_key, n_dim = k.size() + n_query = q.size(1) + n_value = v.size(1) + + k_flat = k.contiguous().view(-1, n_dim) + v_flat = v.contiguous().view(-1, n_dim) + q_flat = q.contiguous().view(-1, n_dim) + shape_v = (self.n_heads, n_batch, n_value, -1) + shape_k = (self.n_heads, n_batch, n_key, -1) + shape_q = (self.n_heads, n_batch, n_query, -1) + + Q = torch.matmul(q_flat, self.w_query).view(shape_q) # n_heads*batch_size*n_query*key_dim + K = torch.matmul(k_flat, self.w_key).view(shape_k) # n_heads*batch_size*targets_size*key_dim + V = torch.matmul(v_flat, self.w_value).view(shape_v) # n_heads*batch_size*targets_size*value_dim + + U = self.norm_factor * torch.matmul(Q, K.transpose(2, 3)) # n_heads*batch_size*n_query*targets_size + + if attn_mask is not None: + attn_mask = attn_mask.view(1, n_batch, n_query, n_key).expand_as(U) + + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.repeat(1, n_query, 1) + key_padding_mask = key_padding_mask.view(1, n_batch, n_query, n_key).expand_as(U) # copy for n_heads times + + if attn_mask is not None and key_padding_mask is not None: + mask = (attn_mask + key_padding_mask) + elif attn_mask is not None: + mask = attn_mask + elif key_padding_mask is not None: + mask = key_padding_mask + else: + mask = None + + if mask is not None: + U = U.masked_fill(mask > 0, -1e8) + + attention = torch.softmax(U, dim=-1) # n_heads*batch_size*n_query*targets_size + heads = torch.matmul(attention, V) # n_heads*batch_size*n_query*value_dim + out = torch.mm( + heads.permute(1, 2, 0, 3).reshape(-1, self.n_heads * self.value_dim), + # batch_size*n_query*n_heads*value_dim + self.w_out.view(-1, self.embedding_dim) + # n_heads*value_dim*embedding_dim + ).view(-1, n_query, self.embedding_dim) + + + return out, attention # batch_size*n_query*embedding_dim + + +class Normalization(nn.Module): + def __init__(self, embedding_dim): + super(Normalization, self).__init__() + self.normalizer = nn.LayerNorm(embedding_dim) + + def forward(self, input): + return self.normalizer(input.view(-1, input.size(-1))).view(*input.size()) + + +class EncoderLayer(nn.Module): + def __init__(self, embedding_dim, n_head): + super(EncoderLayer, self).__init__() + self.multiHeadAttention = MultiHeadAttention(embedding_dim, n_head) + self.normalization1 = Normalization(embedding_dim) + self.feedForward = nn.Sequential(nn.Linear(embedding_dim, 512), nn.ReLU(inplace=True), + nn.Linear(512, embedding_dim)) + self.normalization2 = Normalization(embedding_dim) + + def forward(self, src, key_padding_mask=None, attn_mask=None): + h0 = src + h = self.normalization1(src) + h, _ = self.multiHeadAttention(q=h, key_padding_mask=key_padding_mask, attn_mask=attn_mask) + h = h + h0 + h1 = h + h = self.normalization2(h) + h = self.feedForward(h) + h2 = h + h1 + return h2 + + +class DecoderLayer(nn.Module): + def __init__(self, embedding_dim, n_head): + super(DecoderLayer, self).__init__() + self.multiHeadAttention = MultiHeadAttention(embedding_dim, n_head) + self.normalization1 = Normalization(embedding_dim) + self.feedForward = nn.Sequential(nn.Linear(embedding_dim, 512), + nn.ReLU(inplace=True), + nn.Linear(512, embedding_dim)) + self.normalization2 = Normalization(embedding_dim) + + def forward(self, tgt, memory, key_padding_mask=None, attn_mask=None): + h0 = tgt + tgt = self.normalization1(tgt) + memory = self.normalization1(memory) + h, w = self.multiHeadAttention(q=tgt, k=memory, v=memory, key_padding_mask=key_padding_mask, attn_mask=attn_mask) + h = h + h0 + h1 = h + h = self.normalization2(h) + h = self.feedForward(h) + h2 = h + h1 + return h2, w + + +class Encoder(nn.Module): + def __init__(self, embedding_dim=128, n_head=8, n_layer=1): + super(Encoder, self).__init__() + self.layers = nn.ModuleList(EncoderLayer(embedding_dim, n_head) for i in range(n_layer)) + + def forward(self, src, key_padding_mask=None, attn_mask=None): + for layer in self.layers: + src = layer(src, key_padding_mask=key_padding_mask, attn_mask=attn_mask) + return src + + +class Decoder(nn.Module): + def __init__(self, embedding_dim=128, n_head=8, n_layer=1): + super(Decoder, self).__init__() + self.layers = nn.ModuleList([DecoderLayer(embedding_dim, n_head) for i in range(n_layer)]) + + def forward(self, tgt, memory, key_padding_mask=None, attn_mask=None): + for layer in self.layers: + tgt, w = layer(tgt, memory, key_padding_mask=key_padding_mask, attn_mask=attn_mask) + return tgt, w + + +class PolicyNet(nn.Module): + def __init__(self, input_dim, embedding_dim): + super(PolicyNet, self).__init__() + self.initial_embedding = nn.Linear(input_dim, embedding_dim) # layer for non-end position + + self.current_embedding = nn.Linear(embedding_dim * 2, embedding_dim) + + self.encoder = Encoder(embedding_dim=embedding_dim, n_head=8, n_layer=6) + self.decoder = Decoder(embedding_dim=embedding_dim, n_head=8, n_layer=1) + self.pointer = SingleHeadAttention(embedding_dim) + + def encode_graph(self, node_inputs, node_padding_mask, edge_mask): + node_feature = self.initial_embedding(node_inputs) + enhanced_node_feature = self.encoder(src=node_feature, key_padding_mask=node_padding_mask, attn_mask=edge_mask) + + return enhanced_node_feature + + def output_policy(self, enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask): + k_size = edge_inputs.size()[2] + current_edge = torch.gather(edge_inputs, 1, current_index.repeat(1, 1, k_size)) + current_edge = current_edge.permute(0, 2, 1) + embedding_dim = enhanced_node_feature.size()[2] + + neigboring_feature = torch.gather(enhanced_node_feature, 1, current_edge.repeat(1, 1, embedding_dim)) + + current_node_feature = torch.gather(enhanced_node_feature, 1, current_index.repeat(1, 1, embedding_dim)) + + if edge_padding_mask is not None: + current_mask = torch.gather(edge_padding_mask, 1, current_index.repeat(1,1,k_size)).to(enhanced_node_feature.device) + else: + current_mask = None + current_mask[:,:,0] = 1 # don't stay at current position + + if not 0 in current_mask: + current_mask[:,:,0] = 0 + + enhanced_current_node_feature, _ = self.decoder(current_node_feature, enhanced_node_feature, node_padding_mask) + enhanced_current_node_feature = self.current_embedding(torch.cat((enhanced_current_node_feature, current_node_feature), dim=-1)) + logp = self.pointer(enhanced_current_node_feature, neigboring_feature, current_mask) + logp= logp.squeeze(1) # batch_size*k_size + + return logp + + def forward(self, node_inputs, edge_inputs, current_index, node_padding_mask=None, edge_padding_mask=None, edge_mask=None): + enhanced_node_feature = self.encode_graph(node_inputs, node_padding_mask, edge_mask) + logp = self.output_policy(enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask) + return logp + + +class QNet(nn.Module): + def __init__(self, input_dim, embedding_dim): + super(QNet, self).__init__() + self.initial_embedding = nn.Linear(input_dim, embedding_dim) # layer for non-end position + self.action_embedding = nn.Linear(embedding_dim*3, embedding_dim) + + self.encoder = Encoder(embedding_dim=embedding_dim, n_head=8, n_layer=6) + self.decoder = Decoder(embedding_dim=embedding_dim, n_head=8, n_layer=1) + + self.q_values_layer = nn.Linear(embedding_dim, 1) + + def encode_graph(self, node_inputs, node_padding_mask, edge_mask): + embedding_feature = self.initial_embedding(node_inputs) + embedding_feature = self.encoder(src=embedding_feature, key_padding_mask=node_padding_mask, attn_mask=edge_mask) + + return embedding_feature + + def output_q_values(self, enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask): + k_size = edge_inputs.size()[2] + current_edge = torch.gather(edge_inputs, 1, current_index.repeat(1, 1, k_size)) + current_edge = current_edge.permute(0, 2, 1) + embedding_dim = enhanced_node_feature.size()[2] + + neigboring_feature = torch.gather(enhanced_node_feature, 1, current_edge.repeat(1, 1, embedding_dim)) + + current_node_feature = torch.gather(enhanced_node_feature, 1, current_index.repeat(1, 1, embedding_dim)) + + enhanced_current_node_feature, attention_weights = self.decoder(current_node_feature, enhanced_node_feature, node_padding_mask) + action_features = torch.cat((enhanced_current_node_feature.repeat(1, k_size, 1), current_node_feature.repeat(1, k_size, 1), neigboring_feature), dim=-1) + action_features = self.action_embedding(action_features) + q_values = self.q_values_layer(action_features) + + if edge_padding_mask is not None: + current_mask = torch.gather(edge_padding_mask, 1, current_index.repeat(1, 1, k_size)).to( + enhanced_node_feature.device) + else: + current_mask = None + current_mask[:, :, 0] = 1 # don't stay at current position + + if not 0 in current_mask: + current_mask[:,:,0] = 0 + + current_mask = current_mask.permute(0, 2, 1) + zero = torch.zeros_like(q_values).to(q_values.device) + q_values = torch.where(current_mask == 1, zero, q_values) + + return q_values, attention_weights + + def forward(self, node_inputs, edge_inputs, current_index, node_padding_mask=None, edge_padding_mask=None, + edge_mask=None): + enhanced_node_feature = self.encode_graph(node_inputs, node_padding_mask, edge_mask) + q_values, attention_weights = self.output_q_values(enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask) + return q_values, attention_weights + diff --git a/planner/node.py b/planner/node.py new file mode 100644 index 0000000000000000000000000000000000000000..b856ef1487226e8010bbc352ae7ecdb33276eef4 --- /dev/null +++ b/planner/node.py @@ -0,0 +1,96 @@ +####################################################################### +# Name: node.py +# +# - Contains info per node on graph (edge) +# - Contains: Position, Utility, Visitation History +####################################################################### + +import sys +if sys.modules['TRAINING']: + from .parameter import * +else: + from .test_parameter import * + +import numpy as np +import shapely.geometry + + +class Node(): + def __init__(self, coords, frontiers, robot_belief): + self.coords = coords + self.observable_frontiers = [] + self.sensor_range = SENSOR_RANGE + self.initialize_observable_frontiers(frontiers, robot_belief) + self.utility = self.get_node_utility() + if self.utility == 0: + self.zero_utility_node = True + else: + self.zero_utility_node = False + + def initialize_observable_frontiers(self, frontiers, robot_belief): + dist_list = np.linalg.norm(frontiers - self.coords, axis=-1) + frontiers_in_range = frontiers[dist_list < self.sensor_range - 10] + for point in frontiers_in_range: + collision = self.check_collision(self.coords, point, robot_belief) + if not collision: + self.observable_frontiers.append(point) + + def get_node_utility(self): + return len(self.observable_frontiers) + + def update_observable_frontiers(self, observed_frontiers, new_frontiers, robot_belief): + if observed_frontiers != []: + observed_index = [] + for i, point in enumerate(self.observable_frontiers): + if point[0] + point[1] * 1j in observed_frontiers[:, 0] + observed_frontiers[:, 1] * 1j: + observed_index.append(i) + for index in reversed(observed_index): + self.observable_frontiers.pop(index) + # + if new_frontiers != []: + dist_list = np.linalg.norm(new_frontiers - self.coords, axis=-1) + new_frontiers_in_range = new_frontiers[dist_list < self.sensor_range - 15] + for point in new_frontiers_in_range: + collision = self.check_collision(self.coords, point, robot_belief) + if not collision: + self.observable_frontiers.append(point) + + self.utility = self.get_node_utility() + if self.utility == 0: + self.zero_utility_node = True + else: + self.zero_utility_node = False + + def set_visited(self): + self.observable_frontiers = [] + self.utility = 0 + self.zero_utility_node = True + + def check_collision(self, start, end, robot_belief): + collision = False + line = shapely.geometry.LineString([start, end]) + + sortx = np.sort([start[0], end[0]]) + sorty = np.sort([start[1], end[1]]) + + robot_belief = robot_belief[sorty[0]:sorty[1] + 1, sortx[0]:sortx[1] + 1] + + occupied_area_index = np.where(robot_belief == 1) + occupied_area_coords = np.asarray( + [occupied_area_index[1] + sortx[0], occupied_area_index[0] + sorty[0]]).T + unexplored_area_index = np.where(robot_belief == 127) + unexplored_area_coords = np.asarray( + [unexplored_area_index[1] + sortx[0], unexplored_area_index[0] + sorty[0]]).T + unfree_area_coords = np.concatenate((occupied_area_coords, unexplored_area_coords)) + + for i in range(unfree_area_coords.shape[0]): + coords = ([(unfree_area_coords[i][0], unfree_area_coords[i][1]), + (unfree_area_coords[i][0] + 1, unfree_area_coords[i][1]), + (unfree_area_coords[i][0], unfree_area_coords[i][1] + 1), + (unfree_area_coords[i][0] + 1, unfree_area_coords[i][1] + 1)]) + obstacle = shapely.geometry.Polygon(coords) + collision = line.intersects(obstacle) + if collision: + break + + return collision diff --git a/planner/robot.py b/planner/robot.py new file mode 100644 index 0000000000000000000000000000000000000000..898c3146d7d7c0a2c6d4463e4184a6d89b93898d --- /dev/null +++ b/planner/robot.py @@ -0,0 +1,58 @@ +####################################################################### +# Name: robot.py +# +# - Stores S(t), A(t), R(t), S(t+1) +####################################################################### + +import torch +from copy import deepcopy + +class Robot: + def __init__(self, robot_id, position, plot=False): + self.robot_id = robot_id + self.plot = plot + self.travel_dist = 0 + self.robot_position = position + self.observations = None + self.trajectory_coords = [] + self.targets_found_on_path = [] + + self.episode_buffer = [] + for i in range(15): + self.episode_buffer.append([]) + + if self.plot: + self.xPoints = [self.robot_position[0]] + self.yPoints = [self.robot_position[1]] + + def save_observations(self, observations): + node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask = observations + self.episode_buffer[0] += deepcopy(node_inputs).to('cpu') + self.episode_buffer[1] += deepcopy(edge_inputs).to('cpu') + self.episode_buffer[2] += deepcopy(current_index).to('cpu') + self.episode_buffer[3] += deepcopy(node_padding_mask).to('cpu') + self.episode_buffer[4] += deepcopy(edge_padding_mask).to('cpu') + self.episode_buffer[5] += deepcopy(edge_mask).to('cpu') + + def save_action(self, action_index): + self.episode_buffer[6] += action_index.unsqueeze(0).unsqueeze(0) + + def save_reward_done(self, reward, done): + self.episode_buffer[7] += deepcopy(torch.FloatTensor([[[reward]]])).to('cpu') + self.episode_buffer[8] += deepcopy(torch.tensor([[[(int(done))]]])).to('cpu') + if self.plot: + self.xPoints.append(self.robot_position[0]) + self.yPoints.append(self.robot_position[1]) + + def save_next_observations(self, observations): + node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask = observations + self.episode_buffer[9] += deepcopy(node_inputs).to('cpu') + self.episode_buffer[10] += deepcopy(edge_inputs).to('cpu') + self.episode_buffer[11] += deepcopy(current_index).to('cpu') + self.episode_buffer[12] += deepcopy(node_padding_mask).to('cpu') + self.episode_buffer[13] += deepcopy(edge_padding_mask).to('cpu') + self.episode_buffer[14] += deepcopy(edge_mask).to('cpu') + + def save_trajectory_coords(self, robot_position_coords, num_target_found): + self.trajectory_coords.append(robot_position_coords) + self.targets_found_on_path.append(num_target_found) diff --git a/planner/sensor.py b/planner/sensor.py new file mode 100644 index 0000000000000000000000000000000000000000..98b010b54268c79cebd077039eab90bb426bc2e8 --- /dev/null +++ b/planner/sensor.py @@ -0,0 +1,128 @@ +####################################################################### +# Name: sensor.py +# +# - Computes sensor related checks (e.g. collision, utility etc) +####################################################################### + +import sys +if sys.modules['TRAINING']: + from .parameter import * +else: + from .test_parameter import * + +import math +import numpy as np +import copy + +def collision_check(x0, y0, x1, y1, ground_truth, robot_belief): + x0 = x0.round() + y0 = y0.round() + x1 = x1.round() + y1 = y1.round() + dx, dy = abs(x1 - x0), abs(y1 - y0) + x, y = x0, y0 + error = dx - dy + x_inc = 1 if x1 > x0 else -1 + y_inc = 1 if y1 > y0 else -1 + dx *= 2 + dy *= 2 + + collision_flag = 0 + max_collision = 10 + + while 0 <= x < ground_truth.shape[1] and 0 <= y < ground_truth.shape[0]: + k = ground_truth.item(y, x) + if k == 1 and collision_flag < max_collision: + collision_flag += 1 + if collision_flag >= max_collision: + break + + if k !=1 and collision_flag > 0: + break + + if x == x1 and y == y1: + break + + robot_belief.itemset((y, x), k) + + if error > 0: + x += x_inc + error -= dy + else: + y += y_inc + error += dx + + return robot_belief + + +def sensor_work(robot_position, sensor_range, robot_belief, ground_truth, sensor_model=SENSOR_MODEL): + x0 = robot_position[0] + y0 = robot_position[1] + rng_x = 0.5 * (ground_truth.shape[1] / NUM_COORDS_WIDTH) + rng_y = 0.5 * (ground_truth.shape[0] / NUM_COORDS_HEIGHT) + + if sensor_model == "rectangular": # TODO: add collision check + max_x = min(x0 + int(math.ceil(rng_x)), ground_truth.shape[1]) + min_x = max(x0 - int(math.ceil(rng_x)), 0) + max_y = min(y0 + int(math.ceil(rng_y)), ground_truth.shape[0]) + min_y = max(y0 - int(math.ceil(rng_y)), 0) + robot_belief[min_y:max_y, min_x:max_x] = ground_truth[min_y:max_y, min_x:max_x] + else: + sensor_angle_inc = 0.5 / 180 * np.pi + sensor_angle = 0 + while sensor_angle < 2 * np.pi: + x1 = x0 + np.cos(sensor_angle) * sensor_range + y1 = y0 + np.sin(sensor_angle) * sensor_range + robot_belief = collision_check(x0, y0, x1, y1, ground_truth, robot_belief) + sensor_angle += sensor_angle_inc + return robot_belief + + +def unexplored_area_check(x0, y0, x1, y1, current_belief): + x0 = x0.round() + y0 = y0.round() + x1 = x1.round() + y1 = y1.round() + dx, dy = abs(x1 - x0), abs(y1 - y0) + x, y = x0, y0 + error = dx - dy + x_inc = 1 if x1 > x0 else -1 + y_inc = 1 if y1 > y0 else -1 + dx *= 2 + dy *= 2 + + while 0 <= x < current_belief.shape[1] and 0 <= y < current_belief.shape[0]: + k = current_belief.item(y, x) + if x == x1 and y == y1: + break + + if k == 1: + break + + if k == 127: + current_belief.itemset((y, x), 0) + break + + if error > 0: + x += x_inc + error -= dy + else: + y += y_inc + error += dx + + return current_belief + + +def calculate_utility(waypoint_position, sensor_range, robot_belief): + sensor_angle_inc = 5 / 180 * np.pi + sensor_angle = 0 + x0 = waypoint_position[0] + y0 = waypoint_position[1] + current_belief = copy.deepcopy(robot_belief) + while sensor_angle < 2 * np.pi: + x1 = x0 + np.cos(sensor_angle) * sensor_range + y1 = y0 + np.sin(sensor_angle) * sensor_range + current_belief = unexplored_area_check(x0, y0, x1, y1, current_belief) + sensor_angle += sensor_angle_inc + utility = np.sum(robot_belief == 127) - np.sum(current_belief == 127) + return utility diff --git a/planner/test_info_surfing.py b/planner/test_info_surfing.py new file mode 100644 index 0000000000000000000000000000000000000000..306311f2570ffa81838f64b8431885904433647c --- /dev/null +++ b/planner/test_info_surfing.py @@ -0,0 +1,1071 @@ +####################################################################### +# Name: test_info_surfing.py +# +# - Runs robot in environment using Info Surfing Planner +####################################################################### + +import sys +sys.modules['TRAINING'] = False # False = Inference Testing + +import copy +import os +import imageio +import numpy as np +import matplotlib.pyplot as plt +from pathlib import Path +from time import time +from types import SimpleNamespace +from skimage.transform import resize +from taxabind_avs.satbind.kmeans_clustering import CombinedSilhouetteInertiaClusterer +from .env import Env +from .test_parameter import * + + +OPPOSITE_ACTIONS = {1: 3, 2: 4, 3: 1, 4: 2, 5: 7, 6: 8, 7: 5, 8: 6} +# color +agentColor = (1, 0.2, 0.6) +agentCommColor = (1, 0.6, 0.2) +obstacleColor = (0., 0., 0.) +targetNotFound = (0., 1., 0.) +targetFound = (0.545, 0.27, 0.075) +highestProbColor = (1., 0., 0.) +highestUncertaintyColor = (0., 0., 1.) +lowestProbColor = (1., 1., 1.) + + +class ISEnv: + """Custom Environment that follows gym interface""" + metadata = {'render.modes': ['human']} + + def __init__(self, global_step=0, state=None, shape=(24, 24), numAgents=8, observationSize=11, sensorSize=1, diag=False, save_image=False, clip_seg_tta=None): + + self.global_step = global_step + self.infoMap = None + self.targetMap = None + self.agents = [] + self.targets = [] + self.numAgents = numAgents + self.found_target = [] + self.shape = shape + self.observationSize = observationSize + self.sensorSize = sensorSize + self.diag = diag + self.communicateCircle = 11 + self.distribs = [] + self.mask = None + self.finished = False + self.action_vects = [[-1., 0.], [0., 1.], [1., 0], [0., -1.]] if not diag else [[-1., 0.], [0., 1.], [1., 0], [0., -1.], [-0.707, -0.707], [-0.707, 0.707], [0.707, 0.707], [0.707, -0.707]] + self.actionlist = [] + self.IS_step = 0 + self.save_image = save_image + self.clip_seg_tta = clip_seg_tta + self.perf_metrics = dict() + self.steps_to_first_tgt = None + self.steps_to_mid_tgt = None + self.steps_to_last_tgt = None + self.targets_found_on_path = [] + self.step_since_tta = 0 + self.IS_frame_files = [] + self.bad_mask_init = False + + # define env + self.env = Env(map_index=self.global_step, n_agent=numAgents, k_size=K_SIZE, plot=save_image, test=True) + + # Overwrite state + if self.clip_seg_tta is not None: + self.clip_seg_tta.reset(sample_idx=self.global_step) + + # Override target positions in env + self.env.target_positions = [(pose[1], pose[0]) for pose in self.clip_seg_tta.target_positions] + + # Override segmentation mask + if not USE_CLIP_PREDS and OVERRIDE_MASK_DIR != "": + score_mask_path = os.path.join(OVERRIDE_MASK_DIR, self.clip_seg_tta.gt_mask_name) + print("score_mask_path: ", score_mask_path) + if os.path.exists(score_mask_path): + self.env.segmentation_mask = self.env.import_segmentation_mask(score_mask_path) + self.env.begin(self.env.map_start_position) + else: + print(f"\n\n{RED}ERROR: Trying to override, but score mask not found at path:{NC} ", score_mask_path) + self.bad_mask_init = True + + # Save clustered embeds from sat encoder + if USE_CLIP_PREDS: + self.kmeans_clusterer = CombinedSilhouetteInertiaClusterer( + k_min=1, + k_max=8, + k_avg_max=4, + silhouette_threshold=0.15, + relative_threshold=0.15, + random_state=0, + min_patch_size=5, + n_smooth_iter=2, + ignore_label=-1, + plot=self.save_image, + gifs_dir = GIFS_PATH + ) + # Generate kmeans clusters + self.kmeans_sat_embeds_clusters = self.kmeans_clusterer.fit_predict( + patch_embeds=self.clip_seg_tta.patch_embeds, + map_shape=(CLIP_GRIDS_DIMS[0], CLIP_GRIDS_DIMS[1]), + ) + + if EXECUTE_TTA: + print("Will execute TTA...") + + IS_info_map = copy.deepcopy(self.env.segmentation_info_mask) + IS_agent_loc = copy.deepcopy(self.env.start_positions) + IS_target_loc = copy.deepcopy(self.env.target_positions) + state=[IS_info_map, IS_agent_loc, IS_target_loc] + self.setWorld(state) + + + def init_render(self): + """ + Call this once (e.g., in __init__ or just before the scenario loop) + to initialize storage for agent paths and turn interactive plotting on. + """ + # Keep track of each agent's trajectory + self.trajectories = [[] for _ in range(self.numAgents)] + self.trajectories_upscaled = [[] for _ in range(self.numAgents)] + + # Turn on interactive mode so we can update the same figure repeatedly + plt.ion() + plt.figure(figsize=(6,6)) + plt.title("Information Map with Agents, Targets, and Sensor Ranges") + + + def record_positions(self): + """ + Call this after all agents have moved in a step (or whenever you want to update + the trajectory). It appends the current positions of each agent to `self.trajectories`. + """ + for idx, agent in enumerate(self.agents): + self.trajectories[idx].append((agent.row, agent.col)) + self.trajectories_upscaled[idx].append(self.env.graph_generator.grid_coords[agent.row, agent.col]) + + + def render(self, episode_num, step_num): + """ + Renders the current state in a single matplotlib plot. + Ensures consistent image size for GIF generation. + """ + + # Completely reset the figure to avoid leftover state + plt.close('all') + fig = plt.figure(figsize=(6.4, 4.8), dpi=100) + ax = fig.add_subplot(111) + + # Plot the information map + ax.imshow(self.infoMap, origin='lower', cmap='gray') + + # Show agent positions and their trajectories + for idx, agent in enumerate(self.agents): + positions = self.trajectories[idx] + if len(positions) > 1: + rows = [p[0] for p in positions] + cols = [p[1] for p in positions] + ax.plot(cols, rows, linewidth=1) + + ax.scatter(agent.col, agent.row, marker='o', s=50) + + # Plot target locations + for t in self.targets: + color = 'green' if np.isnan(t.time_found) else 'red' + ax.scatter(t.col, t.row, marker='x', s=100, color=color) + + # Title and axis formatting + ax.set_title(f"Step: {self.IS_step}") + ax.invert_yaxis() + + # Create output folder if it doesn't exist + if not os.path.exists(GIFS_PATH): + os.makedirs(GIFS_PATH) + + # Save the frame with consistent canvas + frame_path = f'{GIFS_PATH}/IS_{episode_num}_{step_num}.png' + plt.savefig(frame_path, bbox_inches='tight', pad_inches=0.1) + self.IS_frame_files.append(frame_path) + + # Cleanup + plt.close(fig) + + + def setWorld(self, state=None): + """ + 1. empty all the element + 2. create the new episode + """ + if state is not None: + self.infoMap = copy.deepcopy(state[0].reshape(self.shape).T) + agents = [] + self.numAgents = len(state[1]) + for a in range(1, self.numAgents + 1): + abs_pos = state[1].pop(0) + abs_pos = np.array(abs_pos) + row, col = self.env.graph_generator.find_closest_index_from_grid_coords_2d(np.array(abs_pos)) + agents.append(Agent(ID=a, row=row, col=col, sensorSize=self.sensorSize, infoMap=np.copy(self.infoMap), + uncertaintyMap=np.copy(self.infoMap), shape=self.shape, numAgents=self.numAgents)) + self.agents = agents + + targets, n_targets = [], 1 + for t in range(len(state[2])): + abs_pos = state[2].pop(0) + abs_pos = np.array(abs_pos) + row, col = self.env.graph_generator.find_closest_index_from_grid_coords_2d(abs_pos) + targets.append(Target(ID=n_targets, row=row, col=col, time_found=np.nan)) + n_targets = n_targets + 1 + self.targets = targets + + def extractObservation(self, agent): + """ + Extract observations from information map + """ + + transform_row = self.observationSize // 2 - agent.row + transform_col = self.observationSize // 2 - agent.col + + observation_layers = np.zeros((1, self.observationSize, self.observationSize)) + min_row = max((agent.row - self.observationSize // 2), 0) + max_row = min((agent.row + self.observationSize // 2 + 1), self.shape[0]) + min_col = max((agent.col - self.observationSize // 2), 0) + max_col = min((agent.col + self.observationSize // 2 + 1), self.shape[1]) + + observation = np.full((self.observationSize, self.observationSize), 0.) + infoMap = np.full((self.observationSize, self.observationSize), 0.) + densityMap = np.full((self.observationSize, self.observationSize), 0.) + + infoMap[(min_row + transform_row):(max_row + transform_row), + (min_col + transform_col):(max_col + transform_col)] = self.infoMap[ + min_row:max_row, min_col:max_col] + observation_layers[0] = infoMap + + return observation_layers + + + def listNextValidActions(self, agent_id, prev_action=0): + """ + No movement: 0 + North (-1,0): 1 + East (0,1): 2 + South (1,0): 3 + West (0,-1): 4 + """ + available_actions = [0] + agent = self.agents[agent_id - 1] + + MOVES = [(-1, 0), (0, 1), (1, 0), (0, -1), (-1, -1), (-1, 1), (1, 1), (1, -1)] + size = 4 + self.diag * 4 + for action in range(size): + out_of_bounds = agent.row + MOVES[action][0] >= self.shape[0] \ + or agent.row + MOVES[action][0] < 0\ + or agent.col + MOVES[action][1] >= self.shape[1] \ + or agent.col + MOVES[action][1] < 0 + + if (not out_of_bounds) and not (prev_action == OPPOSITE_ACTIONS[action + 1]): + available_actions.append(action + 1) + + return np.array(available_actions) + + + def executeAction(self, agentID, action, timeStep): + """ + No movement: 0 + North (-1,0): 1 + East (0,1): 2 + South (1,0): 3 + West (0,-1): 4 + LeftUp (-1,-1) : 5 + RightUP (-1,1) :6 + RightDown (1,1) :7 + RightLeft (1,-1) :8 + """ + agent = self.agents[agentID - 1] + origLoc = agent.getLocation() + + if (action >= 1) and (action <= 8): + agent.move(action) + row, col = agent.getLocation() + + # If the move is not valid, roll it back + if (row < 0) or (col < 0) or (row >= self.shape[0]) or (col >= self.shape[1]): + self.updateInfoCheckTarget(agentID, timeStep, origLoc) + return 0 + + elif action == 0: + self.updateInfoCheckTarget(agentID, timeStep, origLoc) + return 0 + + else: + print("INVALID ACTION: {}".format(action)) + sys.exit() + + newLoc = agent.getLocation() + self.updateInfoCheckTarget(agentID, timeStep, origLoc) + return action + + + def updateInfoCheckTarget(self, agentID, timeStep, origLoc): + """ + update the self.infoMap and check whether the agent has found a target + """ + agent = self.agents[agentID - 1] + transform_row = self.sensorSize // 2 - agent.row + transform_col = self.sensorSize // 2 - agent.col + + min_row = max((agent.row - self.sensorSize // 2), 0) + max_row = min((agent.row + self.sensorSize // 2 + 1), self.shape[0]) + min_col = max((agent.col - self.sensorSize // 2), 0) + max_col = min((agent.col + self.sensorSize // 2 + 1), self.shape[1]) + for t in self.targets: + if (t.row == agent.row) and (t.col == agent.col): + t.updateFound(timeStep) + self.found_target.append(t) + t.status = True + + self.infoMap[min_row:max_row, min_col:max_col] *= 0.05 + + + def updateInfoEntireTrajectory(self, agentID): + """ + update the self.infoMap and check whether the agent has found a target + """ + traj = self.trajectories[agentID - 1] + + for (row,col) in traj: + min_row = max((row - self.sensorSize // 2), 0) + max_row = min((row + self.sensorSize // 2 + 1), self.shape[0]) + min_col = max((col - self.sensorSize // 2), 0) + max_col = min((col + self.sensorSize // 2 + 1), self.shape[1]) + self.infoMap[min_row:max_row, min_col:max_col] *= 0.05 + + + # Execute one time step within the environment + def step(self, agentID, action, timeStep): + """ + the agents execute the actions + No movement: 0 + North (-1,0): 1 + East (0,1): 2 + South (1,0): 3 + West (0,-1): 4 + """ + assert (agentID > 0) + + self.executeAction(agentID, action, timeStep) + + + def observe(self, agentID): + assert (agentID > 0) + vectorObs = self.extractObservation(self.agents[agentID - 1]) + return [vectorObs] + + + def check_finish(self): + if TERMINATE_ON_TGTS_FOUND: + found_status = [t.time_found for t in self.targets] + d = False + if np.isnan(found_status).sum() == 0: + d = True + return d + else: + return False + + + def gradVec(self, observation, agent): + a = observation[0] + + # Make info & unc cells with low value as 0 + a[a < 0.0002] = 0.0 + + # Center square from 11x11 + a_11x11 = a[4:7, 4:7] + m_11x11 = np.array((a_11x11)) + + # Center square from 9x9 + a_9x9 = self.pooling(a, (3, 3), stride=(1, 1), method='max', pad=False) + a_9x9 = a_9x9[3:6, 3:6] + m_9x9 = np.array((a_9x9)) + + # Center square from 6x6 + a_6x6 = self.pooling(a, (6, 6), stride=(1, 1), method='max', pad=False) + a_6x6 = a_6x6[1:4, 1:4] + m_6x6 = np.array((a_6x6)) + + # Center square from 3x3 + a_3x3 = self.pooling(a, (5, 5), stride=(3, 3), method='max', pad=False) + m_3x3 = np.array((a_3x3)) + + # Merging multiScales with weights + m = m_3x3 * 0.25 + m_6x6 * 0.25 + m_9x9 * 0.25 + m_11x11 * 0.25 + a = m + + adx, ady = np.gradient(a) + den = np.linalg.norm(np.array([adx[1, 1], ady[1, 1]])) + if (den != 0) and (not np.isnan(den)): + infovec = np.array([adx[1, 1], ady[1, 1]]) / den + else: + infovec = 0 + agentvec = [] + + if len(agentvec) == 0: + den = np.linalg.norm(infovec) + if (den != 0) and (not np.isnan(den)): + direction = infovec / den + else: + direction = self.action_vects[np.random.randint(4 + self.diag * 4)] + else: + den = np.linalg.norm(np.mean(agentvec, 0)) + if (den != 0) and (not np.isnan(den)): + agentvec = np.mean(agentvec, 0) / den + else: + agentvec = 0 + + den = np.linalg.norm(0.6 * infovec + 0.4 * agentvec) + if (den != 0) and (not np.isnan(den)): + direction = (0.6 * infovec + 0.4 * agentvec) / den + else: + direction = self.action_vects[np.random.randint(4 + self.diag * 4)] + + action_vec = [[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]] if not self.diag else [[0., 0.], [-1., 0.], [0., 1.], [1., 0], [0., -1.], [-0.707, -0.707], [-0.707, 0.707], [0.707, 0.707], [0.707, -0.707]] + actionid = np.argmax([np.dot(direction, a) for a in action_vec]) + actionid = self.best_valid_action(actionid, agent, direction) + return actionid + + + def best_valid_action(self, actionid, agent, direction): + if len(self.actionlist) > 1: + if self.action_invalid(actionid, agent): + action_vec = [[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]] if not self.diag else [[0., 0.], [-1., 0.], [0., 1.], [1., 0], [0., -1.], [-0.707, -0.707], [-0.707, 0.707], [0.707, 0.707], [0.707, -0.707]] + actionid = np.array([np.dot(direction, a) for a in action_vec]) + actionid = actionid.argsort() + pi = 3 + self.diag*4 + while self.action_invalid(actionid[pi], agent) and pi >= 0: + pi -= 1 + if pi == -1: + return OPPOSITE_ACTIONS[self.actionlist[self.IS_step - 1][agent - 1]] + elif actionid[pi] == 0: + return OPPOSITE_ACTIONS[self.actionlist[self.IS_step - 1][agent - 1]] + else: + return actionid[pi] + return actionid + + + def action_invalid(self, action, agent): + # Going back to the previous cell is disabled + if action == OPPOSITE_ACTIONS[self.actionlist[self.IS_step - 1][agent - 1]]: + return True + # Move N,E,S,W + if (action >= 1) and (action <= 8): + agent = self.agents[agent - 1] + agent.move(action) + row, col = agent.getLocation() + + # If the move is not valid, roll it back + if ((row < 0) or (col < 0) or (row >= self.shape[0]) or (col >= self.shape[1])): + agent.reverseMove(action) + return True + + agent.reverseMove(action) + return False + return False + + + def step_all_parallel(self): + actions = [] + reward = 0 + # Decide actions for each agent + for agent_id in range(1, self.numAgents + 1): + o = self.observe(agent_id) + actions.append(self.gradVec(o[0], agent_id)) + self.actionlist.append(actions) + + # Execute those actions + for agent_id in range(1, self.numAgents + 1): + self.step(agent_id, actions[agent_id - 1], self.IS_step) + + # Record for visualization + self.record_positions() + + def is_scenario(self, max_step=512, episode_number=0): + + # Return all metrics as None if faulty mask init + if self.bad_mask_init: + self.perf_metrics['tax'] = None + self.perf_metrics['travel_dist'] = None + self.perf_metrics['travel_steps'] = None + self.perf_metrics['steps_to_first_tgt'] = None + self.perf_metrics['steps_to_mid_tgt'] = None + self.perf_metrics['steps_to_last_tgt'] = None + self.perf_metrics['explored_rate'] = None + self.perf_metrics['targets_found'] = None + self.perf_metrics['targets_total'] = None + self.perf_metrics['kmeans_k'] = None + self.perf_metrics['tgts_gt_score'] = None + self.perf_metrics['clip_inference_time'] = None + self.perf_metrics['tta_time'] = None + self.perf_metrics['success_rate'] = None + return + + eps_start = time() + self.IS_step = 0 + self.finished = False + reward = 0 + + # Initialize the rendering just once before the loop + self.init_render() + self.record_positions() + + # Initial Setup + if LOAD_AVS_BENCH and USE_CLIP_PREDS: + if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]: # If heatmap is resized from clip original dims + heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT)) + self.env.segmentation_info_mask = np.expand_dims(heatmap.T.flatten(), axis=1) + unnormalized_heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap_unnormalized, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT)) + self.env.segmentation_info_mask_unnormalized = np.expand_dims(unnormalized_heatmap.T.flatten(), axis=1) + self.infoMap = copy.deepcopy(heatmap) + print("Resized heatmap to", NUM_COORDS_WIDTH, "x", NUM_COORDS_HEIGHT) + else: + self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1) + self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1) + self.infoMap = copy.deepcopy(self.clip_seg_tta.heatmap) + + self.targets_found_on_path.append(self.env.num_new_targets_found) + + while self.IS_step < max_step and not self.check_finish(): + self.step_all_parallel() + self.IS_step += 1 + + # Render after each step + if self.save_image: + self.render(episode_num=self.global_step, step_num=self.IS_step) + + # Update in env + next_position_list = [self.trajectories_upscaled[i][-1] for i, agent in enumerate(self.agents)] + dist_list = [0 for _ in range(self.numAgents)] + travel_dist_list = [self.compute_travel_distance(traj) for traj in self.trajectories] + self.env.multi_robot_step(next_position_list, dist_list, travel_dist_list) + self.targets_found_on_path.append(self.env.num_new_targets_found) + + # TTA Update via Poisson Test (with KMeans clustering stats) + robot_id = 0 # Assume 1 agent for now + robot_traj = self.trajectories[robot_id] + if LOAD_AVS_BENCH and USE_CLIP_PREDS and EXECUTE_TTA: + flat_traj_coords = [robot_traj[i][1] * self.shape[0] + robot_traj[i][0] for i in range(len(robot_traj))] + robot = SimpleNamespace( + trajectory_coords=flat_traj_coords, + targets_found_on_path=self.targets_found_on_path + ) + self.poisson_tta_update(robot, self.global_step, self.IS_step) + self.infoMap = copy.deepcopy(self.env.segmentation_info_mask.reshape((self.shape[1],self.shape[0])).T) + self.updateInfoEntireTrajectory(robot_id) + + # Update metrics + self.log_metrics(step=self.IS_step-1) + + ### Save a frame to generate gif of robot trajectories ### + if self.save_image: + robots_route = [ ([], []) ] # Assume 1 robot + for point in self.trajectories_upscaled[robot_id]: + robots_route[robot_id][0].append(point[0]) + robots_route[robot_id][1].append(point[1]) + if not os.path.exists(GIFS_PATH): + os.makedirs(GIFS_PATH) + if LOAD_AVS_BENCH: + sound_id_override = None if self.clip_seg_tta.sound_ids == [] else self.clip_seg_tta.sound_ids[0] + self.env.plot_env( + self.global_step, + GIFS_PATH, + self.IS_step-1, + max(travel_dist_list), + robots_route, + img_path_override=self.clip_seg_tta.img_paths[0], # Viz 1st + sat_path_override=self.clip_seg_tta.imo_path, + msk_name_override=self.clip_seg_tta.species_name, + sound_id_override=sound_id_override, + ) + else: + self.env.plot_env( + self.global_step, + GIFS_PATH, + self.IS_step-1, + max(travel_dist_list), + robots_route + ) + + # Log metrics + if LOAD_AVS_BENCH: + tax = Path(self.clip_seg_tta.gt_mask_name).stem + self.perf_metrics['tax'] = " ".join(tax.split("_")[1:]) + else: + self.perf_metrics['tax'] = None + travel_distances = [self.compute_travel_distance(traj) for traj in self.trajectories] + self.perf_metrics['travel_dist'] = max(travel_distances) + self.perf_metrics['travel_steps'] = self.IS_step + self.perf_metrics['steps_to_first_tgt'] = self.steps_to_first_tgt + self.perf_metrics['steps_to_mid_tgt'] = self.steps_to_mid_tgt + self.perf_metrics['steps_to_last_tgt'] = self.steps_to_last_tgt + self.perf_metrics['explored_rate'] = self.env.explored_rate + self.perf_metrics['targets_found'] = self.env.targets_found_rate + self.perf_metrics['targets_total'] = len(self.env.target_positions) + if USE_CLIP_PREDS: + self.perf_metrics['kmeans_k'] = self.kmeans_clusterer.final_k + self.perf_metrics['tgts_gt_score'] = self.clip_seg_tta.tgts_gt_score + self.perf_metrics['clip_inference_time'] = self.clip_seg_tta.clip_inference_time + self.perf_metrics['tta_time'] = self.clip_seg_tta.tta_time + else: + self.perf_metrics['kmeans_k'] = None + self.perf_metrics['tgts_gt_score'] = None + self.perf_metrics['clip_inference_time'] = None + self.perf_metrics['tta_time'] = None + if FORCE_LOGGING_DONE_TGTS_FOUND and self.env.targets_found_rate == 1.0: + self.perf_metrics['success_rate'] = True + else: + self.perf_metrics['success_rate'] = self.env.check_done()[0] + + # save gif + if self.save_image: + path = GIFS_PATH + self.make_gif(path, self.global_step) + + print(YELLOW, f"[Eps {episode_number} Completed] Time Taken: {time()-eps_start:.2f}s, Steps: {self.IS_step}", NC) + + + def asStride(self, arr, sub_shape, stride): + """ + Get a strided sub-matrices view of an ndarray. + See also skimage.util.shape.view_as_windows() + """ + s0, s1 = arr.strides[:2] + m1, n1 = arr.shape[:2] + m2, n2 = sub_shape + view_shape = (1+(m1-m2)//stride[0], 1+(n1-n2)//stride[1], m2, n2)+arr.shape[2:] + strides = (stride[0]*s0, stride[1]*s1, s0, s1)+arr.strides[2:] + subs = np.lib.stride_tricks.as_strided(arr, view_shape, strides=strides) + return subs + + + def pooling(self, mat, ksize, stride=None, method='max', pad=False): + """ + Overlapping pooling on 2D or 3D data. + + : ndarray, input array to pool. + : tuple of 2, kernel size in (ky, kx). + : tuple of 2 or None, stride of pooling window. + If None, same as (non-overlapping pooling). + : str, 'max for max-pooling, + 'mean' for mean-pooling. + : bool, pad or not. If no pad, output has size + (n-f)//s+1, n being size, f being kernel size, s stride. + if pad, output has size ceil(n/s). + + Return : pooled matrix. + """ + + m, n = mat.shape[:2] + ky, kx = ksize + if stride is None: + stride = (ky, kx) + sy, sx = stride + + _ceil = lambda x, y: int(np.ceil(x/float(y))) + + if pad: + ny = _ceil(m,sy) + nx = _ceil(n,sx) + size = ((ny-1)*sy+ky, (nx-1)*sx+kx) + mat.shape[2:] + mat_pad = np.full(size,np.nan) + mat_pad[:m,:n,...] = mat + else: + mat_pad = mat[:(m-ky)//sy*sy+ky, :(n-kx)//sx*sx+kx, ...] + + view = self.asStride(mat_pad,ksize,stride) + + if method == 'max': + result = np.nanmax(view,axis=(2,3)) + else: + result = np.nanmean(view,axis=(2,3)) + + return result + + + def compute_travel_distance(self, trajectory): + distance = 0.0 + for i in range(1, len(trajectory)): + # Convert the tuple positions to numpy arrays for easy computation. + prev_pos = np.array(trajectory[i-1]) + curr_pos = np.array(trajectory[i]) + # Euclidean distance between consecutive positions. + distance += np.linalg.norm(curr_pos - prev_pos) + return distance + + ################################################################################ + # SPPP Related Fns + ################################################################################ + + def log_metrics(self, step): + # Update tgt found metrics + if self.steps_to_first_tgt is None and self.env.num_targets_found == 1: + self.steps_to_first_tgt = step + 1 + if self.steps_to_mid_tgt is None and self.env.num_targets_found == int(len(self.env.target_positions) / 2): + self.steps_to_mid_tgt = step + 1 + if self.steps_to_last_tgt is None and self.env.num_targets_found == len(self.env.target_positions): + self.steps_to_last_tgt = step + 1 + + + def transpose_flat_idx(self, idx, H=NUM_COORDS_HEIGHT, W=NUM_COORDS_WIDTH): + """ + Transpose a flat index from an ``HΓ—W`` grid to the equivalent + position in the ``WΓ—H`` transposed grid while **keeping the result + in 1-D**. + """ + # --- Safety check to catch out-of-range indices --- + assert 0 <= idx < H * W, f"idx {idx} out of bounds for shape ({H}, {W})" + + # Original (row, col) + row, col = divmod(idx, W) + # After transpose these coordinates swap + row_T, col_T = col, row + + # Flatten back into 1-D (row-major) for the WΓ—H grid + return row_T * H + col_T + + + def poisson_tta_update(self, robot, episode, step): + + # Generate Kmeans Clusters Stats + # Scale index back to CLIP_GRIDS_DIMS to be compatible with CLIP patch size + if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]: + # High-res remap via pixel coordinates preserves exact neighbourhood + filt_traj_coords, filt_targets_found_on_path = self.scale_trajectory( + robot.trajectory_coords, + self.env.target_positions, + old_dims=(NUM_COORDS_HEIGHT, NUM_COORDS_WIDTH), + full_dims=(512, 512), + new_dims=(CLIP_GRIDS_DIMS[0], CLIP_GRIDS_DIMS[1]) + ) + else: + filt_traj_coords = [self.transpose_flat_idx(idx) for idx in robot.trajectory_coords] + filt_targets_found_on_path = robot.targets_found_on_path + + region_stats_dict = self.kmeans_clusterer.compute_region_statistics( + self.kmeans_sat_embeds_clusters, + self.clip_seg_tta.heatmap_unnormalized, + filt_traj_coords, + episode_num=episode, + step_num=step + ) + + # Prep & execute TTA + self.step_since_tta += 1 + if robot.targets_found_on_path[-1] or self.step_since_tta % STEPS_PER_TTA == 0: + + num_cells = self.clip_seg_tta.heatmap.shape[0] * self.clip_seg_tta.heatmap.shape[1] + pos_sample_weight_scale, neg_sample_weight_scale = [], [] + + for i, sample_loc in enumerate(filt_traj_coords): + label = self.kmeans_clusterer.get_label_id(self.kmeans_sat_embeds_clusters, sample_loc) + num_patches = region_stats_dict[label]['num_patches'] + patches_visited = region_stats_dict[label]['patches_visited'] + expectation = region_stats_dict[label]['expectation'] + + # Exponent like focal loss to wait for more samples before confidently decreasing + pos_weight = 4.0 + neg_weight = min(1.0, (patches_visited/(3*num_patches))**GAMMA_EXPONENT) + pos_sample_weight_scale.append(pos_weight) + neg_sample_weight_scale.append(neg_weight) + + # Adaptative LR (as samples increase, increase LR to fit more datapoints) + adaptive_lr = MIN_LR + (MAX_LR - MIN_LR) * (step / num_cells) + + # TTA Update + self.clip_seg_tta.execute_tta( + filt_traj_coords, + filt_targets_found_on_path, + tta_steps=NUM_TTA_STEPS, + lr=adaptive_lr, + pos_sample_weight=pos_sample_weight_scale, + neg_sample_weight=neg_sample_weight_scale, + reset_weights=RESET_WEIGHTS + ) + if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]: # If heatmap is resized from clip original dims + heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT)) + self.env.segmentation_info_mask = np.expand_dims(heatmap.T.flatten(), axis=1) + unnormalized_heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap_unnormalized, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT)) + self.env.segmentation_info_mask_unnormalized = np.expand_dims(unnormalized_heatmap.T.flatten(), axis=1) + print("~Resized heatmap to", NUM_COORDS_WIDTH, "x", NUM_COORDS_HEIGHT) + else: + self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1) + self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1) + + self.step_since_tta = 0 + + + def convert_heatmap_resolution(self, heatmap, full_dims=(512, 512), new_dims=(24, 24)): + heatmap_large = resize(heatmap, full_dims, order=1, # order=1 β†’ bilinear + mode='reflect', anti_aliasing=True) + + coords = self.env.graph_generator.grid_coords # (N, N, 2) + rows, cols = coords[...,1], coords[...,0] + heatmap_resized = heatmap_large[rows, cols] + heatmap_resized = heatmap_resized.reshape(new_dims[1], new_dims[0]) + return heatmap_resized + + + def convert_labelmap_resolution(self, labelmap, full_dims=(512, 512), new_dims=(24, 24)): + """ + 1) Upsample via nearest‐neighbor to full_dims + 2) Sample back down to your graph grid using grid_coords + """ + # 1) Upsample with nearest‐neighbor, preserving integer labels + up = resize( + labelmap, + full_dims, + order=0, # nearest‐neighbor + mode='edge', # padding mode + preserve_range=True, # don't normalize labels + anti_aliasing=False # must be False for labels + ).astype(labelmap.dtype) # back to original integer dtype + + # 2) Downsample via your precomputed grid coords (NΓ—NΓ—2) + coords = self.env.graph_generator.grid_coords # shape (N, N, 2) + rows = coords[...,1].astype(int) + cols = coords[...,0].astype(int) + + small = up[rows, cols] # shape (N, N) + small = small.reshape(new_dims[0], new_dims[1]) + return small + + + def scale_trajectory(self, + flat_indices, + targets, + old_dims=(17, 17), + full_dims=(512, 512), + new_dims=(24, 24)): + """ + Args: + flat_indices: list of ints in [0..old_H*old_W-1] + targets: list of (y_pix, x_pix) in [0..full_H-1] + old_dims: (old_H, old_W) + full_dims: (full_H, full_W) + new_dims: (new_H, new_W) + + Returns: + new_flat_traj: list of unique flattened indices in new_HΓ—new_W + counts: list of ints, same length as new_flat_traj + """ + old_H, old_W = old_dims + full_H, full_W = full_dims + new_H, new_W = new_dims + + # 1) bin targets into new grid + cell_h_new = full_H / new_H + cell_w_new = full_W / new_W + grid_counts = [[0]*new_W for _ in range(new_H)] + for x_pix, y_pix in targets: # note (x, y) order as in original implementation + i_t = min(int(y_pix / cell_h_new), new_H - 1) + j_t = min(int(x_pix / cell_w_new), new_W - 1) + grid_counts[i_t][j_t] += 1 + + # 2) Walk the trajectory indices and project each old cell's *entire + # pixel footprint* onto the finer 24Γ—24 grid. + cell_h_full = full_H / old_H + cell_w_full = full_W / old_W + + seen = set() + new_flat_traj = [] + + for node_idx in flat_indices: + if node_idx < 0 or node_idx >= len(self.env.graph_generator.node_coords): + continue + + coord_xy = self.env.graph_generator.node_coords[node_idx] + try: + row_old, col_old = self.env.graph_generator.find_index_from_grid_coords_2d(coord_xy) + except Exception: + continue + + # Bounding box of the old cell in full-resolution pixel space + y0 = row_old * cell_h_full + y1 = (row_old + 1) * cell_h_full + x0 = col_old * cell_w_full + x1 = (col_old + 1) * cell_w_full + + # Which new-grid rows & cols overlap? (inclusive ranges) + i_start = max(0, min(int(y0 / cell_h_new), new_H - 1)) + i_end = max(0, min(int((y1 - 1) / cell_h_new), new_H - 1)) + j_start = max(0, min(int(x0 / cell_w_new), new_W - 1)) + j_end = max(0, min(int((x1 - 1) / cell_w_new), new_W - 1)) + + for ii in range(i_start, i_end + 1): + for jj in range(j_start, j_end + 1): + f_new = ii * new_W + jj + if f_new not in seen: + seen.add(f_new) + new_flat_traj.append(f_new) + + # 3) annotate counts + counts = [] + for f in new_flat_traj: + i_new, j_new = divmod(f, new_W) + counts.append(grid_counts[i_new][j_new]) + + return new_flat_traj, counts + + + ################################################################################ + + def make_gif(self, path, n): + """ Generate a gif given list of images """ + with imageio.get_writer('{}/{}_target_rate_{:.2f}.gif'.format(path, n, self.env.targets_found_rate), mode='I', + fps=5) as writer: + for frame in self.env.frame_files: + image = imageio.imread(frame) + writer.append_data(image) + print('gif complete\n') + + # Remove files + for filename in self.env.frame_files[:-1]: + os.remove(filename) + + # For KMeans gif + if LOAD_AVS_BENCH and USE_CLIP_PREDS: + with imageio.get_writer('{}/{}_kmeans_stats.gif'.format(path, n), mode='I', + fps=5) as writer: + for frame in self.kmeans_clusterer.kmeans_frame_files: + image = imageio.imread(frame) + writer.append_data(image) + print('Kmeans Clusterer gif complete\n') + + # Remove files + for filename in self.kmeans_clusterer.kmeans_frame_files[:-1]: + os.remove(filename) + + + # IS gif + with imageio.get_writer('{}/{}_IS.gif'.format(path, n), mode='I', + fps=5) as writer: + for frame in self.IS_frame_files: + image = imageio.imread(frame) + writer.append_data(image) + print('Kmeans Clusterer gif complete\n') + + # Remove files + for filename in self.IS_frame_files[:-1]: + os.remove(filename) + + ################################################################################ + + +class Agent: + def __init__(self, ID, infoMap=None, uncertaintyMap=None, shape=None, row=0, col=0, sensorSize=9, numAgents=8): + self.ID = ID + self.row = row + self.col = col + self.numAgents = numAgents + self.sensorSize = sensorSize + + def setLocation(self, row, col): + self.row = row + self.col = col + + def getLocation(self): + return [self.row, self.col] + + def move(self, action): + """ + No movement: 0 + North (-1,0): 1 + East (0,1): 2 + South (1,0): 3 + West (0,-1): 4 + LeftUp (-1,-1) : 5 + RightUP (-1,1) :6 + RightDown (1,1) :7 + RightLeft (1,-1) :8 + check valid action of the agent. be sure not to be out of the boundary + """ + if action == 0: + return 0 + elif action == 1: + self.row -= 1 + elif action == 2: + self.col += 1 + elif action == 3: + self.row += 1 + elif action == 4: + self.col -= 1 + elif action == 5: + self.row -= 1 + self.col -= 1 + elif action == 6: + self.row -= 1 + self.col += 1 + elif action == 7: + self.row += 1 + self.col += 1 + elif action == 8: + self.row += 1 + self.col -= 1 + + def reverseMove(self, action): + if action == 0: + return 0 + elif action == 1: + self.row += 1 + elif action == 2: + self.col -= 1 + elif action == 3: + self.row -= 1 + elif action == 4: + self.col += 1 + elif action == 5: + self.row += 1 + self.col += 1 + elif action == 6: + self.row += 1 + self.col -= 1 + elif action == 7: + self.row -= 1 + self.col -= 1 + elif action == 8: + self.row -= 1 + self.col += 1 + else: + print("agent can only move NESW/1234") + sys.exit() + + +class Target: + def __init__(self, row, col, ID, time_found=np.nan): + self.row = row + self.col = col + self.ID = ID + self.time_found = time_found + self.status = None + self.time_visited = time_found + + def getLocation(self): + return self.row, self.col + + def updateFound(self, timeStep): + if np.isnan(self.time_found): + self.time_found = timeStep + + def updateVisited(self, timeStep): + if np.isnan(self.time_visited): + self.time_visited = timeStep + + +if __name__ == "__main__": + + search_env = Env(map_index=1, k_size=K_SIZE, n_agent=NUM_ROBOTS, plot=SAVE_GIFS) + + IS_info_map = search_env.segmentation_info_mask + IS_agent_loc = search_env.start_positions + IS_target_loc = [[312, 123], [123, 312], [312, 312], [123, 123]] + + env = ISEnv(state=[IS_info_map, IS_agent_loc, IS_target_loc], shape=(NUM_COORDS_HEIGHT, NUM_COORDS_WIDTH)) + env.is_scenario(NUM_EPS_STEPS) + print() diff --git a/planner/test_parameter.py b/planner/test_parameter.py new file mode 100644 index 0000000000000000000000000000000000000000..0fd0808a446b306d7e43397d2674ca160f03158a --- /dev/null +++ b/planner/test_parameter.py @@ -0,0 +1,118 @@ +############################################################################################ +# Name: test_parameter.py +# +# NOTE: Change all your hyper-params here for eval +# Simple How-To Guide: +# 1. CLIP TTA: USE_CLIP_PREDS=True, EXECUTE_TTA=True +# 2. CLIP (No TTA): USE_CLIP_PREDS=True, EXECUTE_TTA=False +# 3. Custom masks (e.g. LISA): USE_CLIP_PREDS=False, EXECUTE_TTA=False +############################################################################################ + +import os +import sys +sys.modules['TRAINING'] = False # False = Inference Testing + +############################################################### +# Overload Params +############################################################### + +OPT_VARS = {} +def getenv(var_name, default=None, cast_type=str): + try: + value = os.environ.get(var_name, None) + if value is None: + result = default + elif cast_type == bool: + result = value.lower() in ("true", "1", "yes") + else: + result = cast_type(value) + except (ValueError, TypeError): + result = default + + OPT_VARS[var_name] = result # Log the result + return result + +############################################################### +# General +############################################################### + +# --- GENERAL --- # +USE_GPU = False +NUM_GPU = getenv("NUM_GPU", default=1, cast_type=int) # the number of GPUs +NUM_META_AGENT = getenv("NUM_META_AGENT", default=2, cast_type=int) # the number of concurrent processes +NUM_EPS_STEPS = getenv("NUM_EPS_STEPS", default=400, cast_type=int) +FIX_START_POSITION = getenv("FIX_START_POSITION", default=True, cast_type=bool) # Whether to fix the starting position of the robots (middle index) +NUM_ROBOTS = 1 # Only allow for 1 robot +NUM_COORDS_WIDTH=24 # How many node coords across width? +NUM_COORDS_HEIGHT=24 # How many node coords across height? +CLIP_GRIDS_DIMS=[24,24] # [16,16] if 'openai/clip-vit-large-patch14-336' +SENSOR_RANGE=80 # Only applicable to 'circle' sensor model +SENSOR_MODEL="rectangular" # "rectangular", "circle" (NOTE: no colllision check for rectangular) +TERMINATE_ON_TGTS_FOUND = True # Whether to terminate episode when all targets found +FORCE_LOGGING_DONE_TGTS_FOUND = True # Whether to force csv logging when all targets found + + +# --- Planner Params --- # +POLICY = getenv("POLICY", default="RL", cast_type=str) +NUM_TEST = 800 # Overriden if LOAD_AVS_BENCH +NUM_RUN = 1 +MODEL_NAME = "avs_rl_policy.pth" +INPUT_DIM = 4 +EMBEDDING_DIM = 128 +K_SIZE = 8 + + +# --- Folders & Visualizations --- # +GRIDMAP_SET_DIR = "maps/gpt4o/envs_val" +MASK_SET_DIR = "maps/example/masks_val" # Overriden if LOAD_AVS_BENCH +TARGETS_SET_DIR = "" +# TARGETS_SET_DIR = "maps/example/gt_masks_val_with_tgts" # Overriden if LOAD_AVS_BENCH +OVERRIDE_MASK_DIR = getenv("OVERRIDE_MASK_DIR", default="", cast_type=str) # Override initial score mask from CLIP +SAVE_GIFS = getenv("SAVE_GIFS", default=True, cast_type=bool) # do you want to save GIFs +FOLDER_NAME = 'avs_search' +MODEL_PATH = f'inference/model' +GIFS_PATH = f'inference/test_results/gifs/{FOLDER_NAME}' +LOG_PATH = f'inference/test_results/log/{FOLDER_NAME}' +LOG_TEMPLATE_XLSX = f'inference/template.xlsx' +CSV_EXPT_NAME = getenv("CSV_EXPT_NAME", default="data", cast_type=str) +VIZ_GRAPH_EDGES = False # do you want to visualize the graph edges + + +####################################################################### +# AVS Params +####################################################################### + +# General PARAMS +USE_CLIP_PREDS = getenv("USE_CLIP_PREDS", default=True, cast_type=bool) # If false, use custom masks from OVERRIDE_MASK_DIR +QUERY_TAX = getenv("QUERY_TAX", default="", cast_type=str) # "" = Test all tax (can accept taxonomy substrings) +EXECUTE_TTA = getenv("EXECUTE_TTA", default=True, cast_type=bool) # Whether to execute TTA mask updates +QUERY_MODALITY = getenv("QUERY_MODALITY", default="image", cast_type=str) # "image", "text", "sound" +STEPS_PER_TTA = 20 # no. steps before each TTA series +NUM_TTA_STEPS = 1 # no. of TTA steps during each series +RESET_WEIGHTS = True +MIN_LR = 1e-6 +MAX_LR = 1e-5 +GAMMA_EXPONENT = 2 + +# Paths related to AVS (TRAIN w/ TARGETS) +LOAD_AVS_BENCH = True # Whether to init AVS datasets +AVS_IMG_DIR = '/mnt/hdd/avs_bench_ds/inat21' +AVS_IMO_DIR = '/mnt/hdd/avs_bench_ds/sat_jpg/train_512px' +AVS_INAT_JSON_PATH = '/mnt/hdd/avs_bench_ds/inat21/train.json' +AVS_SOUND_DIR = '/mnt/hdd/avs_bench_ds/sound_mp3/test' +AVS_GAUSSIAN_BLUR_KERNEL = (5,5) +AVS_SAT_TO_IMG_IDS_PATH = getenv("AVS_SAT_TO_IMG_IDS_PATH", default="search_tri_modal|val_in_domain", cast_type=str) +AVS_LOAD_PRETRAINED_HF_CHECKPOINT = getenv("AVS_LOAD_PRETRAINED_HF_CHECKPOINT", default=True, cast_type=bool) # If false, load locally using CHECKPOINT_PATHs +AVS_SAT_CHECKPOINT_PATH = getenv("AVS_SAT_CHECKPOINT_PATH", default="", cast_type=str) +AVS_SOUND_CHECKPOINT_PATH = getenv("AVS_SOUND_CHECKPOINT_PATH", default="", cast_type=str) + +####################################################################### +# UTILS +####################################################################### + +# COLORS (for printing) +RED='\033[1;31m' +GREEN='\033[1;32m' +YELLOW='\033[1;93m' +NC_BOLD='\033[1m' # Bold, No Color +NC='\033[0m' # No Color diff --git a/planner/test_worker.py b/planner/test_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..f96c4b935e779761b4ce323822755d66e8676ba3 --- /dev/null +++ b/planner/test_worker.py @@ -0,0 +1,590 @@ +####################################################################### +# Name: test_worker.py +# +# - Runs robot in environment using RL Planner +####################################################################### + +from .test_parameter import * + +import imageio +import os +import copy +import numpy as np +import torch +from time import time +from pathlib import Path +from skimage.transform import resize +from taxabind_avs.satbind.kmeans_clustering import CombinedSilhouetteInertiaClusterer +from .env import Env +from .robot import Robot + +np.seterr(invalid='raise', divide='raise') + + +class TestWorker: + def __init__(self, meta_agent_id, n_agent, policy_net, global_step, device='cuda', greedy=False, save_image=False, clip_seg_tta=None): + self.device = device + self.greedy = greedy + self.n_agent = n_agent + self.metaAgentID = meta_agent_id + self.global_step = global_step + self.k_size = K_SIZE + self.save_image = save_image + self.clip_seg_tta = clip_seg_tta + self.execute_tta = EXECUTE_TTA # Added to interface with app.py + + self.env = Env(map_index=self.global_step, n_agent=n_agent, k_size=self.k_size, plot=save_image, test=True) + self.local_policy_net = policy_net + + self.robot_list = [] + self.all_robot_positions = [] + for i in range(self.n_agent): + robot_position = self.env.start_positions[i] + robot = Robot(robot_id=i, position=robot_position, plot=save_image) + self.robot_list.append(robot) + self.all_robot_positions.append(robot_position) + + self.perf_metrics = dict() + self.bad_mask_init = False + + # NOTE: Option to override gifs_path to interface with app.py + self.gifs_path = GIFS_PATH + + # NOTE: updated due to app.py (hf does not allow heatmap to persist) + if LOAD_AVS_BENCH: + if clip_seg_tta is not None: + heatmap, heatmap_unnormalized, heatmap_unnormalized_initial, patch_embeds = self.clip_seg_tta.reset(sample_idx=self.global_step) + self.clip_seg_tta.heatmap = heatmap + self.clip_seg_tta.heatmap_unnormalized = heatmap_unnormalized + self.clip_seg_tta.heatmap_unnormalized_initial = heatmap_unnormalized_initial + self.clip_seg_tta.patch_embeds = patch_embeds + + # Override target positions in env + self.env.target_positions = [(pose[1], pose[0]) for pose in self.clip_seg_tta.target_positions] + + # Override segmentation mask + if not USE_CLIP_PREDS and OVERRIDE_MASK_DIR != "": + score_mask_path = os.path.join(OVERRIDE_MASK_DIR, self.clip_seg_tta.gt_mask_name) + print("score_mask_path: ", score_mask_path) + if os.path.exists(score_mask_path): + self.env.segmentation_mask = self.env.import_segmentation_mask(score_mask_path) + self.env.begin(self.env.map_start_position) + else: + print(f"\n\n{RED}ERROR: Trying to override, but score mask not found at path:{NC} ", score_mask_path) + self.bad_mask_init = True + + # Save clustered embeds from sat encoder + if USE_CLIP_PREDS: + self.kmeans_clusterer = CombinedSilhouetteInertiaClusterer( + k_min=1, + k_max=8, + k_avg_max=4, + silhouette_threshold=0.15, + relative_threshold=0.15, + random_state=0, + min_patch_size=5, + n_smooth_iter=2, + ignore_label=-1, + plot=self.save_image, + gifs_dir = GIFS_PATH + ) + # Generate kmeans clusters + self.kmeans_sat_embeds_clusters = self.kmeans_clusterer.fit_predict( + patch_embeds=self.clip_seg_tta.patch_embeds, + map_shape=(CLIP_GRIDS_DIMS[0], CLIP_GRIDS_DIMS[1]), + ) + print("Chosen k:", self.kmeans_clusterer.final_k) + + # if EXECUTE_TTA: + # print("Will execute TTA...") + + # Define Poisson TTA params + self.step_since_tta = 0 + self.steps_to_first_tgt = None + self.steps_to_mid_tgt = None + self.steps_to_last_tgt = None + + + def run_episode(self, curr_episode): + + # Return all metrics as None if faulty mask init + if self.bad_mask_init: + self.perf_metrics['tax'] = None + self.perf_metrics['travel_dist'] = None + self.perf_metrics['travel_steps'] = None + self.perf_metrics['steps_to_first_tgt'] = None + self.perf_metrics['steps_to_mid_tgt'] = None + self.perf_metrics['steps_to_last_tgt'] = None + self.perf_metrics['explored_rate'] = None + self.perf_metrics['targets_found'] = None + self.perf_metrics['targets_total'] = None + self.perf_metrics['kmeans_k'] = None + self.perf_metrics['tgts_gt_score'] = None + self.perf_metrics['clip_inference_time'] = None + self.perf_metrics['tta_time'] = None + self.perf_metrics['success_rate'] = None + return + + eps_start = time() + done = False + for robot_id, deciding_robot in enumerate(self.robot_list): + deciding_robot.observations = self.get_observations(deciding_robot.robot_position) + if LOAD_AVS_BENCH and USE_CLIP_PREDS: + if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]: # If heatmap is resized from clip original dims + heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT)) + self.env.segmentation_info_mask = np.expand_dims(heatmap.T.flatten(), axis=1) + unnormalized_heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap_unnormalized, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT)) + self.env.segmentation_info_mask_unnormalized = np.expand_dims(unnormalized_heatmap.T.flatten(), axis=1) + print("Resized heatmap to", NUM_COORDS_WIDTH, "x", NUM_COORDS_HEIGHT) + else: + self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1) + self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1) + + ### Run episode ### + for step in range(NUM_EPS_STEPS): + + next_position_list = [] + dist_list = [] + travel_dist_list = [] + dist_array = np.zeros((self.n_agent, 1)) + for robot_id, deciding_robot in enumerate(self.robot_list): + observations = deciding_robot.observations + + ### Forward pass through policy to get next position ### + next_position, action_index = self.select_node(observations) + dist = np.linalg.norm(next_position - deciding_robot.robot_position) + + ### Log results of action (e.g. distance travelled) ### + dist_array[robot_id] = dist + dist_list.append(dist) + travel_dist_list.append(deciding_robot.travel_dist) + next_position_list.append(next_position) + self.all_robot_positions[robot_id] = next_position + + arriving_sequence = np.argsort(dist_list) + next_position_list = np.array(next_position_list) + dist_list = np.array(dist_list) + travel_dist_list = np.array(travel_dist_list) + next_position_list = next_position_list[arriving_sequence] + dist_list = dist_list[arriving_sequence] + travel_dist_list = travel_dist_list[arriving_sequence] + + ### Take Action (Deconflict if 2 agents choose the same target position) ### + next_position_list, dist_list = self.solve_conflict(arriving_sequence, next_position_list, dist_list) + reward_list, done = self.env.multi_robot_step(next_position_list, dist_list, travel_dist_list) + + ### Update observations + rewards from action ### + for reward, robot_id in zip(reward_list, arriving_sequence): + robot = self.robot_list[robot_id] + robot.save_trajectory_coords(self.env.find_index_from_coords(robot.robot_position), self.env.num_new_targets_found) + + # # TTA Update via Poisson Test (with KMeans clustering stats) + if LOAD_AVS_BENCH and USE_CLIP_PREDS and self.execute_tta: + self.poisson_tta_update(robot, self.global_step, step) + + robot.observations = self.get_observations(robot.robot_position) + robot.save_reward_done(reward, done) + + # Update metrics + self.log_metrics(step=step) + + ### Save a frame to generate gif of robot trajectories ### + if self.save_image: + robots_route = [] + for robot in self.robot_list: + robots_route.append([robot.xPoints, robot.yPoints]) + if not os.path.exists(self.gifs_path): + os.makedirs(self.gifs_path) + if LOAD_AVS_BENCH: + # NOTE: Replaced since using app.py + self.env.plot_heatmap(self.gifs_path, step, max(travel_dist_list), robots_route) + + if done: + break + + if LOAD_AVS_BENCH: + tax = Path(self.clip_seg_tta.gt_mask_name).stem + self.perf_metrics['tax'] = " ".join(tax.split("_")[1:]) + else: + self.perf_metrics['tax'] = None + self.perf_metrics['travel_dist'] = max(travel_dist_list) + self.perf_metrics['travel_steps'] = step + 1 + self.perf_metrics['steps_to_first_tgt'] = self.steps_to_first_tgt + self.perf_metrics['steps_to_mid_tgt'] = self.steps_to_mid_tgt + self.perf_metrics['steps_to_last_tgt'] = self.steps_to_last_tgt + self.perf_metrics['explored_rate'] = self.env.explored_rate + self.perf_metrics['targets_found'] = self.env.targets_found_rate + self.perf_metrics['targets_total'] = len(self.env.target_positions) + if USE_CLIP_PREDS: + self.perf_metrics['kmeans_k'] = self.kmeans_clusterer.final_k + self.perf_metrics['tgts_gt_score'] = self.clip_seg_tta.tgts_gt_score + self.perf_metrics['clip_inference_time'] = self.clip_seg_tta.clip_inference_time + self.perf_metrics['tta_time'] = self.clip_seg_tta.tta_time + else: + self.perf_metrics['kmeans_k'] = None + self.perf_metrics['tgts_gt_score'] = None + self.perf_metrics['clip_inference_time'] = None + self.perf_metrics['tta_time'] = None + if FORCE_LOGGING_DONE_TGTS_FOUND and self.env.targets_found_rate == 1.0: + self.perf_metrics['success_rate'] = True + else: + self.perf_metrics['success_rate'] = done + + # save gif + if self.save_image: + path = self.gifs_path # NOTE: Set to self.gifs_path since using app.py + self.make_gif(path, curr_episode) + + print(YELLOW, f"[Eps {curr_episode} Completed] Time Taken: {time()-eps_start:.2f}s, Steps: {step+1}", NC) + + def get_observations(self, robot_position): + """ Get robot's sensor observation of environment given position """ + current_node_index = self.env.find_index_from_coords(robot_position) + current_index = torch.tensor([current_node_index]).unsqueeze(0).unsqueeze(0).to(self.device) # (1,1,1) + + node_coords = copy.deepcopy(self.env.node_coords) + graph = copy.deepcopy(self.env.graph) + node_utility = copy.deepcopy(self.env.node_utility) + guidepost = copy.deepcopy(self.env.guidepost) + segmentation_info_mask = copy.deepcopy(self.env.filtered_seg_info_mask) + + n_nodes = node_coords.shape[0] + node_coords = node_coords / 640 + node_utility = node_utility / 50 + node_utility_inputs = node_utility.reshape((n_nodes, 1)) + + occupied_node = np.zeros((n_nodes, 1)) + for position in self.all_robot_positions: + index = self.env.find_index_from_coords(position) + if index == current_index.item(): + occupied_node[index] = -1 + else: + occupied_node[index] = 1 + + node_inputs = np.concatenate((node_coords, segmentation_info_mask, guidepost), axis=1) + node_inputs = torch.FloatTensor(node_inputs).unsqueeze(0).to(self.device) + node_padding_mask = None + + graph = list(graph.values()) + edge_inputs = [] + for node in graph: + node_edges = list(map(int, node)) + edge_inputs.append(node_edges) + + bias_matrix = self.calculate_edge_mask(edge_inputs) + edge_mask = torch.from_numpy(bias_matrix).float().unsqueeze(0).to(self.device) + + for edges in edge_inputs: + while len(edges) < self.k_size: + edges.append(0) + + edge_inputs = torch.tensor(edge_inputs).unsqueeze(0).to(self.device) + edge_padding_mask = torch.zeros((1, len(edge_inputs), K_SIZE), dtype=torch.int64).to(self.device) + one = torch.ones_like(edge_padding_mask, dtype=torch.int64).to(self.device) + edge_padding_mask = torch.where(edge_inputs == 0, one, edge_padding_mask) + + observations = node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask + return observations + + + def select_node(self, observations): + """ Forward pass through policy to get next position to go to on map """ + node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask = observations + with torch.no_grad(): + logp_list = self.local_policy_net(node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask) + + if self.greedy: + action_index = torch.argmax(logp_list, dim=1).long() + else: + action_index = torch.multinomial(logp_list.exp(), 1).long().squeeze(1) + + next_node_index = edge_inputs[:, current_index.item(), action_index.item()] + + next_position = self.env.node_coords[next_node_index] + + return next_position, action_index + + def solve_conflict(self, arriving_sequence, next_position_list, dist_list): + """ Deconflict if 2 agents choose the same target position """ + for j, [robot_id, next_position] in enumerate(zip(arriving_sequence, next_position_list)): + moving_robot = self.robot_list[robot_id] + # if next_position[0] + next_position[1] * 1j in (next_position_list[:, 0] + next_position_list[:, 1] * 1j)[:j]: + # dist_to_next_position = np.argsort(np.linalg.norm(self.env.node_coords - next_position, axis=1)) + # k = 0 + # while next_position[0] + next_position[1] * 1j in (next_position_list[:, 0] + next_position_list[:, 1] * 1j)[:j]: + # k += 1 + # next_position = self.env.node_coords[dist_to_next_position[k]] + + dist = np.linalg.norm(next_position - moving_robot.robot_position) + next_position_list[j] = next_position + dist_list[j] = dist + moving_robot.travel_dist += dist + moving_robot.robot_position = next_position + + return next_position_list, dist_list + + def work(self, currEpisode): + ''' + Interacts with the environment. The agent gets either gradients or experience buffer + ''' + self.run_episode(currEpisode) + + def calculate_edge_mask(self, edge_inputs): + size = len(edge_inputs) + bias_matrix = np.ones((size, size)) + for i in range(size): + for j in range(size): + if j in edge_inputs[i]: + bias_matrix[i][j] = 0 + return bias_matrix + + def make_gif(self, path, n): + """ Generate a gif given list of images """ + with imageio.get_writer('{}/{}_target_rate_{:.2f}.gif'.format(path, n, self.env.targets_found_rate), mode='I', + fps=5) as writer: + for frame in self.env.frame_files: + image = imageio.imread(frame) + writer.append_data(image) + print('gif complete\n') + + # Remove files + for filename in self.env.frame_files[:-1]: + os.remove(filename) + + # For gif during TTA + if LOAD_AVS_BENCH: + with imageio.get_writer('{}/{}_kmeans_stats.gif'.format(path, n), mode='I', + fps=5) as writer: + for frame in self.kmeans_clusterer.kmeans_frame_files: + image = imageio.imread(frame) + writer.append_data(image) + print('Kmeans Clusterer gif complete\n') + + # Remove files + for filename in self.kmeans_clusterer.kmeans_frame_files[:-1]: + os.remove(filename) + + ################################################################################ + # SPPP Related Fns + ################################################################################ + + def log_metrics(self, step): + # Update tgt found metrics + if self.steps_to_first_tgt is None and self.env.num_targets_found == 1: + self.steps_to_first_tgt = step + 1 + if self.steps_to_mid_tgt is None and self.env.num_targets_found == int(len(self.env.target_positions) / 2): + self.steps_to_mid_tgt = step + 1 + if self.steps_to_last_tgt is None and self.env.num_targets_found == len(self.env.target_positions): + self.steps_to_last_tgt = step + 1 + + + def transpose_flat_idx(self, idx, H=NUM_COORDS_HEIGHT, W=NUM_COORDS_WIDTH): + """ + Transpose a flat index from an ``HΓ—W`` grid to the equivalent + position in the ``WΓ—H`` transposed grid while **keeping the result + in 1-D**. + """ + # --- Safety check to catch out-of-range indices --- + assert 0 <= idx < H * W, f"idx {idx} out of bounds for shape ({H}, {W})" + + # Original (row, col) + row, col = divmod(idx, W) + # After transpose these coordinates swap + row_T, col_T = col, row + + # Flatten back into 1-D (row-major) for the WΓ—H grid + return row_T * H + col_T + + + def poisson_tta_update(self, robot, episode, step): + + # Generate Kmeans Clusters Stats + # Scale index back to CLIP_GRIDS_DIMS to be compatible with CLIP patch size + if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]: + # High-res remap via pixel coordinates preserves exact neighbourhood + filt_traj_coords, filt_targets_found_on_path = self.scale_trajectory( + robot.trajectory_coords, + self.env.target_positions, + old_dims=(NUM_COORDS_HEIGHT, NUM_COORDS_WIDTH), + full_dims=(512, 512), + new_dims=(CLIP_GRIDS_DIMS[0], CLIP_GRIDS_DIMS[1]) + ) + else: + filt_traj_coords = [self.transpose_flat_idx(idx) for idx in robot.trajectory_coords] + filt_targets_found_on_path = robot.targets_found_on_path + + region_stats_dict = self.kmeans_clusterer.compute_region_statistics( + self.kmeans_sat_embeds_clusters, + self.clip_seg_tta.heatmap_unnormalized, + filt_traj_coords, + episode_num=episode, + step_num=step + ) + + # Prep & execute TTA + self.step_since_tta += 1 + if robot.targets_found_on_path[-1] or self.step_since_tta % STEPS_PER_TTA == 0: + + # NOTE: integration with app.py on hf + self.clip_seg_tta.executing_tta = True + + num_cells = self.clip_seg_tta.heatmap.shape[0] * self.clip_seg_tta.heatmap.shape[1] + pos_sample_weight_scale, neg_sample_weight_scale = [], [] + + for i, sample_loc in enumerate(filt_traj_coords): + label = self.kmeans_clusterer.get_label_id(self.kmeans_sat_embeds_clusters, sample_loc) + num_patches = region_stats_dict[label]['num_patches'] + patches_visited = region_stats_dict[label]['patches_visited'] + expectation = region_stats_dict[label]['expectation'] + + # Exponent like focal loss to wait for more samples before confidently decreasing + pos_weight = 4.0 + neg_weight = min(1.0, (patches_visited/(3*num_patches))**GAMMA_EXPONENT) + pos_sample_weight_scale.append(pos_weight) + neg_sample_weight_scale.append(neg_weight) + + # # # Adaptative LR (as samples increase, increase LR to fit more datapoints) + adaptive_lr = MIN_LR + (MAX_LR - MIN_LR) * (step / num_cells) + + # TTA Update + # NOTE: updated due to app.py (hf does not allow heatmap to persist) + heatmap = self.clip_seg_tta.execute_tta( + filt_traj_coords, + filt_targets_found_on_path, + tta_steps=NUM_TTA_STEPS, + lr=adaptive_lr, + pos_sample_weight=pos_sample_weight_scale, + neg_sample_weight=neg_sample_weight_scale, + reset_weights=RESET_WEIGHTS + ) + self.clip_seg_tta.heatmap = heatmap + + if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]: # If heatmap is resized from clip original dims + heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT)) + self.env.segmentation_info_mask = np.expand_dims(heatmap.T.flatten(), axis=1) + unnormalized_heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap_unnormalized, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT)) + self.env.segmentation_info_mask_unnormalized = np.expand_dims(unnormalized_heatmap.T.flatten(), axis=1) + print("~Resized heatmap to", NUM_COORDS_WIDTH, "x", NUM_COORDS_HEIGHT) + else: + self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1) + self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1) + + self.step_since_tta = 0 + + # NOTE: integration with app.py on hf + self.clip_seg_tta.executing_tta = False + + + def convert_heatmap_resolution(self, heatmap, full_dims=(512, 512), new_dims=(24, 24)): + heatmap_large = resize(heatmap, full_dims, order=1, # order=1 β†’ bilinear + mode='reflect', anti_aliasing=True) + + coords = self.env.graph_generator.grid_coords # (N, N, 2) + rows, cols = coords[...,1], coords[...,0] + heatmap_resized = heatmap_large[rows, cols] + heatmap_resized = heatmap_resized.reshape(new_dims[1], new_dims[0]) + return heatmap_resized + + + def convert_labelmap_resolution(self, labelmap, full_dims=(512, 512), new_dims=(24, 24)): + """ + 1) Upsample via nearest‐neighbor to full_dims + 2) Sample back down to your graph grid using grid_coords + """ + # 1) Upsample with nearest‐neighbor, preserving integer labels + up = resize( + labelmap, + full_dims, + order=0, # nearest‐neighbor + mode='edge', # padding mode + preserve_range=True, # don't normalize labels + anti_aliasing=False # must be False for labels + ).astype(labelmap.dtype) # back to original integer dtype + + # 2) Downsample via your precomputed grid coords + coords = self.env.graph_generator.grid_coords # shape (N, N, 2) + rows = coords[...,1].astype(int) + cols = coords[...,0].astype(int) + + small = up[rows, cols] # shape (N, N) + small = small.reshape(new_dims[0], new_dims[1]) + return small + + + def scale_trajectory(self, + flat_indices, + targets, + old_dims=(17, 17), + full_dims=(512, 512), + new_dims=(24, 24)): + """ + Args: + flat_indices: list of ints in [0..old_H*old_W-1] + targets: list of (y_pix, x_pix) in [0..full_H-1] + old_dims: (old_H, old_W) + full_dims: (full_H, full_W) + new_dims: (new_H, new_W) + + Returns: + new_flat_traj: list of unique flattened indices in new_HΓ—new_W + counts: list of ints, same length as new_flat_traj + """ + old_H, old_W = old_dims + full_H, full_W = full_dims + new_H, new_W = new_dims + + # 1) bin targets into new grid + cell_h_new = full_H / new_H + cell_w_new = full_W / new_W + grid_counts = [[0]*new_W for _ in range(new_H)] + for x_pix, y_pix in targets: # note (x, y) order as in original implementation + i_t = min(int(y_pix / cell_h_new), new_H - 1) + j_t = min(int(x_pix / cell_w_new), new_W - 1) + grid_counts[i_t][j_t] += 1 + + # 2) Walk the trajectory indices and project each old cell's *entire + # pixel footprint* onto the finer 24Γ—24 grid. + cell_h_full = full_H / old_H + cell_w_full = full_W / old_W + + seen = set() + new_flat_traj = [] + + for node_idx in flat_indices: + if node_idx < 0 or node_idx >= len(self.env.graph_generator.node_coords): + continue + + coord_xy = self.env.graph_generator.node_coords[node_idx] + try: + row_old, col_old = self.env.graph_generator.find_index_from_grid_coords_2d(coord_xy) + except Exception: + continue + + # Bounding box of the old cell in full-resolution pixel space + y0 = row_old * cell_h_full + y1 = (row_old + 1) * cell_h_full + x0 = col_old * cell_w_full + x1 = (col_old + 1) * cell_w_full + + # Which new-grid rows & cols overlap? (inclusive ranges) + i_start = max(0, min(int(y0 / cell_h_new), new_H - 1)) + i_end = max(0, min(int((y1 - 1) / cell_h_new), new_H - 1)) + j_start = max(0, min(int(x0 / cell_w_new), new_W - 1)) + j_end = max(0, min(int((x1 - 1) / cell_w_new), new_W - 1)) + + for ii in range(i_start, i_end + 1): + for jj in range(j_start, j_end + 1): + f_new = ii * new_W + jj + if f_new not in seen: + seen.add(f_new) + new_flat_traj.append(f_new) + + # 3) annotate counts + counts = [] + for f in new_flat_traj: + i_new, j_new = divmod(f, new_W) + counts.append(grid_counts[i_new][j_new]) + + return new_flat_traj, counts + + ################################################################################ diff --git a/planner/worker.py b/planner/worker.py new file mode 100644 index 0000000000000000000000000000000000000000..6abc2909d0aa4cacc4ab38b7d00a874e3bb3d263 --- /dev/null +++ b/planner/worker.py @@ -0,0 +1,272 @@ +####################################################################### +# Name: worker.py +# +# - Runs robot in environment for N steps +# - Collects & Returns S(t), A(t), R(t), S(t+1) +####################################################################### + +from .parameter import * + +import os +import json +import copy +import imageio +import numpy as np +import torch +from time import time +from .env import Env +from .robot import Robot + +class Worker: + def __init__(self, meta_agent_id, n_agent, policy_net, q_net, global_step, device='cuda', greedy=False, save_image=False, clip_seg_tta=None): + self.device = device + self.greedy = greedy + self.n_agent = n_agent + self.metaAgentID = meta_agent_id + self.global_step = global_step + self.node_padding_size = NODE_PADDING_SIZE + self.k_size = K_SIZE + self.save_image = save_image + self.clip_seg_tta = clip_seg_tta + + # Randomize map_index + mask_index = None + if MASKS_RAND_INDICES_PATH != "": + with open(MASKS_RAND_INDICES_PATH, 'r') as f: + mask_index_rand_json = json.load(f) + mask_index = mask_index_rand_json[self.global_step % len(mask_index_rand_json)] + print("mask_index: ", mask_index) + + self.env = Env(map_index=self.global_step, n_agent=n_agent, k_size=self.k_size, plot=save_image, mask_index=mask_index) + self.local_policy_net = policy_net + self.local_q_net = q_net + + self.robot_list = [] + self.all_robot_positions = [] + + for i in range(self.n_agent): + robot_position = self.env.start_positions[i] + robot = Robot(robot_id=i, position=robot_position, plot=save_image) + self.robot_list.append(robot) + self.all_robot_positions.append(robot_position) + + self.perf_metrics = dict() + self.episode_buffer = [] + for i in range(15): + self.episode_buffer.append([]) + + + def run_episode(self, curr_episode): + + eps_start = time() + done = False + for robot_id, deciding_robot in enumerate(self.robot_list): + deciding_robot.observations = self.get_observations(deciding_robot.robot_position) + + ### Run episode ### + for step in range(NUM_EPS_STEPS): + + next_position_list = [] + dist_list = [] + travel_dist_list = [] + dist_array = np.zeros((self.n_agent, 1)) + for robot_id, deciding_robot in enumerate(self.robot_list): + observations = deciding_robot.observations + deciding_robot.save_observations(observations) + + ### Forward pass through policy to get next position ### + next_position, action_index = self.select_node(observations) + deciding_robot.save_action(action_index) + + dist = np.linalg.norm(next_position - deciding_robot.robot_position) + + ### Log results of action (e.g. distance travelled) ### + dist_array[robot_id] = dist + dist_list.append(dist) + travel_dist_list.append(deciding_robot.travel_dist) + next_position_list.append(next_position) + self.all_robot_positions[robot_id] = next_position + + arriving_sequence = np.argsort(dist_list) + next_position_list = np.array(next_position_list) + dist_list = np.array(dist_list) + travel_dist_list = np.array(travel_dist_list) + next_position_list = next_position_list[arriving_sequence] + dist_list = dist_list[arriving_sequence] + travel_dist_list = travel_dist_list[arriving_sequence] + + ### Take Action (Deconflict if 2 agents choose the same target position) ### + next_position_list, dist_list = self.solve_conflict(arriving_sequence, next_position_list, dist_list) + reward_list, done = self.env.multi_robot_step(next_position_list, dist_list, travel_dist_list) + + ### Update observations + rewards from action ### + for reward, robot_id in zip(reward_list, arriving_sequence): + robot = self.robot_list[robot_id] + robot.observations = self.get_observations(robot.robot_position) + robot.save_trajectory_coords(self.env.find_index_from_coords(robot.robot_position), self.env.num_new_targets_found) + robot.save_reward_done(reward, done) + robot.save_next_observations(robot.observations) + + ### Save a frame to generate gif of robot trajectories ### + if self.save_image: + robots_route = [] + for robot in self.robot_list: + robots_route.append([robot.xPoints, robot.yPoints]) + if not os.path.exists(GIFS_PATH): + os.makedirs(GIFS_PATH) + self.env.plot_env(self.global_step, GIFS_PATH, step, max(travel_dist_list), robots_route) + + if done: + break + + for robot in self.robot_list: + for i in range(15): + self.episode_buffer[i] += robot.episode_buffer[i] + + self.perf_metrics['travel_dist'] = max(travel_dist_list) + self.perf_metrics['explored_rate'] = self.env.explored_rate + self.perf_metrics['targets_found'] = self.env.targets_found_rate + self.perf_metrics['success_rate'] = done + + # save gif + if self.save_image: + path = GIFS_PATH + self.make_gif(path, curr_episode) + + print(YELLOW, f"[Eps {curr_episode} Completed] Time Taken: {time()-eps_start:.2f}s, Steps: {step+1}", NC) + + + def get_observations(self, robot_position): + """ Get robot's sensor observation of environment given position """ + current_node_index = self.env.find_index_from_coords(robot_position) + current_index = torch.tensor([current_node_index]).unsqueeze(0).unsqueeze(0).to(self.device) # (1,1,1) + + node_coords = copy.deepcopy(self.env.node_coords) + graph = copy.deepcopy(self.env.graph) + node_utility = copy.deepcopy(self.env.node_utility) + guidepost = copy.deepcopy(self.env.guidepost) + segmentation_info_mask = copy.deepcopy(self.env.filtered_seg_info_mask) + + n_nodes = node_coords.shape[0] + node_coords = node_coords / 640 + node_utility = node_utility / 50 + + node_utility_inputs = node_utility.reshape((n_nodes, 1)) + + occupied_node = np.zeros((n_nodes, 1)) + for position in self.all_robot_positions: + index = self.env.find_index_from_coords(position) + if index == current_index.item(): + occupied_node[index] = -1 + else: + occupied_node[index] = 1 + + node_inputs = np.concatenate((node_coords, segmentation_info_mask, guidepost), axis=1) + node_inputs = torch.FloatTensor(node_inputs).unsqueeze(0).to(self.device) # (1, node_padding_size+1, 3) + + assert node_coords.shape[0] < self.node_padding_size + padding = torch.nn.ZeroPad2d((0, 0, 0, self.node_padding_size - node_coords.shape[0])) + node_inputs = padding(node_inputs) + + node_padding_mask = torch.zeros((1, 1, node_coords.shape[0]), dtype=torch.int64).to(self.device) + node_padding = torch.ones((1, 1, self.node_padding_size - node_coords.shape[0]), dtype=torch.int64).to( + self.device) + node_padding_mask = torch.cat((node_padding_mask, node_padding), dim=-1) + + graph = list(graph.values()) + edge_inputs = [] + for node in graph: + node_edges = list(map(int, node)) + edge_inputs.append(node_edges) + + bias_matrix = self.calculate_edge_mask(edge_inputs) + edge_mask = torch.from_numpy(bias_matrix).float().unsqueeze(0).to(self.device) + + assert len(edge_inputs) < self.node_padding_size + padding = torch.nn.ConstantPad2d( + (0, self.node_padding_size - len(edge_inputs), 0, self.node_padding_size - len(edge_inputs)), 1) + edge_mask = padding(edge_mask) + padding2 = torch.nn.ZeroPad2d((0, 0, 0, self.node_padding_size - len(edge_inputs))) + + for edges in edge_inputs: + while len(edges) < self.k_size: + edges.append(0) + + edge_inputs = torch.tensor(edge_inputs).unsqueeze(0).to(self.device) # (1, node_padding_size+1, k_size) + edge_inputs = padding2(edge_inputs) + + edge_padding_mask = torch.zeros((1, len(edge_inputs), K_SIZE), dtype=torch.int64).to(self.device) + one = torch.ones_like(edge_padding_mask, dtype=torch.int64).to(self.device) + edge_padding_mask = torch.where(edge_inputs == 0, one, edge_padding_mask) + + observations = node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask + return observations + + + def select_node(self, observations): + """ Forward pass through policy to get next position to go to on map """ + node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask = observations + with torch.no_grad(): + logp_list = self.local_policy_net(node_inputs, edge_inputs, current_index, node_padding_mask, + edge_padding_mask, edge_mask) + + if self.greedy: + action_index = torch.argmax(logp_list, dim=1).long() + else: + action_index = torch.multinomial(logp_list.exp(), 1).long().squeeze(1) + + next_node_index = edge_inputs[:, current_index.item(), action_index.item()] + + next_position = self.env.node_coords[next_node_index] + + return next_position, action_index + + + def solve_conflict(self, arriving_sequence, next_position_list, dist_list): + """ Deconflict if 2 agents choose the same target position """ + for j, [robot_id, next_position] in enumerate(zip(arriving_sequence, next_position_list)): + moving_robot = self.robot_list[robot_id] + # if next_position[0] + next_position[1] * 1j in (next_position_list[:, 0] + next_position_list[:, 1] * 1j)[:j]: + # dist_to_next_position = np.argsort(np.linalg.norm(self.env.node_coords - next_position, axis=1)) + # k = 0 + # while next_position[0] + next_position[1] * 1j in (next_position_list[:, 0] + next_position_list[:, 1] * 1j)[:j]: + # k += 1 + # next_position = self.env.node_coords[dist_to_next_position[k]] + + dist = np.linalg.norm(next_position - moving_robot.robot_position) + next_position_list[j] = next_position + dist_list[j] = dist + moving_robot.travel_dist += dist + moving_robot.robot_position = next_position + + return next_position_list, dist_list + + + def work(self, currEpisode): + ''' + Interacts with the environment. The agent gets either gradients or experience buffer + ''' + self.run_episode(currEpisode) + + def calculate_edge_mask(self, edge_inputs): + size = len(edge_inputs) + bias_matrix = np.ones((size, size)) + for i in range(size): + for j in range(size): + if j in edge_inputs[i]: + bias_matrix[i][j] = 0 + return bias_matrix + + + def make_gif(self, path, n): + """ Generate a gif given list of images """ + with imageio.get_writer('{}/{}_target_rate_{:.2f}.gif'.format(path, n, self.env.targets_found_rate), mode='I', + fps=5) as writer: + for frame in self.env.frame_files: + image = imageio.imread(frame) + writer.append_data(image) + print('gif complete\n') + + # Remove files + for filename in self.env.frame_files[:-1]: + os.remove(filename) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..92dc9c5e3fcf1940f69f063679662be9590408a2 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,24 @@ +# python 3.10.14 + +# Search-TTA +scipy==1.14.1 +scikit-learn==1.6.1 +scikit-image==0.24.0 +matplotlib==3.9.1 +imageio==2.36.0 +shapely==2.0.7 +rasterio==1.4.1 +kneed==0.8.5 +easydict==1.13 + +# Taxabind +numpy==1.26.3 +torch==2.4.0 +torchvision==0.19.0 +torchaudio==2.4.0 +pytorch-lightning==2.2.1 +open_clip_torch==2.30.0 +transformers==4.45.1 +tokenizers==0.20.3 +opencv-python==4.10.0.84 +gradio==4.44.1 \ No newline at end of file diff --git a/taxabind_avs/LICENSE b/taxabind_avs/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/taxabind_avs/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/taxabind_avs/README.md b/taxabind_avs/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bc6410a407f561c691db935a6ed0701be6181375 --- /dev/null +++ b/taxabind_avs/README.md @@ -0,0 +1 @@ +This folder is adapted from the [Taxabind repository](https://github.com/mvrl/TaxaBind). \ No newline at end of file diff --git a/taxabind_avs/satbind/clip_seg_tta.py b/taxabind_avs/satbind/clip_seg_tta.py new file mode 100644 index 0000000000000000000000000000000000000000..f17334f2b2c150b22ce959909e981757294b96a5 --- /dev/null +++ b/taxabind_avs/satbind/clip_seg_tta.py @@ -0,0 +1,564 @@ +############################################################################## +# Name: clip_seg_tta.py +# +# - Performs TTA on sat encoder a collected measurements +############################################################################### + +import os +import sys +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +import cv2 +import time +import torch +import numpy as np +import open_clip +import torch.nn as nn +import spaces +from PIL import Image +from matplotlib import pyplot as plt +from dataset import SatNatDataset +from model_sat import SatBind +from transformers import ClapAudioModelWithProjection +from clip_vision_per_patch_model import CLIPVisionPerPatchModel +from types import SimpleNamespace +from config_sat import config + + +class ClipSegTTA: + def __init__( + self, + img_dir: str, + imo_dir: str, + json_path: str, + sat_to_img_ids_path: str, + sat_checkpoint_path: str, + load_pretrained_hf_ckpt: bool = True, + sample_index: int = 0, # Set using 'reset' + blur_kernel = (5,5), # (0,0) for no gaussian blur + batch_size: int = 1, + num_workers: int = 1, + device: str = "cuda", + sat_to_img_ids_json_is_train_dict: bool = True, + tax_to_filter_val: str = "", + load_model: bool = True, + query_modality: str = "image", # image, text, sound + sound_dir: str = None, + sound_checkpoint_path: str = None, + ): + + self.img_dir = img_dir + self.imo_dir = imo_dir + self.json_path = json_path + self.sat_to_img_ids_path = sat_to_img_ids_path + self.sat_checkpoint_path = sat_checkpoint_path + self.pretrained_hf_ckpt = load_pretrained_hf_ckpt + self.sample_index = sample_index + self.blur_kernel = blur_kernel + self.batch_size = batch_size + self.num_workers = num_workers + self.device = device + self.sat_to_img_ids_json_is_train_dict = sat_to_img_ids_json_is_train_dict + self.tax_to_filter_val = tax_to_filter_val + self.load_model = load_model + self.query_modality = query_modality + self.sound_dir = sound_dir + self.sound_checkpoint_path = sound_checkpoint_path + + # Prepare the dataset + start_time = time.time() + self.load_data() + print(f"Dataset loaded in {(time.time()-start_time):.2f}s.") + + if self.load_model: + start_time = time.time() + + # Load the global model (original/frozen checkpoint) + self.load_global_model() + self.tokenizer = open_clip.get_tokenizer(config.image_encoder_finetuned) + print(f"Global model loaded in {(time.time()-start_time):.2f}s.") + + # Create the local model that will be adapted for TTA + if self.pretrained_hf_ckpt: + imo_encoder = CLIPVisionPerPatchModel.from_pretrained(config.sat_encoder_finetuned) + imo_encoder.to(self.device) + imo_encoder.eval() + bio_model, *_ = open_clip.create_model_and_transforms(config.image_encoder_finetuned) + bio_model.to(self.device) + bio_model.eval() + logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.model_local = SimpleNamespace(imo_encoder=imo_encoder, bio_model=bio_model, logit_scale=logit_scale) + + # Load sound model if provided + if self.query_modality =="sound": + self.sound_model = ClapAudioModelWithProjection.from_pretrained(config.sound_encoder_finetuned) + self.sound_model.to(self.device) + self.sound_model.eval() + print("~Loaded HF checkpoint") + else: + self.model_local = SatBind(train_dataset=None, val_dataset=None) + self.model_local.to(self.device) + self.model_local.eval() + + # Load sound model if provided + if self.query_modality =="sound" and self.sound_checkpoint_path: + from soundbind.model_sound import AudioBind + self.sound_model = AudioBind.load_from_checkpoint(self.sound_checkpoint_path, train_dataset=None, val_dataset=None) + self.sound_model.to(self.device) + self.sound_model.eval() + print("~Loaded local checkpoint") + + self.clip_inference_time = 0.0 + self.tta_time = 0.0 + + # NOTE: integration with app.py on hf + if sample_index >= 0: + self.reset(sample_idx=self.sample_index) + self.executing_tta = False + + + def load_data(self): + """Load or initialize the dataset.""" + self.dataset = SatNatDataset( + img_dir=self.img_dir, + imo_dir=self.imo_dir, + json_path=self.json_path, + sat_to_img_ids_path=self.sat_to_img_ids_path, + patch_size=config.patch_size, + mode="val", + get_img_path=True, + sat_to_img_ids_json_is_train_dict=self.sat_to_img_ids_json_is_train_dict, + tax_to_filter_val=self.tax_to_filter_val, + sound_dir=self.sound_dir + ) + + def reset(self, sample_idx): + """Reset the parameters & local model for the current sample.""" + if self.load_model: + self.reset_local_model() # Reset to global weights as init + + # NOTE: integration with app.py on hf + if sample_idx >= 0: + self.img_paths, self.imo_path, self.imgs, self.imo, self.sounds, self.sound_ids, self.species_name, self.target_positions, self.gt_mask_name = self.dataset.get_search_ds_data(sample_idx) + self.imgs = self.imgs.to(self.device) + + self.tgts_gt_score = None + if self.load_model: + self.heatmap, self.heatmap_unnormalized, self.heatmap_unnormalized_initial, self.patch_embeds = None, None, None, None + img = self.imgs[0].unsqueeze(0).to(self.device) + imo = self.imo.unsqueeze(0).to(self.device) + txt = [self.species_name] + if self.sounds != []: + sound = self.sounds[0].to(self.device) + for k in sound.keys(): + sound[k] = sound[k].to(self.device) + else: + sound = None + self.generate_heatmap(img, imo, txt, sound=sound, modality=self.query_modality) + + # Find avg heatmap score for target positions + scores = [] + imo_orig = Image.open(self.imo_path) + for pos in self.target_positions: + row_trans = int(pos[0] * self.heatmap.shape[0] / imo_orig.size[0]) + col_trans = int(pos[1] * self.heatmap.shape[1] / imo_orig.size[1]) + scores.append(self.heatmap[row_trans, col_trans]) + self.tgts_gt_score = np.mean(scores) + + # NOTE: integration with app.py on hf + return self.heatmap, self.heatmap_unnormalized, self.heatmap_unnormalized_initial, self.patch_embeds + + + def load_global_model(self): + """Load the global SatBind model from checkpoint, move to device, and eval.""" + if self.pretrained_hf_ckpt: + print("Downloading HF checkpoint (if not already downloaded)...") + self.model_global = CLIPVisionPerPatchModel.from_pretrained(config.sat_encoder_finetuned) + else: + self.model_global = SatBind.load_from_checkpoint( + self.sat_checkpoint_path, train_dataset=None, val_dataset=None + ) + self.model_global = self.model_global.to(self.device) + self.model_global.eval() + + + def reset_local_model(self): + """ + Reset the local model to match the global model's parameters + and freeze/unfreeze layers for TTA. + """ + start_time = time.time() + with torch.no_grad(): + local_params = self.model_local.imo_encoder.parameters() \ + if self.pretrained_hf_ckpt else self.model_local.parameters() + for param_global, param_local in zip( + self.model_global.parameters(), local_params + ): + param_local.data.copy_(param_global.data) + + if self.pretrained_hf_ckpt: + for param in self.model_local.imo_encoder.parameters(): + param.requires_grad = True + self.model_local.imo_encoder.eval() + else: + # Freeze everything except the satellite encoder & custom projection + for name, param in self.model_local.named_parameters(): + if "imo_encoder" in name or "visual_projection_custom" in name: + param.requires_grad = True + else: + param.requires_grad = False + self.model_local.eval() + + + # NOTE: integration with app.py on hf + @spaces.GPU(duration=120) + def execute_tta( + self, + patch_indices: list, + patch_is_pos: list, + pos_sample_weight: float, + neg_sample_weight: float, + tta_steps: int = 10, + lr: float = 2e-6, + reset_weights: bool = True, + num_viz_steps: int = 1, + viz_heatmap: bool = False, + ): + """ + Run test-time adaptation using the local model. The local model is first + reset to the global weights. After TTA, the global model remains + unchanged; only the local model is updated. + """ + + ### Option 1: SAMPLE FROM DATASET + # 1) Reset the local model to global weights + if reset_weights: + self.reset_local_model() + + ## NOTE: Added due to app.py (to allocate to GPU only when needed on HF) + print("Allocating models to GPU...") + self.device = torch.device("cuda") + self.model_local.imo_encoder.to(self.device) + self.model_local.bio_model.to(self.device) + + # 2) Prepare the sample(s) for TTA + img = self.imgs[0].unsqueeze(0).to(self.device) + imo = self.imo.unsqueeze(0).to(self.device) # vectorize + txt = [self.species_name] + if self.sounds != []: + sound = self.sounds[0].to(self.device) + for k in sound.keys(): + sound[k] = sound[k].to(self.device) + else: + sound = None + patch_indices = [idx+1 for idx in patch_indices] # Consider the [CLS] token offset + patch_idx = torch.tensor(patch_indices).to(self.device) + + # --------------------------------------------------------------------- + + # 5) Set up optimizer + local_params = self.model_local.imo_encoder.parameters() \ + if self.pretrained_hf_ckpt else self.model_local.parameters() + optimizer = torch.optim.Adam( + [p for p in local_params if p.requires_grad], lr=lr + ) + start_time = time.time() + + # 6) TTA loop + for step in range(tta_steps): + batch_size = imo.shape[0] + + # Query embeds + query_embeds = self.generate_query_embeds(img, imo, txt, sound=sound, modality=self.query_modality) + + # Sat Embeds + if self.pretrained_hf_ckpt: + imo_embeds = self.model_local.imo_encoder.vision_model(imo, return_dict=True).last_hidden_state + imo_embeds = imo_embeds[torch.arange(batch_size), patch_idx] + imo_embeds = self.model_local.imo_encoder.visual_projection(imo_embeds) + else: + imo_embeds = self.model_local.imo_encoder(imo).last_hidden_state # (batch, Patches, hidden_dim) + imo_embeds = imo_embeds[torch.arange(batch_size), patch_idx] # (batch, hidden_dim) + imo_embeds = self.model_local.visual_projection_custom(imo_embeds) # (batch_size, proj_dim) + imo_embeds = torch.nn.functional.normalize(imo_embeds, dim=-1) + + # Compute Similarity Loss + logit_scale = self.model_local.logit_scale.exp() + similarity = imo_embeds @ query_embeds.t() * logit_scale + + # Negative Log Likelihood loss for spatial poisson point process + patch_probs = similarity.squeeze().sigmoid() + counts = torch.tensor(patch_is_pos, dtype=torch.float32, device=similarity.device) + pos_weights = torch.tensor(pos_sample_weight, dtype=torch.float32, device=similarity.device) + neg_weights = torch.tensor(neg_sample_weight, dtype=torch.float32, device=similarity.device) + loss = (neg_weights * patch_probs - pos_weights * counts * torch.log(patch_probs + 1e-6)) + loss = loss.sum() + + # Backprop and update + optimizer.zero_grad() + loss.backward() + optimizer.step() + + self.tta_time = time.time() - start_time + + # Visualization every 'num_viz_steps' steps (if enabled) + if (step + 1) % num_viz_steps == 0 and viz_heatmap: + self.generate_heatmap(img, imo, txt, sound=sound, modality=self.query_modality) + self.visualize_heatmap( + step=step, + img_path_viz=self.img_paths[0], # Viz 1st image + imo_path_viz=self.imo_path, + patch_idx_viz=patch_idx, + patch_is_pos=patch_is_pos, + species_name=self.species_name + ) + + # Save final heatmap after TTA steps + self.generate_heatmap(img, imo, txt, sound=sound, modality=self.query_modality) + + ## NOTE: Added due to app.py (to allocate to GPU only when needed on HF) + print("Deallocating models from GPU...") + self.device = torch.device("cpu") + self.model_local.imo_encoder.to(self.device) + self.model_local.bio_model.to(self.device) + + return self.heatmap + + + def generate_query_embeds(self, img, imo, txt, sound=None, modality="image"): + + # Query Embeds + if modality == "image": + query_embeds, *_ = self.model_local.bio_model(img) # (batch_size, proj_dim) + if query_embeds.shape[0] > 1: + query_embeds = query_embeds.mean(dim=0, keepdim=True) # (1, proj_dim) + elif modality == "text": + txt_tokenized = self.tokenizer(txt).to(imo.device) + _, query_embeds, _ = self.model_local.bio_model(text=txt_tokenized) + elif modality == "sound": + if sound == None: + print("!!!! Sound modality requires sound input !!!") + exit(1) + if self.pretrained_hf_ckpt: + unnormalized_audio_embeds = self.sound_model(**sound).audio_embeds + else: + unnormalized_audio_embeds = self.sound_model.audio_encoder(sound) + query_embeds = torch.nn.functional.normalize(unnormalized_audio_embeds, dim=-1) + else: + raise ValueError("Invalid modality") + + return query_embeds + + + def generate_heatmap(self, img, imo, txt, sound=None, modality="image"): + + start_time = time.time() + + # Satellite encoder outputs + if self.pretrained_hf_ckpt: + imo_embeds = self.model_local.imo_encoder(imo) + else: + imo_embeds = self.model_local.imo_encoder(imo).last_hidden_state + imo_embeds = self.model_local.visual_projection_custom(imo_embeds) + imo_embeds = torch.nn.functional.normalize(imo_embeds, dim=-1) + + # Remove batch dimension -> (num_tokens, proj_dim) + imo_embeds = imo_embeds.squeeze(0) + self.patch_embeds = imo_embeds.clone()[1:].cpu().detach().numpy() + + # Ground image embedding (bio CLIP model) + query_embeds = self.generate_query_embeds(img, imo, txt, sound=sound, modality=modality) + + # Same logit scale as in SatBind + logit_scale = self.model_local.logit_scale.exp() + sim = query_embeds @ imo_embeds.t() * logit_scale + # Sigmoid to get similarity scores + scores = sim.t().sigmoid() # (num_tokens, 1) + + # Exclude [CLS] token at index 0 + score_no_cls = scores[1:].squeeze() # shape: (num_tokens-1,) + num_tokens = score_no_cls.shape[0] + side_dim = int(num_tokens**0.5) + sim_scores = score_no_cls.reshape(side_dim, side_dim).clone() + sim_scores = sim_scores.cpu().detach().numpy() + + self.clip_inference_time = time.time() - start_time + + # Gausian Smoothing + if self.blur_kernel != (0,0): + sim_scores = cv2.GaussianBlur(sim_scores, self.blur_kernel, 0) + + # Normalize to expectation + self.heatmap_unnormalized = sim_scores + scale = len(self.target_positions) / (self.heatmap_unnormalized.sum()) + self.heatmap_unnormalized *= scale + if self.heatmap_unnormalized_initial is None: + self.heatmap_unnormalized_initial = self.heatmap_unnormalized.copy() + + # Standard normalization to (0,1) + self.heatmap = sim_scores.copy() + self.heatmap = (self.heatmap - self.heatmap.min()) / (self.heatmap.max() - self.heatmap.min()) + + + def visualize_heatmap( + self, + step: int, + img_path_viz: str, + imo_path_viz: str, + patch_idx_viz: torch.Tensor, + patch_is_pos: list, + species_name: str + ): + """ + Visualization function that plots the ground image, satellite image with + highlighted patch, and the learned heatmap. + """ + + # Switch off gradients for visualization + with torch.no_grad(): + side_dim = self.heatmap.shape[0] + + # ----------------------------------------------------------------- + # Highlight the patch in the satellite image + sat_img_orig = Image.open(imo_path_viz) + sat_highlight = np.array( + self.dataset.debug_imo_viz_transform(sat_img_orig.copy()) + ) + + for idx, patch_idx in enumerate(patch_idx_viz): + + # Because patch_idx includes the [CLS] offset, subtract 1 + patch_idx_actual = patch_idx - 1 + + # Get dimensions (H x W) + H, W = sat_highlight.shape[0], sat_highlight.shape[1] + + # Number of patches in each dimension + patches_per_col = W // config.patch_size + patches_per_row = H // config.patch_size + + # Determine row/col in the patch grid + patch_row = patch_idx_actual // patches_per_col + patch_col = patch_idx_actual % patches_per_row + + # Pixel boundaries + x_start = patch_col * config.patch_size + x_end = (patch_col + 1) * config.patch_size + y_start = patch_row * config.patch_size + y_end = (patch_row + 1) * config.patch_size + + # Blue color for positive patches (transparent) + if patch_is_pos[idx]: + sat_highlight[y_start:y_end, x_start:x_end, 0] = 0 + sat_highlight[y_start:y_end, x_start:x_end, 1] = 0 + sat_highlight[y_start:y_end, x_start:x_end, 2] = 255 + # Red color for negative patches (transparent) + else: + sat_highlight[y_start:y_end, x_start:x_end, 0] = 255 + sat_highlight[y_start:y_end, x_start:x_end, 1] = 0 + sat_highlight[y_start:y_end, x_start:x_end, 2] = 0 + + + # ----------------------------------------------------------------- + # Plot results + fig, axes = plt.subplots(1, 3, figsize=(12, 6)) + fig.suptitle(f"Query: {species_name}") + + # Ground image + img_orig = Image.open(img_path_viz) + axes[0].imshow(img_orig) + axes[0].set_title("Ground Image") + axes[0].axis("off") + + # Satellite image + axes[1].imshow(sat_highlight) + axes[1].set_title("Sat Image") + axes[1].axis("off") + + # Heatmap + heatmap_np = self.heatmap_unnormalized + im = axes[2].imshow(heatmap_np, cmap="viridis") + axes[2].set_title( + f"Heatmap at TTA Step {step:03d} ({side_dim}x{side_dim})" + ) + axes[2].axis("off") + fig.colorbar(im, ax=axes[2], fraction=0.046, pad=0.04) + + plt.tight_layout() + plt.show() + + +if __name__ == "__main__": + + ########################################### + # Example with Image Modality + ########################################### + + clip_seg_tta = ClipSegTTA( + img_dir="/mnt/hdd/avs_bench_ds/inat21", + imo_dir="/mnt/hdd/avs_bench_ds/sat_jpg/train_512px", + json_path="/mnt/hdd/avs_bench_ds/inat21/train.json", + sat_to_img_ids_path="search_tri_modal|val_in_domain", + patch_size=14, + load_pretrained_hf_ckpt=True, + sat_checkpoint_path="", + sample_index=261, + batch_size=1, + num_workers=1, + device="cuda", + sat_to_img_ids_json_is_train_dict=False, + query_modality="image" + ) + + # Image modality test + patch_indices = [50, 357] + patch_is_pos = [True, False] + pos_sample_weight = 1.0 + neg_sample_weight = 1.0 + clip_seg_tta.execute_tta( + patch_indices, + patch_is_pos, + pos_sample_weight, + neg_sample_weight, + tta_steps=10, # for sanity check + num_viz_steps=2, + viz_heatmap=True + ) + + ########################################### + # Example with Sound Modality + ########################################### + + # # Sound Modality Test + # clip_seg_tta = ClipSegTTA( + # img_dir="/mnt/hdd/avs_bench_ds/inat21", + # imo_dir="/mnt/hdd/avs_bench_ds/sat_jpg/train_512px", + # json_path="/mnt/hdd/avs_bench_ds/inat21/train.json", + # sat_to_img_ids_path="search_quad_modal|val_in_domain", + # sound_dir='/mnt/hdd/avs_bench_ds/sound_mp3/test', + # patch_size=14, + # sat_checkpoint_path="", + # sound_checkpoint_path = "", + # sample_index=120, + # batch_size=1, + # num_workers=1, + # device="cuda", + # sat_to_img_ids_json_is_train_dict=False, + # query_modality="sound" + # ) + + # # Sound modality test + # patch_indices = [422, 32] + # patch_is_pos = [True, False] + # pos_sample_weight = 1.0 + # neg_sample_weight = 1.0 + # clip_seg_tta.execute_tta( + # patch_indices, + # patch_is_pos, + # pos_sample_weight, + # neg_sample_weight, + # tta_steps=30, + # num_viz_steps=2, + # viz_heatmap=True + # ) \ No newline at end of file diff --git a/taxabind_avs/satbind/clip_vision_per_patch_model.py b/taxabind_avs/satbind/clip_vision_per_patch_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d7461295d79f9cf82904bdf59120827e0d954c53 --- /dev/null +++ b/taxabind_avs/satbind/clip_vision_per_patch_model.py @@ -0,0 +1,32 @@ +############################################################################## +# Name: clip_vision_per_patch_model.py +# +# - Overloads CLIP template with custom functions +############################################################################### + +import torch +from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig + +class CLIPVisionPerPatchModel(CLIPVisionModelWithProjection): + """ + Like CLIPVisionModelWithProjection but returns + per-patch embeddings instead of pooled CLS tokens. + """ + def __init__(self, config: CLIPVisionConfig): + super().__init__(config) + # everything else (self.vision_model, self.visual_projection) + # is set up for you by the parent class + + def forward(self, pixel_values, **kwargs): + # 1) run the ViT backbone β†’ last_hidden_state [B, n_patches, hidden_size] + outputs = self.vision_model(pixel_values, return_dict=True, **kwargs) + hidden_states = outputs.last_hidden_state + + # 2) project every patch token β†’ [B, n_patches, projection_dim] + patch_embeds = self.visual_projection(hidden_states) + + # 3) Postprocessing embeds + patch_embeds = torch.nn.functional.normalize(patch_embeds, dim=-1) + patch_embeds = patch_embeds.squeeze() # (Patches, proj_dim) + + return patch_embeds \ No newline at end of file diff --git a/taxabind_avs/satbind/config_sat.py b/taxabind_avs/satbind/config_sat.py new file mode 100644 index 0000000000000000000000000000000000000000..f1b0cd3ecfd52860c9ea3e208ad0bc0e967aa25c --- /dev/null +++ b/taxabind_avs/satbind/config_sat.py @@ -0,0 +1,45 @@ +############################################################################## +# Name: config_sat.py +# +# - Parameters for train/eval sat encoder +############################################################################### + +from easydict import EasyDict as edict + +config = edict() + +# Pixel level CLIP training +config.img_dir = '/mnt/hdd/avs_bench_ds/inat21' +config.imo_dir = '/mnt/hdd/avs_bench_ds/sat_jpg/train_512px' +config.imo_dir_val = '/mnt/hdd/avs_bench_ds/sat_jpg/test_512px' +config.train_json_path = '/mnt/hdd/avs_bench_ds/inat21/train.json' +config.val_json_path = '/mnt/hdd/avs_bench_ds/inat21/val.json' +config.sat_to_img_ids_train_json_path = 'clip_tri_modal|train' +config.sat_to_img_ids_val_json_path = 'clip_tri_modal|val' + +# batch_size * accumulate_grad_batches * devices = constant (i.e. 256 * 8 * 2 = 4096) +config.batch_size = 32 +config.lr = 1e-4 +config.accumulate_grad_batches = 64 +config.max_epochs = 20 +config.num_workers = 16 +config.devices = 2 +config.val_check_interval = 0.5 +config.sat_encoder = 'openai/clip-vit-large-patch14-336' +config.avs_dataset = 'derektan95/avs-bench' +config.patch_size = 14 + +config.save_dir = 'checkpoints' +config.filename = 'satbind-{epoch:02d}-{val_loss:.2f}' + +config.locked_tuning = True + +config.resume_from_checkpoint = False +config.resume_checkpoint_name = 'satbind-resume' + +# huggingface finetuned +config.image_encoder_finetuned = 'hf-hub:imageomics/bioclip' +config.sound_encoder_finetuned = 'derektan95/search-tta-sound' # For eval only +config.sat_encoder_finetuned = 'derektan95/search-tta-sat' # For eval only + +print("config: \n", config) \ No newline at end of file diff --git a/taxabind_avs/satbind/dataset.py b/taxabind_avs/satbind/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..96ae92418070dfe56072d123abcba1a35a4e4acb --- /dev/null +++ b/taxabind_avs/satbind/dataset.py @@ -0,0 +1,215 @@ +############################################################################## +# Name: dataset.py +# +# - Handles loading of trimodal dataset +# - https://huggingface.co/datasets/derektan95/avs-bench +############################################################################### + +import os +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +import math +import json +import torch +from pathlib import Path +from torch.utils.data import Dataset +from torchvision import transforms +from PIL import Image +from datasets import load_dataset +from config_sat import config + + +class SatNatDataset(Dataset): + def __init__(self, img_dir, imo_dir, json_path, sat_to_img_ids_path, patch_size, mode='train', get_img_path=False, sat_to_img_ids_json_is_train_dict=True, tax_to_filter_val="", sound_dir=None): + self.img_dir = img_dir + self.imo_dir = imo_dir + self.patch_size = patch_size + self.get_img_path = get_img_path + self.mode = mode + self.sat_to_img_ids_json_is_train_dict = sat_to_img_ids_json_is_train_dict + self.tax_to_filter_val = tax_to_filter_val + self.sound_dir = sound_dir + self.current_epoch = 0 + + ## NOTE: Removed as unnecessary for app.py + # self.json = json.load(open(json_path, 'r')) + # self.images = self.json['images'] + # self.annot = self.json['annotations'] + # for i in range(len(self.images)): + # assert self.images[i]['id'] == self.annot[i]['id'] + # self.images[i]['label'] = self.annot[i]['category_id'] + # self.filtered_json = [d for d in self.images if d['latitude'] is not None and d['longitude'] is not None] + # self.species_text = list(set([" ".join(d['file_name'].split("/")[1].split("_")[1:]) for d in self.filtered_json])) + # self.inat_json_dict = { + # "images": {img["id"]: img for img in self.images}, + # "annotations": {ann["id"]: ann for ann in self.annot}, + # } + + # # Load from huggingface dataset + # ds_config = sat_to_img_ids_path.split("|")[0] + # ds_split = sat_to_img_ids_path.split("|")[1] + # self.sat_to_img_ids_json = load_dataset(config.avs_dataset, name=ds_config, split=ds_split) + # print("Loaded huggingface dataset: ", ds_config, ds_split) + + # # Expand dict + # self.sat_to_img_ids_tuples = [] + # if self.sat_to_img_ids_json_is_train_dict: + # # Convert from a huggingface list of dicts into dict of dicts (no duplicate keys) + # self.sat_to_img_ids_json = {sat_sample["sat_key"]: sat_sample for sat_sample in self.sat_to_img_ids_json} + # for sat_key, sat_sample in self.sat_to_img_ids_json.items(): + # id = sat_sample["id"] + # sat_path = sat_sample["sat_path"] + # img_ids = sat_sample["img_ids"] + # for img_id in img_ids: + # self.sat_to_img_ids_tuples.append((id, sat_path, img_id)) + # print("len(self.sat_to_img_ids_json): ", len(self.sat_to_img_ids_json)) + # print("len(self.sat_to_img_ids_tuples): ", len(self.sat_to_img_ids_tuples)) + # else: + # self.filtered_val_ds_by_tax = [d for d in self.sat_to_img_ids_json if self.tax_to_filter_val in d['taxonomy']] + + if mode == 'train': + self.img_transform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.RandomCrop((224, 224)), + transforms.RandomHorizontalFlip(0.5), + transforms.GaussianBlur(5, (0.01, 1.0)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + self.imo_transform = transforms.Compose([ + transforms.Resize((336,336)), + transforms.GaussianBlur(5, (0.01, 1.0)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + else: + self.img_transform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.CenterCrop((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + self.imo_transform = transforms.Compose([ + transforms.Resize((336,336)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + self.debug_img_viz_transform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.CenterCrop((224, 224)) + ]) + self.debug_imo_viz_transform = transforms.Compose([ + transforms.Resize((336,336)) + ]) + + + def __len__(self): + return len(self.sat_to_img_ids_tuples) + + + def __getitem__(self, idx): + + if not self.sat_to_img_ids_json_is_train_dict: + print("Json is not dict. Please reformat for training!") + exit() + + ## Pixel-level CLIP + id, sat_path, img_id = self.sat_to_img_ids_tuples[idx] + imo_path = os.path.join(self.imo_dir, sat_path) + imo = self.imo_transform(Image.open(imo_path)) + sat_id = Path(sat_path).stem + + img_path = os.path.join(self.img_dir, self.inat_json_dict["images"][img_id]["file_name"]) + img = self.img_transform(Image.open(img_path)) + + # # Map lat-lon to pixel in sat img + sat_min_lon = self.sat_to_img_ids_json[sat_id]["sat_bounds"]["min_lon"] + sat_min_lat = self.sat_to_img_ids_json[sat_id]["sat_bounds"]["min_lat"] + sat_max_lon = self.sat_to_img_ids_json[sat_id]["sat_bounds"]["max_lon"] + sat_max_lat = self.sat_to_img_ids_json[sat_id]["sat_bounds"]["max_lat"] + + img_lon = self.inat_json_dict["images"][img_id]["longitude"] + img_lat = self.inat_json_dict["images"][img_id]["latitude"] + row, col = self.latlon_to_pixel(img_lat, img_lon, sat_min_lat, sat_max_lat, sat_min_lon, sat_max_lon, imo.shape[2], imo.shape[1]) + + patch_idx = self.pixel_to_patch_idx(row, col, self.patch_size, imo.shape[2], imo.shape[1]) + patch_idx += 1 # account for [CLS] token at the start of ViT input sequence + + species_text = " ".join(self.inat_json_dict["images"][img_id]['file_name'].split("/")[1].split("_")[1:]) + + if self.get_img_path: + return img_path, imo_path, img, imo, self.inat_json_dict["annotations"][img_id]['category_id'], patch_idx, species_text, self.species_text.index(species_text) + else: + return img, imo, self.inat_json_dict["annotations"][img_id]['category_id'], patch_idx, species_text, self.species_text.index(species_text) + + + def latlon_to_pixel(self, lat, lon, lat_min, lat_max, lon_min, lon_max, img_width, img_height): + lat_res = (lat_max - lat_min) / img_height + lon_res = (lon_max - lon_min) / img_width + col = int(math.floor((lon - lon_min) / lon_res)) + row = int(math.floor((lat_max - lat) / lat_res)) + return row, col + + + def pixel_to_patch_idx(self, row, col, patch_size, img_width, img_height): + patch_size_width = patch_size + patch_size_height = patch_size + patch_row = row // patch_size_height + patch_col = col // patch_size_width + patch_idx = patch_row * (img_width // patch_size) + patch_col + return patch_idx + + + def set_epoch(self, epoch): + self.current_epoch = epoch + + + def get_search_ds_data(self, idx): + + if self.sat_to_img_ids_json_is_train_dict: + print("Json is dict. Please reformat for target search!") + exit() + + bounded_idx = idx % len(self.filtered_val_ds_by_tax) + sat_sample = self.filtered_val_ds_by_tax[bounded_idx] + target_positions = sat_sample["target_positions"] + imo_path = os.path.join(self.imo_dir, sat_sample["sat_path"]) + imo = self.imo_transform(Image.open(imo_path)) + + img_paths = [] + imgs = [] + species_texts = [] + for img_id in sat_sample["img_ids"]: + img_path = os.path.join(self.img_dir, self.inat_json_dict["images"][img_id]["file_name"]) + img = self.img_transform(Image.open(img_path)) + img_paths.append(img_path) + imgs.append(img) + species_text = " ".join(self.inat_json_dict["images"][img_id]['file_name'].split("/")[1].split("_")[1:]) + species_texts.append(species_text) + imgs = torch.stack(imgs) + + if len(set(species_texts)) > 1: + print("Species mismatch in search dataset!") + exit() + else: + species_name = species_texts[0] + gt_mask_name = str(sat_sample["id"]) + "_" + sat_sample["taxonomy"] + ".png" + gt_mask_name = gt_mask_name.replace(" ", "_") + + # Consider sound if valid + sounds, sound_ids = [], [] + if self.sound_dir is not None and "sound_ids" in sat_sample: + sound_id = sat_sample["sound_ids"][0] + sound_path = os.path.join(self.sound_dir,"sounds_mp3",str(sound_id)+"."+'mp3') + + from soundbind.sound_encoder import get_audio_clap + sound = get_audio_clap(sound_path) + sounds.append(sound) + sound_ids.append(sound_id) + + return img_paths, imo_path, imgs, imo, sounds, sound_ids, species_name, target_positions, gt_mask_name \ No newline at end of file diff --git a/taxabind_avs/satbind/kmeans_clustering.py b/taxabind_avs/satbind/kmeans_clustering.py new file mode 100644 index 0000000000000000000000000000000000000000..e4f4657e6570900555fa3928d864e3c2a5b19493 --- /dev/null +++ b/taxabind_avs/satbind/kmeans_clustering.py @@ -0,0 +1,324 @@ +############################################################################## +# Name: kmeans_clustering.py +# +# - Performs k-means clustering on a patch embedding matrix +############################################################################### + +import math +import os +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +from sklearn.cluster import KMeans +from sklearn.metrics import silhouette_score +from kneed import KneeLocator + + +class CombinedSilhouetteInertiaClusterer: + def __init__( + self, + k_min=1, + k_max=8, + k_avg_max=4, + silhouette_threshold=0.15, + relative_threshold=0.15, + random_state=0, + min_patch_size=5, + n_smooth_iter=2, + ignore_label=-1, + plot=False, + gifs_dir = "./" + ): + """ + Parameters + ---------- + k_min : int + Minimum number of clusters for KMeans. + k_max : int + Maximum number of clusters for KMeans. + k_avg_max : int + Upper bound on k after combining elbow & silhouette if they disagree. + silhouette_threshold : float + Minimum silhouette score at k=2 to justify splitting. + relative_threshold : float + Minimum % improvement in inertia from k=1β†’k=2 to justify splitting. + random_state : int + RNG seed for KMeans. + min_patch_size : int + Patches smaller than this threshold are smoothed. + n_smooth_iter : int + Number of smoothing iterations. + ignore_label : int + Label to ignore in smoothing step. + """ + self.k_min = k_min + self.k_max = k_max + self.k_avg_max = k_avg_max + self.silhouette_threshold = silhouette_threshold + self.relative_threshold = relative_threshold + self.random_state = random_state + + self.min_patch_size = min_patch_size + self.n_smooth_iter = n_smooth_iter + self.ignore_label = ignore_label + self.plot = False #plot + self.gifs_dir = gifs_dir + + self.final_k = None + self.final_labels_1d = None + self.smoothed_labels_2d = None + self.kmeans_frame_files = [] + + ############################## + # Helper functions + ############################## + + def combined_silhouette_inertia_clustering( + self, + X, + k_min=1, + k_max=8, + k_avg_max=4, + silhouette_threshold=0.2, + relative_threshold=0.05, + random_state=0 + ): + """ + Runs KMeans for k in [k_min..k_max] exactly once each, + collects silhouette scores & inertias, and returns best_k. + """ + n_samples = len(X) + if n_samples < 2: + return 1, np.zeros(n_samples, dtype=int), [None], [None] + + # --- Fit once for k=1 --- + km1 = KMeans(n_clusters=1, random_state=random_state).fit(X) + inertia_k1 = km1.inertia_ / n_samples + silhouette_k1 = None # undefined for k=1 + + # If k_max=1, no reason to check further + if k_max < 2: + return 1, km1.labels_, [silhouette_k1], [inertia_k1] + + # --- Fit once for k=2 --- + km2 = KMeans(n_clusters=2, random_state=random_state).fit(X) + inertia_k2 = km2.inertia_ / n_samples + sil_k2 = silhouette_score(X, km2.labels_) + relative_improvement = (inertia_k1 - inertia_k2) / inertia_k1 + + # If improvement is too small or silhouette is too low => remain at k=1 + if (relative_improvement < relative_threshold) or (sil_k2 < silhouette_threshold): + return 1, km1.labels_, [silhouette_k1, sil_k2], [inertia_k1, inertia_k2] + + # --- Otherwise fit k=2..k_max and gather inertias & silhouettes --- + all_k = range(2, k_max + 1) + kmeans_models = {} + inertias = [] + silhouettes = [] + + # We already have k=2 + kmeans_models[2] = km2 + inertias.append(inertia_k2) + silhouettes.append(sil_k2) + + for k in range(3, k_max + 1): + km = KMeans(n_clusters=k, random_state=random_state).fit(X) + kmeans_models[k] = km + + norm_inertia = km.inertia_ / n_samples + inertias.append(norm_inertia) + + # If k>n_samples, silhouette_score is meaningless, but in normal usage k< default to k=1.") + best_k_elbow = 1 # fallback + + print(f"Silhouette-based best_k={best_k_sil}, elbow-based best_k={best_k_elbow}") + + # Combine if there's disagreement + if best_k_sil == best_k_elbow: + final_k = max(1, min(best_k_sil, k_avg_max)) # best_k_sil + else: + avg_k = 0.5 * (best_k_sil + best_k_elbow) + final_k = int(math.ceil(avg_k)) + final_k = max(1, min(final_k, k_avg_max)) + + assert (final_k <= k_avg_max), f"Final k={final_k} is greater than k_avg_max={k_avg_max}" + + # Get final labels from the chosen KMeans model + if final_k == 1: + final_labels = km1.labels_ + else: + final_labels = kmeans_models[final_k].labels_ + + return final_k, final_labels, [silhouette_k1] + silhouettes, [inertia_k1] + inertias + + + def compute_region_statistics(self, label_map, heatmap, visited_indices, episode_num=0, step_num=0): + """ + Computes region statistics for the current smoothed label map. + """ + # Flatten the cluster map and the heatmap to handle indexing uniformly + label_map_2d = label_map + label_map_1d = label_map.ravel() + heatmap_1d = heatmap.ravel() + + # Identify unique labels (excluding ignore_label if present) + unique_labels = np.unique(label_map_1d) + region_dict = {} + for lbl in unique_labels: + if lbl == self.ignore_label: + continue + region_dict[lbl] = { + 'num_patches': 0, + 'patches_visited': 0, + 'expectation': 0.0 + } + + # Accumulate totals for all patches + total_patches = len(label_map_1d) + for i in range(total_patches): + lbl = label_map_1d[i] + if lbl == self.ignore_label: + continue + region_dict[lbl]['num_patches'] += 1 + region_dict[lbl]['expectation'] += float(heatmap_1d[i]) + + # # Exponential distribution (waiting time) = num_patches / expected_num_tgts + for lbl in region_dict: + region_dict[lbl]['expectation'] = region_dict[lbl]['num_patches'] / region_dict[lbl]['expectation'] + + # Count only unique visited patches by converting to a set. + unique_visited = set(visited_indices) + for vi in unique_visited: + if vi < 0 or vi >= total_patches: + continue + lbl = label_map_1d[vi] + if lbl == self.ignore_label: + continue + region_dict[lbl]['patches_visited'] += 1 + + if self.plot: + self.plot_cluster_map(label_map_2d, heatmap, visited_indices, region_dict, episode_num, step_num) + + return region_dict + + + def plot_cluster_map(self, cluster_map, heatmap, path_taken, region_stats_dict, episode_num, step_num, cmap='tab20'): + + # 4) Plot (side-by-side) if requested + fig, axes = plt.subplots(1, 3, figsize=(12, 6)) + + axes[0].imshow(cluster_map, cmap='tab20') + axes[0].set_title(f"Raw KMeans Clusters") + axes[0].axis('off') + + axes[1].imshow(heatmap, cmap="viridis") + axes[1].set_title("Heatmap") + axes[1].axis('off') + + axes[2].imshow(cluster_map, cmap='tab20') + axes[2].set_title("Raw KMeans Clusters") + axes[2].axis('off') + + path_rows, path_cols = [], [] + for i, idx in enumerate(path_taken): + rr = idx // cluster_map.shape[1] + cc = idx % cluster_map.shape[1] + path_rows.append(rr) + path_cols.append(cc) + axes[2].plot(path_cols, path_rows, c="r", linewidth=2) + axes[2].plot(path_cols[-1], path_rows[-1], markersize=12, zorder=99, marker="^", ls="-", c="r", mec="black") + axes[2].plot(path_cols[0], path_rows[0], 'co', c="r", markersize=8, zorder=5) + + # Create legend patches for each region. + unique_labels = sorted(region_stats_dict.keys()) + max_label = max(unique_labels) if unique_labels else 1 + cm = plt.get_cmap(cmap) + legend_patches = [] + for lbl in unique_labels: + # Normalize the label to [0,1] for colormap lookup. + norm_value = lbl / max_label if max_label > 0 else 0.5 + color = cm(norm_value) + patch = mpatches.Patch(color=color, label=f"R{lbl}") + legend_patches.append(patch) + + # Add legends to both subplots. + axes[0].legend(handles=legend_patches, title="Regions", loc='upper right') + axes[2].legend(handles=legend_patches, title="Regions", loc='upper right') + + # Build the legend text for each region using the provided format: + legend_lines = [] + for label, stats in region_stats_dict.items(): + line = f"R{label}: patches={stats['num_patches']}, E={stats['expectation']:.3f}, visited={stats['patches_visited']}" + legend_lines.append(line) + legend_text = "\n".join(legend_lines) + + # Add the legend text as a subtitle at the bottom of the figure + fig.text(0.5, 0.05, legend_text, ha='center', va='bottom', fontsize=10, + bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray')) + + # Adjust layout to reserve space for the legend text + plt.tight_layout() + plt.subplots_adjust(bottom=0.1) + + if not os.path.exists(self.gifs_dir): + os.makedirs(self.gifs_dir) + + plt.savefig(f'{self.gifs_dir}/kmeans_{episode_num}_{step_num}.png'.format(dpi=150)) + self.kmeans_frame_files.append(f'{self.gifs_dir}/kmeans_{episode_num}_{step_num}.png') + plt.close() + + + def get_label_id(self, label_map, patch_idx): + return label_map.ravel()[patch_idx] + + + def get_probs(self, patch_idx, heatmap): + return heatmap.ravel()[patch_idx] + + + ############################## + # Main functions + ############################## + def fit_predict(self, patch_embeds, map_shape): + """ + Main function to obtain smoothed labelmap + """ + # 1) Run combined silhouette & inertia + best_k, final_labels, silhouettes, inertias = self.combined_silhouette_inertia_clustering( + X=patch_embeds, + k_min=self.k_min, + k_max=self.k_max, + k_avg_max=self.k_avg_max, + silhouette_threshold=self.silhouette_threshold, + relative_threshold=self.relative_threshold, + random_state=self.random_state + ) + self.final_k = best_k + self.final_labels_1d = final_labels.copy() + + # 2) Reshape for display + H, W = map_shape + cluster_map = final_labels.reshape(H, W) + + # 3) Apply smoothing + cluster_map_smoothed = cluster_map + self.smoothed_labels_2d = cluster_map_smoothed.copy() + + # 5) Return the smoothed 2D labels + return self.smoothed_labels_2d diff --git a/taxabind_avs/satbind/model_sat.py b/taxabind_avs/satbind/model_sat.py new file mode 100644 index 0000000000000000000000000000000000000000..8b00d7cd0b05cc712b7dc776d5cb54db4acc7ffb --- /dev/null +++ b/taxabind_avs/satbind/model_sat.py @@ -0,0 +1,190 @@ +############################################################################## +# Name: model_sat.py +# +# - Training wrapper for satellite image CLIP model +############################################################################### + +import os +import sys +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +import open_clip +import pytorch_lightning as pl +import torch +import torch.nn as nn +import numpy as np +from torch.utils.data import DataLoader +from transformers import CLIPVisionModelWithProjection +from dataset import SatNatDataset +from pytorch_lightning.callbacks import ModelCheckpoint +from config_sat import config + + +def create_pairwise_mask(labels): + labels = labels.reshape(-1) + num_samples = len(labels) + pairwise_mask = torch.zeros(num_samples, num_samples).to(labels.device) + + for i in range(num_samples): + pairwise_mask[i, :] = (labels == labels[i]) + + return pairwise_mask + +def clip_loss(similarity: torch.Tensor, label) -> torch.Tensor: + overhead_img_loss = contrastive_loss(similarity, label) + ground_img_loss = contrastive_loss(similarity.t(), label.t()) + return 0.5*torch.mean(torch.sum(overhead_img_loss, dim=-1)) + 0.5*torch.mean(torch.sum(ground_img_loss, dim=-1)) + +def contrastive_loss(logits: torch.Tensor, label) -> torch.Tensor: + gt = create_pairwise_mask(label) + return -gt*torch.log(logits.softmax(-1)+1e-6) + + +class SatBind(pl.LightningModule): + def __init__(self, train_dataset, val_dataset, **kwargs): + super().__init__() + self.train_dataset = train_dataset + self.val_dataset = val_dataset + + #initialize bio CLIP with frozen weights + self.bio_model, *_ = open_clip.create_model_and_transforms(config.image_encoder_finetuned) + if config.locked_tuning: + for param in self.bio_model.parameters(): + param.requires_grad = False + + #initialize CLIP with trainable weights + self.imo_encoder = CLIPVisionModelWithProjection.from_pretrained(config.sat_encoder).train() + for layer in self.imo_encoder.children(): + if hasattr(layer, 'reset_parameters'): + layer.reset_parameters() + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.batch_size = kwargs.get('batch_size', config.batch_size) + self.lr = kwargs.get('lr', config.lr) + + # Custom + clip_cfg = self.imo_encoder.config + self.visual_projection_custom = nn.Linear(clip_cfg.hidden_size, 512, bias=False) # clip_cfg.projection_dim) + + + def forward(self, batch): + img, imo, label, patch_idx, *_ = batch + batch_size = img.shape[0] + + #compute bioclip embeddings + img_embeds, *_ = self.bio_model(img) # (batch_size, proj_dim) + + # Similarity computation + imo_embeds = self.imo_encoder(imo).last_hidden_state # (batch, Patches, hidden_dim) + imo_embeds = imo_embeds[torch.arange(batch_size), patch_idx] # (batch, hidden_dim) + imo_embeds = self.visual_projection_custom(imo_embeds) # (batch_size, proj_dim) + + return img_embeds, imo_embeds, label + + + def shared_step(self, batch, return_sim_matrix=False): + + img_embeds, imo_embeds, label, *_ = self(batch) + imo_embeds = torch.nn.functional.normalize(imo_embeds, dim=-1) + + #exponentiate the log of temperrature + logit_scale = self.logit_scale.exp() + + #compute similarity + img_to_imo_sim = img_embeds @ imo_embeds.t() * logit_scale + + if return_sim_matrix: + img_to_imo_sim_copy = img_to_imo_sim.clone().detach() + + loss = clip_loss(img_to_imo_sim, label) + + if return_sim_matrix: + return loss, img_to_imo_sim_copy + else: + return loss + + + def training_step(self, batch, batch_idx): + loss = self.shared_step(batch) + self.log('train_loss', loss, sync_dist=True, prog_bar=True, on_epoch=True, batch_size=self.batch_size) + self.log('temperature', self.logit_scale.data, prog_bar=True, on_epoch=True, batch_size=self.batch_size) + return loss + + def validation_step(self, batch, batch_idx): + loss = self.shared_step(batch) + self.log('val_loss', loss, sync_dist=True, prog_bar=True, on_epoch=True, batch_size=self.batch_size) + return loss + + def train_dataloader(self): + return DataLoader(self.train_dataset, + batch_size=self.batch_size, + num_workers=config.num_workers, + shuffle=True, # True + persistent_workers=False) + + def val_dataloader(self): + return DataLoader(self.val_dataset, + batch_size=self.batch_size, + num_workers=config.num_workers, + shuffle=False, + persistent_workers=False) + + def configure_optimizers(self): + params = self.parameters() + self.optim = torch.optim.AdamW(params, + lr=self.lr, + betas=(0.9,0.98), + eps=1e-6 + ) + self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer=self.optim, + T_0=20, + eta_min=1e-6 + ) + return [self.optim], [self.scheduler] + + +if __name__ == '__main__': + img_dir = config.img_dir + imo_dir = config.imo_dir + imo_dir_val = config.imo_dir_val + train_json_path = config.train_json_path + val_json_path = config.val_json_path + sat_to_img_ids_train_json_path = config.sat_to_img_ids_train_json_path + sat_to_img_ids_val_json_path = config.sat_to_img_ids_val_json_path + patch_size = config.patch_size + + #define dataset + train_dataset = SatNatDataset(img_dir, imo_dir, train_json_path, sat_to_img_ids_train_json_path, patch_size) + val_dataset = SatNatDataset(img_dir, imo_dir_val, val_json_path, sat_to_img_ids_val_json_path, patch_size, mode='val') + + #define model + model = SatBind(train_dataset=train_dataset, val_dataset=val_dataset) + torch.cuda.empty_cache() + + checkpoint = ModelCheckpoint( + monitor='val_loss', + dirpath=config.save_dir, + filename=config.filename, + mode='min', + save_top_k=1, + save_last=True + ) + checkpoint.CHECKPOINT_NAME_LAST = config.filename + "-LAST" + + trainer = pl.Trainer( + accelerator='gpu', + strategy='ddp_find_unused_parameters_true', # supress pl issues with 'unused trainable params' + devices=config.devices, + max_epochs=config.max_epochs, + num_nodes=1, + callbacks=[checkpoint], + accumulate_grad_batches=config.accumulate_grad_batches, + log_every_n_steps=1, + val_check_interval=config.val_check_interval, + ) + + if config.resume_from_checkpoint: + trainer.fit(model, ckpt_path=f"{config.save_dir}/{config.resume_checkpoint_name}.ckpt") + else: + trainer.fit(model) \ No newline at end of file diff --git a/taxabind_avs/scripts/download_sat_imgs.py b/taxabind_avs/scripts/download_sat_imgs.py new file mode 100644 index 0000000000000000000000000000000000000000..394f429fec75d5688e16e1926fa7f87e40bb85ac --- /dev/null +++ b/taxabind_avs/scripts/download_sat_imgs.py @@ -0,0 +1,145 @@ +############################################################################## +# Name: download_sat_imgs.py +# +# - Downloads satellite images and relevant data from huggingface +# - https://huggingface.co/datasets/MVRL/iSatNat +############################################################################### + +import os +import itertools +import requests +import time +import threading +from datasets import load_dataset +from concurrent.futures import ThreadPoolExecutor, as_completed +from tqdm import tqdm +from PIL import Image, UnidentifiedImageError +from io import BytesIO + +# Load the dataset from Hugging Face +mode = "train" # train or test +ds = load_dataset("MVRL/iSatNat", split=f"{mode}") +num_rows = len(ds) +resize_size = 512 # Resize images to this size (i.e. same FOV - increase resolution) + +# Directory where images will be saved +sat_save_dir = f"/mnt/hdd/inat2021_ds/sat_{mode}" +os.makedirs(sat_save_dir, exist_ok=True) + +# Dictionary to store download failures: {key: sat_url} +download_failures = {} + +# Create a global progress bar for images processed. +pbar = tqdm(total=num_rows, desc="Images Processed") + +# Create a lock for thread-safe updates to the progress bar. +progress_lock = threading.Lock() + +def download_image(row): + """ + Download the image from row['sat_url'] and save it as sat_save_dir/{row['key']}.jpeg. + If the download fails, it will retry up to 3 times before recording a failure. + Each attempt prints a success or retry message. + The progress bar is updated once per image processed. + """ + key = row["key"] + sat_url = row["sat_url"] + file_path = os.path.join(sat_save_dir, f"{key}.jpg") + + if resize_size != 256: + sat_url = sat_url.replace("width=256", f"width={resize_size}") + sat_url = sat_url.replace("height=256", f"height={resize_size}") + + # Check if file already exists; if so, skip the download. + if os.path.exists(file_path): + print(f"SKIPPED: Image for key {key} already exists.") + with progress_lock: + pbar.update(1) + return None + + # Optional: use headers to mimic a browser. + headers = { + "User-Agent": ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/115.0.0.0 Safari/537.36" + ) + } + + max_retries = 10 + success = False + for attempt in range(1, max_retries + 1): + try: + response = requests.get(sat_url, headers=headers, timeout=10) + response.raise_for_status() # Raise an error for bad HTTP status codes. + # Convert the image to JPEG if necessary + image = Image.open(BytesIO(response.content)) + image = image.convert("RGB") # Ensure compatibility with JPEG format + image.save(file_path, "JPEG") + + # Test to see file is corrupted + with Image.open(file_path) as img: + img.verify() # Does minimal decoding, good for quick validation + success = True + break # Exit the loop if the download is successful. + + # OSError can catch issues like truncated files, permission errors, etc. + except (UnidentifiedImageError, OSError): + if attempt < max_retries: + print(f"[Corrupted] Retrying: Failed attempt {attempt} for key {key} from URL {sat_url}") + time.sleep(2) + + except Exception as e: + if attempt < max_retries: + print(f"Retrying: Failed attempt {attempt} for key {key} from URL {sat_url}") + time.sleep(2) + # else: + + + + if not success: + print(f"FAILURE: Could not download image for key {key} from URL: {sat_url} after {max_retries} attempts") + + # Update the progress bar regardless of success or failure. + with progress_lock: + pbar.update(1) + if not success: + return (key, sat_url) + return None + +def chunked_iterator(iterable, chunk_size): + """ + Yield successive chunks of size `chunk_size` from the iterable. + """ + iterator = iter(iterable) + while True: + chunk = list(itertools.islice(iterator, chunk_size)) + if not chunk: + break + yield chunk + +# Define chunk size and number of worker threads. +chunk_size = 1000 # Process 10,000 rows at a time. +max_workers = 32 # Number of threads to use in parallel. + +# Process the dataset in chunks. +try: + for chunk in chunked_iterator(ds, chunk_size): + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Submit a download task for each row in the chunk. + futures = {executor.submit(download_image, row): row for row in chunk} + # As each future completes, record any failures. + for future in as_completed(futures): + result = future.result() + if result: + key, sat_url = result + download_failures[key] = sat_url +except: + print(f"Download failures: {download_failures}") + print("len(download_failures):", len(download_failures)) + +# Close the progress bar when done. +pbar.close() + +print(f"Download failures: {download_failures}") +print("len(download_failures):", len(download_failures)) diff --git a/taxabind_avs/scripts/download_sound_and_img_pairs.py b/taxabind_avs/scripts/download_sound_and_img_pairs.py new file mode 100644 index 0000000000000000000000000000000000000000..effef3afcec085b18ce3087a84770d260f9f44d9 --- /dev/null +++ b/taxabind_avs/scripts/download_sound_and_img_pairs.py @@ -0,0 +1,361 @@ +############################################################################## +# Name: download_sound_and_img_pairs.py +# +# - Downloads sound and image pairs from huggingface +# - https://huggingface.co/datasets/MVRL/iSoundNat +############################################################################### + +import os +import itertools +import requests +import time +import threading +import ffmpeg +import pandas as pd +from concurrent.futures import ThreadPoolExecutor, as_completed +from tqdm import tqdm # progress bar +from PIL import Image + +############################## +# SETUP: DIRECTORIES & DATASET +############################## + +mode = "train" # or "validation" or "test" + +# Define which split to use and CSV paths. +splits = { + 'train': 'train_df.csv', + 'validation': 'val_df.csv', + 'test': 'test_df.csv' +} +# Here we load the training CSV; adjust as needed. +df = pd.read_csv("hf://datasets/MVRL/iSoundNat/" + splits[mode]) + +# If you want to skip to a specific row (for example, row index 19000), +# then slice the DataFrame accordingly. +start_index = 0 +if start_index > 0: + df = df.iloc[start_index:].reset_index(drop=True) + +# Directories for saving images and audio files +image_save_dir = f"/mnt/hdd/inat2021_ds/sound_{mode}/images" +audio_save_dir = f"/mnt/hdd/inat2021_ds/sound_{mode}/sounds_mp3" +os.makedirs(image_save_dir, exist_ok=True) +os.makedirs(audio_save_dir, exist_ok=True) + +# Convert dataframe rows to a list of dictionaries for iteration. +rows = df.to_dict("records") +num_rows = len(rows) + +# Dictionaries to record failures for pairs (keyed by id) +image_failures = {} +audio_failures = {} + +# Global progress bar and lock (one update per pair processed) +pbar = tqdm(total=num_rows, desc="Pairs Processed") +progress_lock = threading.Lock() + +############################## +# HELPER FUNCTIONS +############################## + +def convert_image_to_jpeg(temp_path, final_path): + """ + Opens the image at temp_path using Pillow. + If its format is not JPEG, converts it. + Saves the image as JPEG to final_path. + """ + try: + with Image.open(temp_path) as im: + if im.format != "JPEG": + rgb_im = im.convert("RGB") + rgb_im.save(final_path, "JPEG") + else: + # If already JPEG, simply rename the file. + os.rename(temp_path, final_path) + except Exception as e: + print(f"Error converting image {temp_path}: {e}") + raise e + +def is_audio_corrupted(file_path): + """ + Uses ffmpeg.probe() to check if an audio file is readable. + Returns True if the file is corrupted or unreadable. + """ + try: + ffmpeg.probe(file_path) + return False + except ffmpeg.Error as e: + print(f"Error probing audio '{file_path}': {e}") + return True + +def is_mp3_format(file_path): + """ + Probes the file and checks whether 'mp3' is part of the format name. + """ + try: + info = ffmpeg.probe(file_path) + format_name = info.get("format", {}).get("format_name", "") + return "mp3" in format_name + except Exception as e: + print(f"Error checking mp3 format for '{file_path}': {e}") + return False + +def convert_to_mp3(input_file, output_file): + """ + Converts the input audio file to MP3 using the libmp3lame codec. + """ + try: + stream = ffmpeg.input(input_file) + stream = ffmpeg.output(stream, output_file, acodec="libmp3lame") + ffmpeg.run(stream, quiet=True) + except ffmpeg.Error as e: + print(f"Error converting audio '{input_file}' to MP3: {e}") + raise e + +############################## +# DOWNLOAD FUNCTIONS WITH RETRIES +############################## + +def download_image(row, image_url, image_id, image_save_path): + """ + Downloads the image from row["image_url"]. + Saves a temporary file then converts (if needed) to JPEG as {id}. + """ + + temp_path = image_save_path + ".temp" + headers = { + "User-Agent": ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/115.0.0.0 Safari/537.36" + ) + } + max_retries = 3 + success = False + for attempt in range(1, max_retries + 1): + try: + response = requests.get(image_url, headers=headers, timeout=10) + response.raise_for_status() + with open(temp_path, "wb") as f: + f.write(response.content) + success = True + break # Exit loop on success. + except Exception as e: + if attempt < max_retries: + time.sleep(2) + else: + print(f"FAILURE: Could not download image {image_id} from {image_url} after {max_retries} attempts") + success = False + if not success: + if os.path.exists(temp_path): + try: + os.remove(temp_path) + except Exception: + pass + return False + + try: + convert_image_to_jpeg(temp_path, image_save_path) + except Exception as e: + print(f"Error processing image {image_id}: {e}") + if os.path.exists(temp_path): + try: + os.remove(temp_path) + except Exception: + pass + if os.path.exists(image_save_path): + try: + os.remove(image_save_path) + except Exception: + pass + return False + finally: + if os.path.exists(temp_path): + try: + os.remove(temp_path) + except Exception: + pass + + return os.path.exists(image_save_path) + +def download_audio(row, audio_url, audio_id, audio_save_path): + """ + Downloads the audio file from row["sound_url"]. + Saves it to a temporary file, checks for corruption, and if needed converts it to MP3 + as {id}.mp3 using ffmpeg-python. + """ + + # temp_path = os.path.join(audio_save_dir, f"{audio_id}_temp") + temp_path = audio_save_path + ".temp" + headers = { + "User-Agent": ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/115.0.0.0 Safari/537.36" + ) + } + max_retries = 5 + success = False + for attempt in range(1, max_retries + 1): + try: + response = requests.get(audio_url, headers=headers, timeout=10) + response.raise_for_status() + with open(temp_path, "wb") as f: + f.write(response.content) + success = True + break + except Exception as e: + if attempt < max_retries: + time.sleep(2) + else: + print(f"FAILURE: Could not download audio {audio_id} from {audio_url} after {max_retries} attempts") + success = False + if not success: + if os.path.exists(temp_path): + try: + os.remove(temp_path) + except Exception: + pass + return False + + # Check if the downloaded audio is corrupted. + if is_audio_corrupted(temp_path): + print(f"Audio file {audio_id} is corrupted.") + try: + os.remove(temp_path) + except Exception: + pass + return False + + # Check if the audio is already in MP3 format. + if is_mp3_format(temp_path): + try: + os.rename(temp_path, audio_save_path) + except Exception as e: + print(f"Error renaming audio {audio_id}: {e}") + try: + os.remove(temp_path) + except Exception: + pass + return False + return True + else: + try: + convert_to_mp3(temp_path, audio_save_path) + except Exception as e: + print(f"Error converting audio {audio_id}: {e}") + try: + os.remove(temp_path) + except Exception: + pass + return False + finally: + if os.path.exists(temp_path): + try: + os.remove(temp_path) + except Exception: + pass + + return os.path.exists(audio_save_path) + +def download_pair(row): + """ + Downloads both the image and audio for a given row. + If either download/conversion fails, deletes any successfully downloaded file + and marks the pair as a failure. + """ + + # If the final image already exists, assume it's already downloaded. + image_url = row["image_url"] + image_id = row["id"] + img_save_path = os.path.join(image_save_dir, f"{image_id}.jpg") + + img_exists = False + if os.path.exists(img_save_path): + img_exists = True + + + # If the final audio already exists, assume it's already downloaded. + audio_url = row["sound_url"] + audio_id = row["id"] + audio_save_path = os.path.join(audio_save_dir, f"{audio_id}.mp3") + + audio_exists = False + if os.path.exists(audio_save_path): + audio_exists = True + + # Skip the download if both files already exist. + if not (img_exists and audio_exists): + + image_success = download_image(row, image_url, image_id, img_save_path) + audio_success = download_audio(row, audio_url, audio_id, audio_save_path) + + # If either download failed, delete any successfully downloaded file. + if not (image_success and audio_success): + image_failures[row["id"]] = row["image_url"] + audio_failures[row["id"]] = row["sound_url"] + success = False + else: + success = True + else: + success = True + print(f"SKIPPED: Image {image_id} and Audio {audio_id} already exists.") + + with progress_lock: + pbar.update(1) + return success + +def chunked_iterator(iterable, chunk_size): + """ + Yields successive chunks of size chunk_size from the iterable. + """ + iterator = iter(iterable) + while True: + chunk = list(itertools.islice(iterator, chunk_size)) + if not chunk: + break + yield chunk + +############################## +# PROCESS THE DATASET IN CHUNKS +############################## + +chunk_size = 999999 # Adjust based on memory and dataset size. +max_workers = 8 # Number of threads for parallel downloads. + +# Process rows in chunks using multi-threading. +try: + for chunk in chunked_iterator(rows, chunk_size): + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = {executor.submit(download_pair, row): row for row in chunk} + for future in as_completed(futures): + try: + future.result() # True if both downloads succeeded. + except Exception as e: + row = futures[future] + print(f"Error processing row {row['id']}: {e}") +except Exception as e: + print(f"An error occurred during processing: {e}") + +pbar.close() + +print("Image download failures:", image_failures) +print("Audio download failures:", audio_failures) +print("len(image_failures):", len(image_failures)) +print("len(audio_failures):", len(audio_failures)) + +############################## +# REMOVE FAILURE ROWS FROM ORIGINAL DATAFRAME AND EXPORT +############################## + +# Combine IDs from both failure dictionaries. +failure_ids = set(image_failures.keys()).union(set(audio_failures.keys())) +print(f"Total failed pairs: {len(failure_ids)}") + +# Remove failed rows from the original dataframe (preserving original order). +successful_df = df[~df["id"].isin(failure_ids)] + +output_csv = f"/mnt/hdd/inat2021_ds/sound_{mode}/sound_image_pairs_filtered.csv" +successful_df.to_csv(output_csv, index=False) +print(f"Exported {len(successful_df)} successful rows to {output_csv}") diff --git a/taxabind_avs/soundbind/config_sound.py b/taxabind_avs/soundbind/config_sound.py new file mode 100644 index 0000000000000000000000000000000000000000..109378567ff076a5cf948c16f704efa2b9ba11a3 --- /dev/null +++ b/taxabind_avs/soundbind/config_sound.py @@ -0,0 +1,31 @@ +############################################################################## +# Name: config_sound.py +# +# - Parameters for train/eval sound encoder +############################################################################### + +from easydict import EasyDict as edict + +config = edict() +config.train_df = 'clip_quad_modal|train' +config.val_df = 'clip_quad_modal|val' +config.data_path_train = '/mnt/hdd/avs_bench_ds/sound_mp3/train' +config.data_path_val = '/mnt/hdd/avs_bench_ds/sound_mp3/test' + +config.batch_size = 256 +config.lr = 1e-4 +config.accumulate_grad_batches = 8 +config.max_epochs = 20 +config.num_workers = 16 +config.devices = 2 +config.val_check_interval = 0.5 +config.sound_encoder = 'laion/clap-htsat-fused' +config.avs_dataset = 'derektan95/avs-bench' +config.save_dir = 'checkpoints' +config.filename = 'soundbind-{epoch:02d}-{val_loss:.2f}' +config.locked_tuning = True + +# huggingface finetuned +config.image_encoder_finetuned = 'hf-hub:imageomics/bioclip' + +print("config: \n", config) \ No newline at end of file diff --git a/taxabind_avs/soundbind/dataloader.py b/taxabind_avs/soundbind/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..f1829489e001732e6fd5c26852ab29f02ae4cd69 --- /dev/null +++ b/taxabind_avs/soundbind/dataloader.py @@ -0,0 +1,72 @@ +############################################################################## +# Name: dataloader.py +# +# - Handles loading of quadmodal dataset +# - https://huggingface.co/datasets/derektan95/avs-bench +############################################################################### + +import os +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +from config_sound import config +from sound_encoder import get_audio_clap +from datasets import load_dataset + + +class INatDataset(Dataset): + def __init__(self, + data_file, + mode='train'): + + # Load from huggingface dataset + ds_config = data_file.split("|")[0] + ds_split = data_file.split("|")[1] + self.data_file = load_dataset(config.avs_dataset, name=ds_config, split=ds_split).to_pandas() + print("Loaded huggingface dataset: ", ds_config, ds_split) + + self.mode = mode + if mode=='train': + self.transform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.RandomCrop((224, 224)), + transforms.RandomHorizontalFlip(0.5), + transforms.GaussianBlur(5, (0.01, 1.0)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + else: + self.transform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.CenterCrop((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + self.species_text = self.data_file['scientific_name'].tolist() + self.species_classes = list(set(self.species_text)) + print("mode: ", self.mode) + print("len(self.data_file): ", len(self.data_file)) + + def __len__(self): + return len(self.data_file) + + def get_sample(self,idx): + sample = self.data_file.iloc[idx] + id = sample.id + sound_format = sample.sound_format + data_path = config['data_path_train'] if self.mode == 'train' else config['data_path_val'] + image_path = os.path.join(data_path,"images",str(id)+".jpg") + sound_path = os.path.join(data_path,"sounds_mp3",str(id)+"."+'mp3') + sound = get_audio_clap(sound_path) + + for k in sound.keys(): + sound[k] = sound[k].squeeze(0) + image = self.transform(Image.open(image_path).convert("RGB")) + + return image, sound + + def __getitem__(self, idx): + image, sound = self.get_sample(idx) + return image, sound, self.species_classes.index(self.data_file.iloc[idx]['scientific_name']) diff --git a/taxabind_avs/soundbind/model_sound.py b/taxabind_avs/soundbind/model_sound.py new file mode 100644 index 0000000000000000000000000000000000000000..10efb1d6f90cf3e4548d0da8721a6934291359f3 --- /dev/null +++ b/taxabind_avs/soundbind/model_sound.py @@ -0,0 +1,148 @@ +############################################################################## +# Name: model_sound.py +# +# - Training wrapper for sound CLAP model +############################################################################### + +import os +import sys +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +import open_clip +import pytorch_lightning as pl +import torch +import torch.nn as nn +import numpy as np +import os +import random +from sound_encoder import CLAP_audiomodel_withProjection as AudioEncoder +from torch.utils.data import DataLoader +from config_sound import config +from dataloader import INatDataset +from pytorch_lightning.callbacks import ModelCheckpoint + + +def clip_loss(similarity: torch.Tensor) -> torch.Tensor: + audio_loss = contrastive_loss(similarity) + ground_img_loss = contrastive_loss(similarity.t()) + return 0.5*audio_loss + 0.5*ground_img_loss + +def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: + + return nn.functional.cross_entropy(logits[:logits.shape[1]], torch.arange(logits.shape[1], device=logits.device)) + +class AudioBind(pl.LightningModule): + def __init__(self, train_dataset, val_dataset, **kwargs): + super().__init__() + self.train_dataset = train_dataset + self.val_dataset = val_dataset + self.model, *_ = open_clip.create_model_and_transforms(config.image_encoder_finetuned) + if config.locked_tuning: + for param in self.model.parameters(): + param.requires_grad = False + self.audio_encoder = AudioEncoder(freeze=False) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.batch_size = kwargs.get('batch_size') + self.num_workers = kwargs.get('num_workers') + self.lr = kwargs.get('lr', 1e-4) + + def forward(self, image, audio): + with torch.no_grad(): + image_embeds, *_ = self.model(image) + unnormalized_audio_embeds = self.audio_encoder(audio) + audio_embeds = torch.nn.functional.normalize(unnormalized_audio_embeds, dim=-1) + return image_embeds, audio_embeds + + def shared_step(self, batch): + image, audio, *_ = batch + image_embeds, audio_embeds = self(image, audio) + logit_scale = self.logit_scale.exp() + logits_per_img = torch.matmul(image_embeds,audio_embeds.t())*logit_scale + cross_contrastive_loss = clip_loss(logits_per_img) + return cross_contrastive_loss + + def training_step(self, batch, batch_idx): + loss = self.shared_step(batch) + self.log('train_loss', loss, sync_dist=True, prog_bar=True, on_epoch=True, batch_size=self.batch_size) + self.log('temperature', self.logit_scale.data, prog_bar=True, on_epoch=True, batch_size=self.batch_size) + return loss + + def on_train_batch_end(self,outputs,batch, batch_idx): + if self.logit_scale.data > np.log(100): + self.logit_scale.data = torch.clamp(self.logit_scale.data, 0, np.log(100)) + + def validation_step(self, batch, batch_idx): + loss = self.shared_step(batch) + self.log('val_loss', loss, sync_dist=True, prog_bar=True, on_epoch=True, batch_size=self.batch_size) + return loss + + def train_dataloader(self): + return DataLoader(self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + persistent_workers=False) + + def val_dataloader(self): + return DataLoader(self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + persistent_workers=False) + + def configure_optimizers(self): + params = self.parameters() + self.optim = torch.optim.AdamW(params, + lr=self.lr, + betas=(0.9,0.98), + eps=1e-6 + ) + self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer=self.optim, + T_0=20 + ) + return [self.optim], [self.scheduler] + +def seed_everything(seed=42): + """ + seed: int + """ + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + os.environ["PYTHONHASHSEED"] = str(seed) + +if __name__=='__main__': + import warnings + warnings.filterwarnings("ignore") + torch.set_warn_always(False) + + seed_everything() + train_dataset = INatDataset(data_file=config.train_df, mode='train') + val_dataset = INatDataset(data_file=config.val_df, mode='val') + kwargs = {'batch_size':config.batch_size, 'num_workers': config.num_workers} + + model = AudioBind(train_dataset, val_dataset, **kwargs) + torch.cuda.empty_cache() + + checkpoint = ModelCheckpoint( + monitor='val_loss', + dirpath=config.save_dir, + filename=config.filename, + mode='min', + save_top_k=3 + ) + trainer = pl.Trainer( + accelerator='gpu', + devices=config.devices, + strategy='ddp', + max_epochs=config.max_epochs, + num_nodes=1, + callbacks=[checkpoint], + accumulate_grad_batches=config.accumulate_grad_batches, + log_every_n_steps=1 + ) + trainer.fit(model) \ No newline at end of file diff --git a/taxabind_avs/soundbind/sound_encoder.py b/taxabind_avs/soundbind/sound_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..4ed354c2485cfb77bf9863b500097de172e56c91 --- /dev/null +++ b/taxabind_avs/soundbind/sound_encoder.py @@ -0,0 +1,53 @@ +############################################################################## +# Name: sound_encoder.py +# +# - Wrapper for sound CLAP model +############################################################################### + +import os +import sys +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +import pytorch_lightning as pl +import torch +import torchaudio +from transformers import ClapProcessor +from transformers import ClapAudioModelWithProjection +from config_sound import config + + +processor = ClapProcessor.from_pretrained(config.sound_encoder) +SAMPLE_RATE = 48000 + +def get_audio_clap(path_to_audio,format="mp3",padding="repeatpad",truncation="fusion"): + track, sr = torchaudio.load(path_to_audio, format=format) + track = track.mean(axis=0) + track = torchaudio.functional.resample(track, orig_freq=sr, new_freq=SAMPLE_RATE) + output = processor(audios=track, sampling_rate=SAMPLE_RATE, max_length_s=10, return_tensors="pt",padding=padding,truncation=truncation) + return output + + +class CLAP_audiomodel_withProjection(pl.LightningModule): + def __init__(self,freeze=False): + super().__init__() + if freeze: + self.model = ClapAudioModelWithProjection.from_pretrained(config.sound_encoder).eval() + for params in self.model.parameters(): + params.requires_grad=False + else: + self.model = ClapAudioModelWithProjection.from_pretrained(config.sound_encoder).train() + def forward(self,audio): + batch_embeddings_audio = self.model(**audio)['audio_embeds'] + return batch_embeddings_audio + +if __name__ == '__main__': + path_to_audio ="/mnt/hdd/inat2021_ds/sound_train/sounds_mp3/165878447.mp3" + sample = get_audio_clap(path_to_audio) + print(sample.keys()) + + sample['input_features'] = torch.concat([sample['input_features'],sample['input_features']],axis=0) + sample['is_longer'] = torch.concat([sample['is_longer'],sample['is_longer']],axis=0) + print(sample['input_features'].shape,sample['is_longer'].shape) + model = CLAP_audiomodel_withProjection(freeze=False) + audio_feat = model(sample) + print(audio_feat.shape) \ No newline at end of file