diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..d098c62b13925420b12e93df1f4e2ba1acf3ac80
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,40 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tar filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.png filter=lfs diff=lfs merge=lfs -text
+*.jpg filter=lfs diff=lfs merge=lfs -text
+*.jpeg filter=lfs diff=lfs merge=lfs -text
+*.mp3 filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..08610c946ea133b4c49d12337e2abbef03f24766
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,38 @@
+*.idea/
+*.vscode/
+
+*__pycache__/
+*.pyc
+*.pyo
+*.pyd
+
+*.pt
+*.pth
+*.tar.gz
+*.zip
+
+!**/train/
+**/train/*
+!**/train/saved
+
+!**/inference/
+**/inference/*
+!**/inference/saved
+
+!**/maps/
+**/maps/*
+!**/maps/example
+!**/maps/gpt4o
+!**/maps/lisa
+
+# For taxabind_avs
+**/dataset/
+**/checkpoints/
+
+!**/lightning_logs/
+**/lightning_logs/*
+!**/lightning_logs/saved
+
+# Saved weights & logs
+**avs_rl_policy.pth
+**/avs_rl_policy_21.5k/*
\ No newline at end of file
diff --git a/.vscode/launch.json b/.vscode/launch.json
new file mode 100644
index 0000000000000000000000000000000000000000..6a1736dd3a13900ed66d69cb586a8fc6ce484e4f
--- /dev/null
+++ b/.vscode/launch.json
@@ -0,0 +1,15 @@
+{
+ "version": "0.2.0",
+ "configurations": [
+ {
+ "name": "Debug app.py",
+ "type": "debugpy",
+ "request": "launch",
+ "program": "${workspaceFolder}/app.py",
+ "cwd": "${workspaceFolder}",
+ "console": "integratedTerminal",
+ "justMyCode": false,
+ "python": "/home/user/anaconda3/envs/vlm-search/bin/python3"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f11219855e4a44e0f24ebd0d4e0da2f529774791
--- /dev/null
+++ b/README.md
@@ -0,0 +1,13 @@
+---
+title: Search-TTA
+emoji: π¦
+colorFrom: green
+colorTo: gray
+sdk: gradio
+sdk_version: 5.31.0
+app_file: app.py
+pinned: false
+short_description: Multimodal Test-time Adaptation Framework for Visual Search
+---
+
+Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..c59a521dedbac48cbd9e1f663decae60902555e3
--- /dev/null
+++ b/app.py
@@ -0,0 +1,528 @@
+"""
+Simplified Gradio demo for Search-TTA evaluation.
+"""
+
+# ββββββββββββββββββββββββββ imports βββββββββββββββββββββββββββββββββββ
+from pathlib import Path
+import matplotlib
+matplotlib.use("Agg", force=True)
+
+import gradio as gr
+import ctypes # for safely stopping background threads
+import os, glob, threading, time
+import torch
+from PIL import Image
+import json
+import shutil
+import spaces # integration with ZeroGPU on hf
+from planner.test_parameter import *
+from planner.model import PolicyNet
+from planner.test_worker import TestWorker
+from taxabind_avs.satbind.clip_seg_tta import ClipSegTTA
+
+
+# Helper to kill a Python thread by injecting SystemExit
+def _stop_thread(thread: threading.Thread):
+ """Forcefully raise SystemExit in the given thread (best-effort)."""
+ if thread is None or not thread.is_alive():
+ return
+ tid = thread.ident
+ if tid is None:
+ return
+ # Ask CPython to raise SystemExit in the thread context
+ res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), ctypes.py_object(SystemExit))
+ if res > 1:
+ # If it returned >1, cleanup and fail safe
+ ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), None)
+
+# ββββββββββββ Thread Registry for Cleanup on Tab Switch βββββββββββββ
+_running_threads: list[threading.Thread] = []
+_running_threads_lock = threading.Lock()
+
+# Map worker threads to their ClipSegTTA instance so UI can read executing_tta flag
+_thread_clip_map: dict[threading.Thread, ClipSegTTA] = {}
+
+# ββββββββββββ Run directory rotation βββββββββββββ
+RUN_HISTORY_LIMIT = 30 # keep at most this many timestamped run directories per instance
+
+def _prune_old_run_dirs(base_dir: str, limit: int = RUN_HISTORY_LIMIT):
+ """Delete oldest timestamp-named run directories leaving only *limit* of the newest ones."""
+ try:
+ dirs = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]
+ dirs.sort()
+ if len(dirs) > limit:
+ for obsolete in dirs[:-limit]:
+ shutil.rmtree(os.path.join(base_dir, obsolete), ignore_errors=True)
+ except Exception:
+ pass
+
+
+# CHANGE ME!
+POLL_INTERVAL = 1.0 # For visualization
+
+# Prepare the model
+device = torch.device('cuda') if USE_GPU and torch.cuda.is_available() else torch.device('cpu')
+policy_net = PolicyNet(INPUT_DIM, EMBEDDING_DIM).to(device)
+script_dir = Path(__file__).resolve().parent
+print("real_script_dir: ", script_dir)
+checkpoint = torch.load(f'{MODEL_PATH}/{MODEL_NAME}')
+policy_net.load_state_dict(checkpoint['policy_model'])
+print('Model loaded!')
+
+# Load metadata json
+tgts_metadata_json_path = os.path.join(script_dir, "examples/metadata.json")
+tgts_metadata = json.load(open(tgts_metadata_json_path))
+
+
+# ββββββββββββββββββββββββββ Gradio process fn βββββββββββββββββββββββββ
+
+### integration with ZeroGPU on hf
+# @spaces.GPU
+def process_search_tta(
+ sat_path: str | None,
+ ground_path: str | None,
+ taxonomy: str | None = None,
+ session_threads: list[threading.Thread] | None = None,
+):
+ """Run both TTA and non-TTA search episodes concurrently and stream both heat-maps."""
+
+ if session_threads is None:
+ session_threads = []
+
+ # Disable Run button and clear image/status outputs, hide sliders, clear frame states
+ yield (
+ gr.update(interactive=False),
+ gr.update(value=None),
+ gr.update(value=None),
+ gr.update(value="Initializing modelβ¦", visible=True),
+ gr.update(value="Initializing modelβ¦", visible=True),
+ gr.update(visible=False),
+ gr.update(visible=False),
+ [],
+ [],
+ session_threads,
+ )
+
+ # Bail early if satellite image missing
+ if sat_path is None:
+ yield (
+ gr.update(interactive=True),
+ gr.update(value=None),
+ gr.update(value=None),
+ gr.update(value="No satellite image provided.", visible=True),
+ gr.update(value="", visible=True),
+ gr.update(visible=False),
+ gr.update(visible=False),
+ [],
+ [],
+ session_threads,
+ )
+ return
+
+ # Prepare PIL images
+ sat_img = Image.open(sat_path).convert("RGB")
+ ground_img_pil = Image.open(ground_path).convert("RGB") if ground_path else None
+
+ # Lookup target positions metadata (may be empty)
+ tgt_positions = []
+ if taxonomy and taxonomy in tgts_metadata:
+ tgt_positions = [tuple(t) for t in tgts_metadata[taxonomy]["target_positions"]]
+
+ # Helper to build a TestWorker with/without TTA
+ def build_planner(enable_tta: bool, save_dir: str, clip_obj):
+ # Lazily (re)create a ClipSegTTA instance per thread if not provided
+ local_clip = clip_obj
+ if LOAD_AVS_BENCH and local_clip is None:
+ local_clip = ClipSegTTA(
+ img_dir=AVS_IMG_DIR,
+ imo_dir=AVS_IMO_DIR,
+ json_path=AVS_INAT_JSON_PATH,
+ sat_to_img_ids_path=AVS_SAT_TO_IMG_IDS_PATH,
+ sat_checkpoint_path=AVS_SAT_CHECKPOINT_PATH,
+ load_pretrained_hf_ckpt=AVS_LOAD_PRETRAINED_HF_CHECKPOINT,
+ blur_kernel = AVS_GAUSSIAN_BLUR_KERNEL,
+ sample_index=-1,
+ device=device,
+ sat_to_img_ids_json_is_train_dict=False,
+ tax_to_filter_val=QUERY_TAX,
+ load_model=USE_CLIP_PREDS,
+ query_modality=QUERY_MODALITY,
+ sound_dir = AVS_SOUND_DIR,
+ sound_checkpoint_path=AVS_SOUND_CHECKPOINT_PATH,
+ )
+
+ if local_clip is not None:
+ # Feed inputs to ClipSegTTA copy
+ local_clip.img_paths = [ground_path] if ground_path else []
+ local_clip.imo_path = sat_path
+ local_clip.imgs = ([local_clip.dataset.img_transform(ground_img_pil).to(device)] if ground_img_pil else [])
+ local_clip.imo = local_clip.dataset.imo_transform(sat_img).to(device)
+ local_clip.sounds = []
+ local_clip.sound_ids = []
+ local_clip.species_name = taxonomy or ""
+ local_clip.gt_mask_name = taxonomy.replace(" ", "_") if taxonomy else ""
+ local_clip.target_positions = tgt_positions if tgt_positions else [(0, 0)]
+
+ planner = TestWorker(
+ meta_agent_id=0,
+ n_agent=1,
+ policy_net=policy_net,
+ global_step=-1,
+ device=device,
+ greedy=True,
+ save_image=SAVE_GIFS,
+ clip_seg_tta=local_clip,
+ )
+ planner.execute_tta = enable_tta
+ planner.gifs_path = save_dir
+ return planner
+
+ # ββββββββββββββ Per-run output directories ββββββββββββββ
+ # Ensure base directory exists
+ os.makedirs(GIFS_PATH, exist_ok=True)
+
+ run_id = time.strftime("%Y%m%d_%H%M%S") # unique timestamp
+ run_root = os.path.join(GIFS_PATH, run_id)
+ gifs_dir_tta = os.path.join(run_root, "with_tta")
+ gifs_dir_no = os.path.join(run_root, "no_tta")
+
+ os.makedirs(gifs_dir_tta, exist_ok=True)
+ os.makedirs(gifs_dir_no, exist_ok=True)
+
+ # House-keep old runs so we never keep more than RUN_HISTORY_LIMIT
+ _prune_old_run_dirs(GIFS_PATH, RUN_HISTORY_LIMIT)
+
+ # Shared dict to record if a thread hit an exception
+ error_flags = {"tta": False, "no": False}
+
+ def _planner_thread(enable_tta: bool, save_dir: str, clip_obj, key: str):
+ """Prepare directory, build planner, run an episode, record errors."""
+ try:
+ planner = build_planner(enable_tta, save_dir, clip_obj)
+ _thread_clip_map[threading.current_thread()] = planner.clip_seg_tta
+ planner.run_episode(0)
+ except Exception as exc:
+ # Mark that this planner crashed so UI can show an error status
+ error_flags[key] = True
+ # Log full traceback so developers can debug via console logs
+ import traceback, sys
+ traceback.print_exc()
+ # Still exit the thread
+ return
+
+ # Launch both planners in background threads β preparation included
+ thread_tta = threading.Thread(
+ target=_planner_thread,
+ args=(True, gifs_dir_tta, None, "tta"),
+ daemon=True,
+ )
+ thread_no = threading.Thread(
+ target=_planner_thread,
+ args=(False, gifs_dir_no, None, "no"),
+ daemon=True,
+ )
+ # Track threads for this user session
+ session_threads.extend([thread_tta, thread_no])
+ thread_tta.start()
+ thread_no.start()
+
+
+ sent_tta: set[str] = set()
+ sent_no: set[str] = set()
+ last_tta = None
+ last_no = None
+ # Track previous status strings so we can emit updates when only the
+ # status (Runningβ¦/Done.) changes even if no new frame was produced.
+ # Previous status values so we can detect changes and yield updates
+ prev_status_tta = "Initializing modelβ¦"
+ prev_status_no = "Initializing modelβ¦"
+
+ try:
+ while thread_tta.is_alive() or thread_no.is_alive():
+ updated = False
+ # Collect new frames from TTA dir
+ pngs = glob.glob(os.path.join(gifs_dir_tta, "*.png"))
+ pngs.sort(key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
+ for fp in pngs:
+ if fp not in sent_tta:
+ # Ensure file is fully written (non-empty & readable)
+ try:
+ if os.path.getsize(fp) == 0:
+ continue
+ with open(fp, "rb") as fh:
+ fh.read(1)
+ except Exception:
+ # Skip this round; we'll retry next poll
+ continue
+ sent_tta.add(fp)
+ last_tta = fp
+ updated = True
+ # Collect new frames from no-TTA dir
+ pngs = glob.glob(os.path.join(gifs_dir_no, "*.png"))
+ pngs.sort(key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
+ for fp in pngs:
+ if fp not in sent_no:
+ try:
+ if os.path.getsize(fp) == 0:
+ continue
+ with open(fp, "rb") as fh:
+ fh.read(1)
+ except Exception:
+ continue
+ sent_no.add(fp)
+ last_no = fp
+ updated = True
+
+ # Determine status based on whether we already have a frame and whether
+ # the corresponding thread is still alive.
+ def _mk_status(last_frame, thread_alive, errored: bool, running_tta: bool=False):
+ if errored:
+ return "Error!"
+ if last_frame is None:
+ return "Initializing modelβ¦"
+ if not thread_alive:
+ return "Done."
+ return "Executing TTA (Scheduling GPUs)β¦" if running_tta else "Executing Plannerβ¦"
+
+ exec_tta_flag = False
+ if thread_tta.is_alive():
+ clip_obj = _thread_clip_map.get(thread_tta)
+ if clip_obj is not None and getattr(clip_obj, "executing_tta", False):
+ exec_tta_flag = True
+
+ status_tta = _mk_status(last_tta, thread_tta.is_alive(), error_flags["tta"], exec_tta_flag)
+ status_no = _mk_status(last_no, thread_no.is_alive(), error_flags["no"], False)
+
+ # Determine if we should reveal sliders (once corresponding thread has finished)
+ show_slider_tta = (not thread_tta.is_alive()) and (last_tta is not None)
+ show_slider_no = (not thread_no.is_alive()) and (last_no is not None)
+
+ # Build slider updates
+ slider_tta_upd = gr.update()
+ slider_no_upd = gr.update()
+ frames_tta_upd = gr.update()
+ frames_no_upd = gr.update()
+
+ if show_slider_tta:
+ n_tta_frames = max(len(sent_tta), 1)
+ slider_tta_upd = gr.update(visible=True, minimum=1, maximum=n_tta_frames, value=n_tta_frames)
+ frames_tta_upd = sorted(sent_tta, key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
+ if show_slider_no:
+ n_no_frames = max(len(sent_no), 1)
+ slider_no_upd = gr.update(visible=True, minimum=1, maximum=n_no_frames, value=n_no_frames)
+ frames_no_upd = sorted(sent_no, key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
+
+ # Emit update if we have a new frame OR status changed OR slider visibility changed
+ if (
+ updated
+ or status_tta != prev_status_tta
+ or status_no != prev_status_no
+ or show_slider_tta
+ or show_slider_no
+ ):
+ yield (
+ gr.update(interactive=False),
+ last_tta,
+ last_no,
+ gr.update(value=status_tta, visible=True),
+ gr.update(value=status_no, visible=True),
+ slider_tta_upd,
+ slider_no_upd,
+ frames_tta_upd,
+ frames_no_upd,
+ session_threads,
+ )
+
+ prev_status_tta = status_tta
+ prev_status_no = status_no
+
+ time.sleep(POLL_INTERVAL)
+ finally:
+ # Ensure background threads are stopped on cancel
+ for th in (thread_tta, thread_no):
+ if th.is_alive():
+ _stop_thread(th)
+ th.join(timeout=1)
+
+ # Remove finished threads from global registry
+ with _running_threads_lock:
+ # Clear session thread list
+ session_threads.clear()
+
+ # Small delay to ensure last frame files are fully flushed
+ time.sleep(0.2)
+ # One last scan after both threads have finished to catch any frame
+ # that may have been written just before termination but after the last
+ # polling iteration.
+ for fp in sorted(glob.glob(os.path.join(gifs_dir_tta, "*.png")), key=lambda p: int(os.path.splitext(os.path.basename(p))[0])):
+ if fp not in sent_tta:
+ sent_tta.add(fp)
+ last_tta = fp
+ for fp in sorted(glob.glob(os.path.join(gifs_dir_no, "*.png")), key=lambda p: int(os.path.splitext(os.path.basename(p))[0])):
+ if fp not in sent_no:
+ sent_no.add(fp)
+ last_no = fp
+
+ # Prepare frames list and slider configs
+ frames_tta = sorted(glob.glob(os.path.join(gifs_dir_tta, "*.png")), key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
+ frames_no = sorted(glob.glob(os.path.join(gifs_dir_no, "*.png")), key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
+ if last_tta is None and frames_tta:
+ last_tta = frames_tta[-1]
+ if last_no is None and frames_no:
+ last_no = frames_no[-1]
+ n_tta = len(frames_tta) or 1 # prevent zero-range slider
+ n_no = len(frames_no) or 1
+
+ # Final emit: re-enable button, hide statuses, show sliders set to last frame
+ yield (
+ gr.update(interactive=True),
+ last_tta,
+ last_no,
+ gr.update(visible=False),
+ gr.update(visible=False),
+ gr.update(visible=True, minimum=1, maximum=n_tta, value=n_tta),
+ gr.update(visible=True, minimum=1, maximum=n_no, value=n_no),
+ frames_tta,
+ frames_no,
+ session_threads,
+ )
+
+
+# ββββββββββββββββββββββββββ Gradio UI βββββββββββββββββββββββββββββββββ
+with gr.Blocks(title="Search-TTA (Simplified)", theme=gr.themes.Base()) as demo:
+
+ gr.Markdown(
+ """
+ # Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild Demo
+ Click on any of the examples below and run the TTA demo. Check out the multimodal heatmap generation feature by switching to the other tab above.
+ Note that the model initialization, RL planner, and TTA updates are not fully optimized on GPU for this huggingface demo, and hence may experience some lag during execution.
+ If you encounter an 'Error' status, refresh the browser and rerun the demo, or try again the next day. We will improve this in the future.
+ Project Website
+ """
+ )
+
+ with gr.Row(variant="panel"):
+ with gr.Column():
+ gr.Markdown("### Model Inputs")
+ sat_input = gr.Image(
+ label="Satellite Image",
+ sources=["upload"],
+ type="filepath",
+ height=320,
+ )
+ taxonomy_input = gr.Textbox(
+ label="Full Taxonomy Name (optional)",
+ placeholder="e.g. Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
+ )
+ ground_input = gr.Image(
+ label="Ground-level Image (optional)",
+ sources=["upload"],
+ type="filepath",
+ height=320,
+ )
+ run_btn = gr.Button("Run Search-TTA", variant="primary")
+
+ with gr.Column():
+ gr.Markdown("### Live Heatmap Output")
+ display_img_tta = gr.Image(label="Heatmap (TTA per 20 steps)", type="filepath", height=400) # 512
+ status_tta = gr.Markdown("")
+ slider_tta = gr.Slider(label="TTA Frame", minimum=1, maximum=1, step=1, value=1, visible=False)
+
+ display_img_no_tta = gr.Image(label="Heatmap (no TTA)", type="filepath", height=400) # 512
+ status_no_tta = gr.Markdown("")
+ slider_no = gr.Slider(label="No-TTA Frame", minimum=1, maximum=1, step=1, value=1, visible=False)
+
+ frames_state_tta = gr.State([])
+ frames_state_no = gr.State([])
+ session_threads_state = gr.State([])
+
+ # Slider callbacks (updates image when user drags slider)
+ def _show_frame(idx: int, frames: list[str]):
+ # Slider is 1-indexed; convert to 0-indexed list access
+ if 1 <= idx <= len(frames):
+ return frames[idx - 1]
+ return gr.update()
+
+ slider_tta.change(_show_frame, inputs=[slider_tta, frames_state_tta], outputs=display_img_tta)
+ slider_no.change(_show_frame, inputs=[slider_no, frames_state_no], outputs=display_img_no_tta)
+
+ # EXAMPLES
+ with gr.Row():
+ gr.Markdown("### Taxonomy")
+ with gr.Row():
+ gr.Examples(
+ examples=[
+ [
+ "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg",
+ "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg",
+ "Animalia Chordata Reptilia Squamata Varanidae Varanus salvator",
+ ],
+ [
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg",
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg",
+ "Animalia Chordata Mammalia Carnivora Canidae Canis aureus",
+ ],
+ [
+ "examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/340271_10.52832_-83.49678.jpg",
+ "examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/938aab7b-4509-4de7-afad-2c8ea51f4799.jpg",
+ "Animalia Chordata Reptilia Crocodylia Alligatoridae Caiman crocodilus",
+ ],
+ [
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg",
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg",
+ "Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis",
+ ],
+ ],
+ inputs=[sat_input, ground_input, taxonomy_input],
+ outputs=[run_btn, display_img_tta, display_img_no_tta, status_tta, status_no_tta, slider_tta, slider_no, frames_state_tta, frames_state_no],
+ fn=process_search_tta,
+ cache_examples=False,
+ )
+
+ run_btn.click(
+ fn=process_search_tta,
+ inputs=[sat_input, ground_input, taxonomy_input, session_threads_state],
+ outputs=[run_btn, display_img_tta, display_img_no_tta, status_tta, status_no_tta, slider_tta, slider_no, frames_state_tta, frames_state_no, session_threads_state],
+ )
+
+ # Footer to point out to model and data from app page.
+ gr.Markdown(
+ """
+ The satellite image CLIP encoder is fine-tuned using [Sentinel-2 Level 2A](https://docs.sentinel-hub.com/api/latest/data/sentinel-2-l2a/) satellite image and taxonomy images (with GPS locations) from [iNaturalist](https://inaturalist.org/). The sound CLIP encoder is fine-tuned with a subset of the same taxonomy images and their corresponding sounds from [iNaturalist](https://inaturalist.org/). Some of these iNaturalist data are also used in [Taxabind](https://arxiv.org/abs/2411.00683). Note that while some of the examples above result in poor probability distributions, they will be improved using our test-time adaptation framework during the search process.
+ """
+ )
+
+
+if __name__ == "__main__":
+
+ # Build UI with explicit Tabs so we can detect tab selection and clean up
+ from app_multimodal_inference import demo as multimodal_demo
+
+ with gr.Blocks() as root:
+ with gr.Tabs() as tabs:
+ with gr.TabItem("Multimodal Inference"):
+ multimodal_demo.render()
+ with gr.TabItem("Search-TTA"):
+ demo.render()
+
+ # Hidden textbox purely to satisfy Gradio's need for an output component.
+ _cleanup_status = gr.Textbox(visible=False)
+
+ outputs_on_tab = [_cleanup_status]
+
+ def _on_tab_change(evt: gr.SelectData, session_threads: list[threading.Thread]):
+ # evt.value contains the name of the newly-selected tab.
+ if evt.value == "Multimodal Inference":
+ # Stop only threads started in this session
+ for th in list(session_threads):
+ if th is not None and th.is_alive():
+ _stop_thread(th)
+ th.join(timeout=1)
+ session_threads.clear()
+ return "Stopped running Search-TTA threads."
+ return ""
+
+ tabs.select(_on_tab_change, inputs=[session_threads_state], outputs=outputs_on_tab)
+
+ root.queue(max_size=15)
+ root.launch(share=True)
diff --git a/app_multimodal_inference.py b/app_multimodal_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0dcfc5898c6164a313c973521bbe83bc4e097b5
--- /dev/null
+++ b/app_multimodal_inference.py
@@ -0,0 +1,350 @@
+"""
+Search-TTA multimodal heatmap generation demo
+"""
+
+# ββββββββββββββββββββββββββ imports βββββββββββββββββββββββββββββββββββ
+import cv2
+import gradio as gr
+import torch
+import numpy as np
+from PIL import Image
+import matplotlib.pyplot as plt
+import io
+import torchaudio
+import spaces # integration with ZeroGPU on hf
+
+from torchvision import transforms
+import open_clip
+from taxabind_avs.satbind.clip_vision_per_patch_model import CLIPVisionPerPatchModel
+from transformers import ClapAudioModelWithProjection
+from transformers import ClapProcessor
+
+# ββββββββββββββββββββββββββ global config & models ββββββββββββββββββββ
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+# BioCLIP (ground-image & text encoder)
+bio_model, _, _ = open_clip.create_model_and_transforms("hf-hub:imageomics/bioclip")
+bio_model = bio_model.to(device).eval()
+bio_tokenizer = open_clip.get_tokenizer("hf-hub:imageomics/bioclip")
+
+# Satellite patch encoder CLIP-L-336 per-patch)
+sat_model: CLIPVisionPerPatchModel = (
+ CLIPVisionPerPatchModel.from_pretrained("derektan95/search-tta-sat")
+ .to(device)
+ .eval()
+)
+
+# Sound CLAP model
+sound_model: ClapAudioModelWithProjection = (
+ ClapAudioModelWithProjection.from_pretrained("derektan95/search-tta-sound")
+ .to(device)
+ .eval()
+)
+sound_processor: ClapProcessor = ClapProcessor.from_pretrained("derektan95/search-tta-sound")
+SAMPLE_RATE = 48000
+
+logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+logit_scale = logit_scale.exp()
+blur_kernel = (5,5)
+
+# ββββββββββββββββββββββββββ transforms (exact spec) βββββββββββββββββββ
+img_transform = transforms.Compose(
+ [
+ transforms.Resize((256, 256)),
+ transforms.CenterCrop((224, 224)),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225],
+ ),
+ ]
+)
+
+imo_transform = transforms.Compose(
+ [
+ transforms.Resize((336, 336)),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225],
+ ),
+ ]
+)
+
+def get_audio_clap(path_to_audio,format="mp3",padding="repeatpad",truncation="fusion"):
+ track, sr = torchaudio.load(path_to_audio, format=format) # torchaudio.load(path_to_audio)
+ track = track.mean(axis=0)
+ track = torchaudio.functional.resample(track, orig_freq=sr, new_freq=SAMPLE_RATE)
+ output = sound_processor(audios=track, sampling_rate=SAMPLE_RATE, max_length_s=10, return_tensors="pt",padding=padding,truncation=truncation)
+ return output
+
+# ββββββββββββββββββββββββββ helpers βββββββββββββββββββββββββββββββββββ
+
+@torch.no_grad()
+def _encode_ground(img_pil: Image.Image) -> torch.Tensor:
+ img = img_transform(img_pil).unsqueeze(0).to(device)
+ img_embeds, *_ = bio_model(img)
+ return img_embeds
+
+
+@torch.no_grad()
+def _encode_text(text: str) -> torch.Tensor:
+ toks = bio_tokenizer(text).to(device)
+ _, txt_embeds, _ = bio_model(text=toks)
+ return txt_embeds
+
+
+@torch.no_grad()
+def _encode_sat(img_pil: Image.Image) -> torch.Tensor:
+ imo = imo_transform(img_pil).unsqueeze(0).to(device)
+ imo_embeds = sat_model(imo)
+ return imo_embeds
+
+
+@torch.no_grad()
+def _encode_sound(sound) -> torch.Tensor:
+ processed_sound = get_audio_clap(sound)
+ for k in processed_sound.keys():
+ processed_sound[k] = processed_sound[k].to(device)
+ unnormalized_audio_embeds = sound_model(**processed_sound).audio_embeds
+ sound_embeds = torch.nn.functional.normalize(unnormalized_audio_embeds, dim=-1)
+ return sound_embeds
+
+
+def _similarity_heatmap(query: torch.Tensor, patches: torch.Tensor) -> np.ndarray:
+ sims = torch.matmul(query, patches.t()) * logit_scale
+ sims = sims.t().sigmoid()
+ sims = sims[1:].squeeze() # drop CLS token
+ side = int(np.sqrt(len(sims)))
+ sims = sims.reshape(side, side)
+ return sims.cpu().detach().numpy()
+
+
+def _array_to_pil(arr: np.ndarray) -> Image.Image:
+ """
+ Render arr with viridis, automatically stretching its own minβmax to 0β1
+ so that the most-similar patches appear yellow.
+ """
+
+ # Gausian Smoothing
+ if blur_kernel != (0,0):
+ arr = cv2.GaussianBlur(arr, blur_kernel, 0)
+
+ # --- contrast-stretch to local 0-1 range --------------------------
+ arr_min, arr_max = float(arr.min()), float(arr.max())
+ if arr_max - arr_min < 1e-6: # avoid /0 when the heat-map is flat
+ arr_scaled = np.zeros_like(arr)
+ else:
+ arr_scaled = (arr - arr_min) / (arr_max - arr_min)
+ # ------------------------------------------------------------------
+ fig, ax = plt.subplots(figsize=(2.6, 2.6), dpi=96)
+ ax.imshow(arr_scaled, cmap="viridis", vmin=0.0, vmax=1.0)
+ ax.axis("off")
+ buf = io.BytesIO()
+ plt.tight_layout(pad=0)
+ fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
+ plt.close(fig)
+ buf.seek(0)
+ return Image.open(buf)
+
+# ββββββββββββββββββββββββββ main inference ββββββββββββββββββββββββββββ
+# integration with ZeroGPU on hf
+@spaces.GPU(duration=5)
+def process(
+ sat_img: Image.Image,
+ taxonomy: str,
+ ground_img: Image.Image | None,
+ sound: torch.Tensor | None,
+):
+ if sat_img is None:
+ return None, None
+
+ patches = _encode_sat(sat_img)
+
+ heat_ground, heat_text, heat_sound = None, None, None
+
+ if ground_img is not None:
+ q_img = _encode_ground(ground_img)
+ heat_ground = _array_to_pil(_similarity_heatmap(q_img, patches))
+
+ if taxonomy.strip():
+ q_txt = _encode_text(taxonomy.strip())
+ heat_text = _array_to_pil(_similarity_heatmap(q_txt, patches))
+
+ if sound is not None:
+ q_sound = _encode_sound(sound)
+ heat_sound = _array_to_pil(_similarity_heatmap(q_sound, patches))
+
+ return heat_ground, heat_text, heat_sound
+
+
+# ββββββββββββββββββββββββββ Gradio UI βββββββββββββββββββββββββββββββββ
+with gr.Blocks(title="Search-TTA", theme=gr.themes.Base()) as demo:
+
+ gr.Markdown(
+ """
+ # Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild Demo
+ Click on any of the examples below and run the multimodal inference demo. Check out the test-time adaptation feature by switching to the other tab above.
+ If you encounter any errors, refresh the browser and rerun the demo, or try again the next day. We will improve this in the future.
+ Project Website
+ """
+ )
+
+ with gr.Row(variant="panel"):
+
+ # LEFT COLUMN (satellite, taxonomy, run)
+ with gr.Column():
+ sat_input = gr.Image(
+ label="Satellite Image",
+ sources=["upload"],
+ type="pil",
+ height=320,
+ )
+ taxonomy_input = gr.Textbox(
+ label="Full Taxonomy Name (optional)",
+ placeholder="e.g. Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
+ )
+
+ # βββ NEW: sound input βββββββββββββββββββββββββββ
+ sound_input = gr.Audio(
+ label="Sound Input (optional)",
+ sources=["upload"],
+ type="filepath",
+ )
+ run_btn = gr.Button("Run", variant="primary")
+
+ # RIGHT COLUMN (ground image + two heat-maps)
+ with gr.Column():
+ ground_input = gr.Image(
+ label="Ground-level Image (optional)",
+ sources=["upload"],
+ type="pil",
+ height=320,
+ )
+ gr.Markdown("### Heat-map Results")
+ with gr.Row():
+ # Separate label and image to avoid overlap
+ with gr.Column(scale=1, min_width=100):
+ gr.Markdown("**Ground Image Query**", elem_id="label-ground")
+ heat_ground_out = gr.Image(
+ show_label=False,
+ height=160,
+ )
+ with gr.Column(scale=1, min_width=100):
+ gr.Markdown("**Text Query**", elem_id="label-text")
+ heat_text_out = gr.Image(
+ show_label=False,
+ height=160,
+ )
+ with gr.Column(scale=1, min_width=100):
+ gr.Markdown("**Sound Query**", elem_id="label-sound")
+ heat_sound_out = gr.Image(
+ show_label=False,
+ height=160,
+ )
+
+
+ # EXAMPLES
+ with gr.Row():
+ gr.Markdown("### In-Domain Taxonomy")
+ with gr.Row():
+ gr.Examples(
+ examples=[
+ [
+ "examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg",
+ "examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/cc1ebaf9-899d-49f2-81c8-d452249a8087.jpg",
+ "Animalia Chordata Aves Charadriiformes Laridae Larus marinus",
+ "examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/89758229.mp3"
+ ],
+ [
+ "examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/28871_-12.80255_-69.29999.jpg",
+ "examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/1b8064f8-7deb-4b30-98cd-69da98ba6a3d.jpg",
+ "Animalia Chordata Mammalia Rodentia Caviidae Hydrochoerus hydrochaeris",
+ "examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/166631961.mp3"
+ ],
+ [
+ "examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/277303_38.72364_-75.07749.jpg",
+ "examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/0b9cc264-a2ba-44bd-8e41-0d01a6edd1e8.jpg",
+ "Animalia Arthropoda Malacostraca Decapoda Ocypodidae Ocypode quadrata",
+ "examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/12372063.mp3"
+ ],
+ [
+ "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/388246_45.49036_7.14796.jpg",
+ "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/327e1f07-692b-4140-8a3e-bd098bc064ff.jpg",
+ "Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
+ "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/59677071.mp3"
+ ],
+ [
+ "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg",
+ "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg",
+ "Animalia Chordata Reptilia Squamata Varanidae Varanus salvator",
+ None
+ ],
+ ],
+ inputs=[sat_input, ground_input, taxonomy_input, sound_input],
+ outputs=[heat_ground_out, heat_text_out, heat_sound_out],
+ fn=process,
+ cache_examples=False,
+ )
+
+ # EXAMPLES
+ with gr.Row():
+ gr.Markdown("### Out-Domain Taxonomy")
+ with gr.Row():
+ gr.Examples(
+ examples=[
+ [
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/27423_35.64005_-121.17595.jpg",
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3aac526d-c921-452a-af6a-cb4f2f52e2c4.jpg",
+ "Animalia Chordata Mammalia Carnivora Phocidae Mirounga angustirostris",
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3123948.mp3"
+ ],
+ [
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg",
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg",
+ "Animalia Chordata Mammalia Carnivora Canidae Canis aureus",
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/189318716.mp3"
+ ],
+ [
+ "examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/340271_10.52832_-83.49678.jpg",
+ "examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/938aab7b-4509-4de7-afad-2c8ea51f4799.jpg",
+ "Animalia Chordata Reptilia Crocodylia Alligatoridae Caiman crocodilus",
+ "examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/45193295.mp3"
+ ],
+ [
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg",
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg",
+ "Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis",
+ None
+ ],
+ [
+ "examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/5041_-0.28573_-90.54837.jpg",
+ "examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/c834edf7-b073-4cd5-8726-9c88ebe943c0.jpg",
+ "Animalia Chordata Elasmobranchii Carcharhiniformes Carcharhinidae Triaenodon obesus",
+ None
+ ],
+ ],
+ inputs=[sat_input, ground_input, taxonomy_input, sound_input],
+ outputs=[heat_ground_out, heat_text_out, heat_sound_out],
+ fn=process,
+ cache_examples=False,
+ )
+
+ # CALLBACK
+ run_btn.click(
+ fn=process,
+ inputs=[sat_input, taxonomy_input, ground_input, sound_input],
+ outputs=[heat_ground_out, heat_text_out, heat_sound_out],
+ )
+
+ # Footer to point out to model and data from app page.
+ gr.Markdown(
+ """
+ The satellite image CLIP encoder is fine-tuned using [Sentinel-2 Level 2A](https://docs.sentinel-hub.com/api/latest/data/sentinel-2-l2a/) satellite image and taxonomy images (with GPS locations) from [iNaturalist](https://inaturalist.org/). The sound CLIP encoder is fine-tuned with a subset of the same taxonomy images and their corresponding sounds from [iNaturalist](https://inaturalist.org/). Some of these iNaturalist data are also used in [Taxabind](https://arxiv.org/abs/2411.00683). Note that while some of the examples above result in poor probability distributions, they will be improved using our test-time adaptation framework during the search process.
+ """
+ )
+
+# LAUNCH
+if __name__ == "__main__":
+ demo.queue(max_size=15)
+ demo.launch(share=True)
diff --git a/examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/0b9cc264-a2ba-44bd-8e41-0d01a6edd1e8.jpg b/examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/0b9cc264-a2ba-44bd-8e41-0d01a6edd1e8.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0ffd63f6e6eae1ace33246f8abecc88997a227ea
--- /dev/null
+++ b/examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/0b9cc264-a2ba-44bd-8e41-0d01a6edd1e8.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:08aee38091dbb62f0862a184acbc9432f50c03c63fdf357592df8efcacaab485
+size 134759
diff --git a/examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/12372063.mp3 b/examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/12372063.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..f25d6a26dae4ddb1c7764ea5ec3c0066cc44b42c
--- /dev/null
+++ b/examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/12372063.mp3
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:575959883981159f2e40593bf5be87be006026c41da36a34d1e40783de648116
+size 54027
diff --git a/examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/277303_38.72364_-75.07749.jpg b/examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/277303_38.72364_-75.07749.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3a8e17130c52e623889436f126fb287715535089
--- /dev/null
+++ b/examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/277303_38.72364_-75.07749.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1caa5d8bab960f559065f79ca554bed63e6b02764096874be6a58b34389855f6
+size 25627
diff --git a/examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg b/examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8fd23a3b82411137495ba6b71c7b6dd358cdc1c4
--- /dev/null
+++ b/examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8350770efa7d8e38b91670e757bb82df26167f8989f946132ad978d238baa916
+size 26142
diff --git a/examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/89758229.mp3 b/examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/89758229.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..2a85121f51df7f978961293eff638e883699fe48
--- /dev/null
+++ b/examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/89758229.mp3
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e2c7ad6df49668d29f9b7f9f9f0739b97ef4edc5219413a41d01983a9863cccc
+size 2601487
diff --git a/examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/cc1ebaf9-899d-49f2-81c8-d452249a8087.jpg b/examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/cc1ebaf9-899d-49f2-81c8-d452249a8087.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e2380dc1133ed25489c69ae1de9967a813caa1cb
--- /dev/null
+++ b/examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/cc1ebaf9-899d-49f2-81c8-d452249a8087.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e346a1c1424e62a040c7e97f17c2e5ccb4c923422682105b2ccedd0ead736170
+size 28444
diff --git a/examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/5041_-0.28573_-90.54837.jpg b/examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/5041_-0.28573_-90.54837.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1b046d43a5083e8207683bced18586dbe3fe8ece
--- /dev/null
+++ b/examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/5041_-0.28573_-90.54837.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:624443bdb62b8d35e6e63be33e04404a85ad8902b70af67d878a013893656dc2
+size 15306
diff --git a/examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/c834edf7-b073-4cd5-8726-9c88ebe943c0.jpg b/examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/c834edf7-b073-4cd5-8726-9c88ebe943c0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..829aeb164ad91ca1564548db2380160803a5dff5
--- /dev/null
+++ b/examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/c834edf7-b073-4cd5-8726-9c88ebe943c0.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e7f8be8790e7c5837d8d8e0d9285adad45138598caef21f528a591a0ab13ee9b
+size 58003
diff --git a/examples/Animalia_Chordata_Mammalia_Artiodactyla_Cervidae_Cervus_nippon/1a51f978-c76d-4816-a716-473c21b5495c.jpg b/examples/Animalia_Chordata_Mammalia_Artiodactyla_Cervidae_Cervus_nippon/1a51f978-c76d-4816-a716-473c21b5495c.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0eb620ff8de6b3bfb7e44c3809d10fca8fcd6020
--- /dev/null
+++ b/examples/Animalia_Chordata_Mammalia_Artiodactyla_Cervidae_Cervus_nippon/1a51f978-c76d-4816-a716-473c21b5495c.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a47758183723ba17f72dee9acdaf4bcfba2b4d07d0af2e50c125b3fac665ca04
+size 116615
diff --git a/examples/Animalia_Chordata_Mammalia_Artiodactyla_Cervidae_Cervus_nippon/245767_53.0076_-6.35201.jpg b/examples/Animalia_Chordata_Mammalia_Artiodactyla_Cervidae_Cervus_nippon/245767_53.0076_-6.35201.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2370e73283220a9faba6ddb7179bff1519750ac3
--- /dev/null
+++ b/examples/Animalia_Chordata_Mammalia_Artiodactyla_Cervidae_Cervus_nippon/245767_53.0076_-6.35201.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7e9e1c3907555774d831b34b9d7ef94b79b7fbe82c3e226baef75e0cf71194e4
+size 23001
diff --git a/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg b/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cbf5d4ba603adb07834bb5d4f321b465814fe79f
--- /dev/null
+++ b/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e9f1934026db176cdcea261d37eda0a02309e5f2647ecab336e53d571b40f8f4
+size 37212
diff --git a/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/189318716.mp3 b/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/189318716.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..e994349c23bfb8614fc7f6e35e4ec9be1755633c
--- /dev/null
+++ b/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/189318716.mp3
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4639c226ea5a0464b98e89b33a1f821b6625c6637d206d3d355e05bc7c89c641
+size 148019
diff --git a/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg b/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..de835cf70160678b36fc78be76f97748063de57c
--- /dev/null
+++ b/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a84eca02154ed12885c075378b0349d6950586a7887883bce414df48adb59746
+size 83083
diff --git a/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg b/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d25a95a88da478475364862dc92c3d85cff0f4c5
--- /dev/null
+++ b/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ea5f2dffebd69cdded00548f8773c5a8a8849bbdfba04ae6385fbc2d0983d55f
+size 75596
diff --git a/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg b/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..def90799f19196092663468334976a847a59f416
--- /dev/null
+++ b/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3b803c3c2e6fa921d9f83ba3aecccac0796a4cd4774c3263aae54fdfc49d13d6
+size 23309
diff --git a/examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/27423_35.64005_-121.17595.jpg b/examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/27423_35.64005_-121.17595.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..de7f0c6267cfe0d0abd4ef95d0c5612299dbfc68
--- /dev/null
+++ b/examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/27423_35.64005_-121.17595.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:16dd378607b7303515593491a1247785ae49733da24bbc3ce21e85d6c6341ab2
+size 22576
diff --git a/examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3123948.mp3 b/examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3123948.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..99fe366aae29123a967d5d909f9849de6df660a7
--- /dev/null
+++ b/examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3123948.mp3
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:96ca3a92e6f614cce82972dacb04f5c0c170c1aea3d70d15778af56820ed02c9
+size 276768
diff --git a/examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3aac526d-c921-452a-af6a-cb4f2f52e2c4.jpg b/examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3aac526d-c921-452a-af6a-cb4f2f52e2c4.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6654076f7838645dc3ce9b32b26fcccdfe864891
--- /dev/null
+++ b/examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3aac526d-c921-452a-af6a-cb4f2f52e2c4.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bdda6139885cf54acfb1d6c9a56373fbe39e205dac7eb99cd04dbe5eb206b9d6
+size 95413
diff --git a/examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/166631961.mp3 b/examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/166631961.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..e15ebd11da89b57d7ebaef55cdf96c0a810a7175
--- /dev/null
+++ b/examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/166631961.mp3
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cc02eca19d0c408d038e205d82f6624c0515858ac374cf7298161a14e169e6a9
+size 266258
diff --git a/examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/1b8064f8-7deb-4b30-98cd-69da98ba6a3d.jpg b/examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/1b8064f8-7deb-4b30-98cd-69da98ba6a3d.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0066c1e1c26bece8235dd23fd035870958284e37
--- /dev/null
+++ b/examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/1b8064f8-7deb-4b30-98cd-69da98ba6a3d.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fe636d18a1e068e85b0bb8cd05ff674eb4b19958cc34d75ef00d385f74254ecb
+size 85294
diff --git a/examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/28871_-12.80255_-69.29999.jpg b/examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/28871_-12.80255_-69.29999.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cb656a4c7e24dc5b49e79ebd7f1061512d7db5fb
--- /dev/null
+++ b/examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/28871_-12.80255_-69.29999.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d28dec20d4f9cba13386ab00f13ddd7cb36fee24ee466e8a5437dbfd778bc2d5
+size 23308
diff --git a/examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/327e1f07-692b-4140-8a3e-bd098bc064ff.jpg b/examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/327e1f07-692b-4140-8a3e-bd098bc064ff.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..11254703f2d85ee4b05bbe1784d8a55eb41e7249
--- /dev/null
+++ b/examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/327e1f07-692b-4140-8a3e-bd098bc064ff.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5bffc4c332ae6406bcb1b78cd23170bd7c71d58e2e7dac12fb812fc9aa39b8f0
+size 70314
diff --git a/examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/388246_45.49036_7.14796.jpg b/examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/388246_45.49036_7.14796.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..de28dc1d30551c4e980243a3806d8f3aba22eee4
--- /dev/null
+++ b/examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/388246_45.49036_7.14796.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a5aa9ae1a1dc4c59191bc72005fc9904d4c390f07ce5cc5ed435eb5687ae1d64
+size 33328
diff --git a/examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/59677071.mp3 b/examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/59677071.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..a1b5402460867d2b2574d3d02fcf5d1d338a7d55
--- /dev/null
+++ b/examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/59677071.mp3
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bb043991fe851d6a1e12f32c5a9277dad5a77a939cf15ccb4afcb215b4bc08e3
+size 92876
diff --git a/examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/340271_10.52832_-83.49678.jpg b/examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/340271_10.52832_-83.49678.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..da1f75add0023ba8926dffc94c3c64ee1a3ca805
--- /dev/null
+++ b/examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/340271_10.52832_-83.49678.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b31fbe934b245e7274289836b9eee781b2e33c4121dfbafebc473cd45d638825
+size 19839
diff --git a/examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/45193295.mp3 b/examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/45193295.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..b803eca684f205f062a1683f083509f23940cf8a
--- /dev/null
+++ b/examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/45193295.mp3
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c4cd2e4fd7094a07d79da7fd54788705e8ce7567e65911d87edfd23ff1c0e484
+size 247762
diff --git a/examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/938aab7b-4509-4de7-afad-2c8ea51f4799.jpg b/examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/938aab7b-4509-4de7-afad-2c8ea51f4799.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b87fa747ba1d4e0b0630970dbd6b441294191a3b
--- /dev/null
+++ b/examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/938aab7b-4509-4de7-afad-2c8ea51f4799.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a26f17668646cd25c77483565f6509ca7b21bba09ce92dac0f38d0ecbfdae3b1
+size 86323
diff --git a/examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg b/examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..53d497a058b2995a75d724a808e74c13553594d4
--- /dev/null
+++ b/examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3019573a982d10c4791e357a5bebadfbb245f145c57c60c8a53f2241ac8789fe
+size 37040
diff --git a/examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg b/examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..9464acdaf32edf40e991cd6d72955d382a34fe5e
--- /dev/null
+++ b/examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:41b11b1ea9709a9fefabc2c7ddf8aa58a7881749474bf0ccadca3a02e3a97c76
+size 175798
diff --git a/examples/metadata.json b/examples/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..1072d38e6f2a10ad3ca6c1e54b4413f82ebfb4de
--- /dev/null
+++ b/examples/metadata.json
@@ -0,0 +1,173 @@
+{
+ "Animalia Chordata Reptilia Squamata Varanidae Varanus salvator": {
+ "id": 410613,
+ "sat_key": "410613_5.35573_100.28948",
+ "sat_path": "410613_5.35573_100.28948.jpg",
+ "taxonomy": "Animalia Chordata Reptilia Squamata Varanidae Varanus salvator",
+ "count": 6,
+ "spread": 58.00460580210422,
+ "sat_bounds": {
+ "min_lat": 5.344155081363914,
+ "max_lat": 5.367304914271601,
+ "min_lon": 100.27793148340874,
+ "max_lon": 100.30102851659126
+ },
+ "img_ids": [
+ 707815,
+ 411949,
+ 701168,
+ 1619682,
+ 2100008,
+ 1548498
+ ],
+ "target_positions": [
+ [
+ 225,
+ 240
+ ],
+ [
+ 232,
+ 275
+ ],
+ [
+ 277,
+ 449
+ ],
+ [
+ 220,
+ 369
+ ],
+ [
+ 180,
+ 393
+ ],
+ [
+ 294,
+ 478
+ ]
+ ],
+ "num_landmarks": 2
+ },
+ "Animalia Chordata Mammalia Carnivora Canidae Canis aureus": {
+ "id": 1528408,
+ "sat_key": "1528408_13.00422_80.23033",
+ "sat_path": "1528408_13.00422_80.23033.jpg",
+ "taxonomy": "Animalia Chordata Mammalia Carnivora Canidae Canis aureus",
+ "count": 3,
+ "spread": 58.14007011752667,
+ "sat_bounds": {
+ "min_lat": 12.992649951077192,
+ "max_lat": 13.015790038631529,
+ "min_lon": 80.21853090802841,
+ "max_lon": 80.24212909197156
+ },
+ "img_ids": [
+ 1528479,
+ 2555188,
+ 2555189
+ ],
+ "target_positions": [
+ [
+ 309,
+ 128
+ ],
+ [
+ 239,
+ 428
+ ],
+ [
+ 240,
+ 419
+ ]
+ ],
+ "num_landmarks": 3
+ },
+ "Animalia Chordata Reptilia Crocodylia Alligatoridae Caiman crocodilus": {
+ "id": 340271,
+ "sat_key": "340271_10.52832_-83.49678",
+ "sat_path": "340271_10.52832_-83.49678.jpg",
+ "taxonomy": "Animalia Chordata Reptilia Crocodylia Alligatoridae Caiman crocodilus",
+ "count": 7,
+ "spread": 40.13902957324975,
+ "sat_bounds": {
+ "min_lat": 10.516747947357544,
+ "max_lat": 10.53989204420829,
+ "min_lon": -83.50847402265151,
+ "max_lon": -83.48508597734848
+ },
+ "img_ids": [
+ 1683531,
+ 1281855,
+ 223089,
+ 688111,
+ 330757,
+ 2408375,
+ 1955359
+ ],
+ "target_positions": [
+ [
+ 347,
+ 75
+ ],
+ [
+ 47,
+ 22
+ ],
+ [
+ 111,
+ 43
+ ],
+ [
+ 116,
+ 51
+ ],
+ [
+ 86,
+ 108
+ ],
+ [
+ 31,
+ 62
+ ],
+ [
+ 4,
+ 78
+ ]
+ ],
+ "num_landmarks": 3
+ },
+ "Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis": {
+ "id": 304160,
+ "sat_key": "304160_34.0144_-119.54417",
+ "sat_path": "304160_34.0144_-119.54417.jpg",
+ "taxonomy": "Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis",
+ "count": 3,
+ "spread": 237.64152837579553,
+ "sat_bounds": {
+ "min_lat": 34.00286041606169,
+ "max_lat": 34.02593956225012,
+ "min_lon": -119.55802743361286,
+ "max_lon": -119.53031256638712
+ },
+ "img_ids": [
+ 304160,
+ 1473173,
+ 384867
+ ],
+ "target_positions": [
+ [
+ 255,
+ 256
+ ],
+ [
+ 19,
+ 22
+ ],
+ [
+ 29,
+ 274
+ ]
+ ],
+ "num_landmarks": 3
+ }
+}
\ No newline at end of file
diff --git a/inference/model/avs_rl_policy.pth b/inference/model/avs_rl_policy.pth
new file mode 100644
index 0000000000000000000000000000000000000000..3721ea2885ee2ad7face94b1a9fecabc823c2044
--- /dev/null
+++ b/inference/model/avs_rl_policy.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:44e642df9aaa2847ba44dd4707985c67ef712f5264272ef7993aeb7805c80f5a
+size 52167246
diff --git a/maps/example/masks_val/MSK_0001.png b/maps/example/masks_val/MSK_0001.png
new file mode 100644
index 0000000000000000000000000000000000000000..21ad5af72bb32265563af29f70ca80ff131568dd
--- /dev/null
+++ b/maps/example/masks_val/MSK_0001.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:318773e2c18275d84b5145d7e69836baa0bedd833f44b49f98e6619357677cff
+size 75884
diff --git a/maps/gpt4o/envs_val/MSK_0001.png b/maps/gpt4o/envs_val/MSK_0001.png
new file mode 100644
index 0000000000000000000000000000000000000000..cdfa027638f778caf542942085be8dffc28d9e1e
--- /dev/null
+++ b/maps/gpt4o/envs_val/MSK_0001.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7af11bcef1972b7e047f53b597fef2a332d82c7feceb21aac6e14a57469c436b
+size 2337
diff --git a/planner/env.py b/planner/env.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e40201a50084d7532a7c38e6d8b3543f8ea5ca6
--- /dev/null
+++ b/planner/env.py
@@ -0,0 +1,610 @@
+#######################################################################
+# Name: env.py
+#
+# - Reads and processes training and test maps
+# - Processes rewards, new frontiers given action
+# - Updates a graph representation of environment for input into network
+#######################################################################
+
+import sys
+if sys.modules['TRAINING']:
+ from .parameter import *
+else:
+ from .test_parameter import *
+
+import os
+import cv2
+import copy
+import matplotlib.image as mpimg
+import matplotlib.pyplot as plt
+from skimage import io
+from skimage.measure import block_reduce
+from scipy.ndimage import label, find_objects
+from .sensor import *
+from .graph_generator import *
+from .node import *
+
+
+class Env():
+ def __init__(self, map_index, n_agent, k_size=20, plot=False, test=False, mask_index=None):
+ self.n_agent = n_agent
+ self.test = test
+ self.map_dir = GRIDMAP_SET_DIR
+
+ # Import environment gridmap
+ self.map_list = os.listdir(self.map_dir)
+ self.map_list.sort(reverse=True)
+
+ # NEW: Import segmentation utility map
+ self.seg_dir = MASK_SET_DIR
+ self.segmentation_mask, self.target_positions, self.target_found_idxs = None, [], []
+ self.segmentation_mask_list = os.listdir(self.seg_dir)
+ self.segmentation_mask_list.sort(reverse=True)
+
+ # # NEW: Find common files in both directories
+ self.map_index = map_index % len(self.map_list)
+ if mask_index is not None:
+ self.mask_index = mask_index % len(self.segmentation_mask_list)
+ else:
+ self.mask_index = map_index % len(self.segmentation_mask_list)
+
+ # Import ground truth and segmentation mask
+ self.ground_truth, self.map_start_position = self.import_ground_truth(
+ os.path.join(self.map_dir, self.map_list[self.map_index]))
+ self.ground_truth_size = np.shape(self.ground_truth)
+ self.robot_belief = np.ones(self.ground_truth_size) * 127 # unexplored 127
+ self.downsampled_belief = None
+ self.old_robot_belief = copy.deepcopy(self.robot_belief)
+ self.coverage_belief = np.ones(self.ground_truth_size) * 127 # unexplored 127
+
+ # Import segmentation mask
+ mask_filename = self.segmentation_mask_list[self.mask_index]
+ self.segmentation_mask = self.import_segmentation_mask(
+ os.path.join(self.seg_dir, mask_filename))
+
+ # Overwrite target positions if directory specified
+ if self.test and TARGETS_SET_DIR != "":
+ self.target_positions = self.import_targets(
+ os.path.join(TARGETS_SET_DIR, self.map_list[self.map_index]))
+
+ self.segmentation_info_mask = None
+ self.segmentation_info_mask_unnormalized = None
+ self.filtered_seg_info_mask = None
+ self.num_targets_found = 0
+ self.num_new_targets_found = 0
+ self.resolution = 4
+ self.sensor_range = SENSOR_RANGE
+ self.explored_rate = 0
+ self.targets_found_rate = 0
+ self.frontiers = None
+ self.start_positions = []
+ self.plot = plot
+ self.frame_files = []
+ self.graph_generator = Graph_generator(map_size=self.ground_truth_size, sensor_range=self.sensor_range, k_size=k_size, plot=plot)
+ self.node_coords, self.graph, self.node_utility, self.guidepost = None, None, None, None
+
+ self.begin(self.map_start_position)
+
+
+ def find_index_from_coords(self, position):
+ index = np.argmin(np.linalg.norm(self.node_coords - position, axis=1))
+ return index
+
+ def begin(self, start_position):
+ self.robot_belief = self.ground_truth
+ self.downsampled_belief = block_reduce(self.robot_belief.copy(), block_size=(self.resolution, self.resolution), func=np.min)
+ self.frontiers = self.find_frontier()
+ self.old_robot_belief = copy.deepcopy(self.robot_belief)
+
+ self.node_coords, self.graph, self.node_utility, self.guidepost = self.graph_generator.generate_graph(
+ self.robot_belief, self.frontiers)
+
+ # Define start positions
+ if FIX_START_POSITION:
+ coords_res_row = int(self.robot_belief.shape[0]/NUM_COORDS_HEIGHT)
+ coords_res_col = int(self.robot_belief.shape[1]/NUM_COORDS_WIDTH)
+ self.start_positions = [(int(self.robot_belief.shape[1]/2)-coords_res_col/2,int(self.robot_belief.shape[0]/2)-coords_res_row/2) for _ in range(self.n_agent)]
+ else:
+ nearby_coords = self.graph_generator.get_neighbors_grid_coords(start_position)
+ itr = 0
+ for i in range(self.n_agent):
+ if i == 0 or len(nearby_coords) == 0:
+ self.start_positions.append(start_position)
+ else:
+ idx = min(itr, len(nearby_coords)-1)
+ self.start_positions.append(nearby_coords[idx])
+ itr += 1
+
+ for i in range(len(self.start_positions)):
+ self.start_positions[i] = self.node_coords[self.find_index_from_coords(self.start_positions[i])]
+ self.coverage_belief = self.update_robot_belief(self.start_positions[i], self.sensor_range, self.coverage_belief,
+ self.ground_truth)
+
+ for start_position in self.start_positions:
+ self.graph_generator.route_node.append(start_position)
+
+ # Info map from ground truth
+ rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH)
+ rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT)
+ self.segmentation_info_mask = np.zeros((len(self.node_coords), 1))
+ for i, node_coord in enumerate(self.node_coords):
+ max_x = min(node_coord[0] + int(math.ceil(rng_x)), self.ground_truth.shape[1])
+ min_x = max(node_coord[0] - int(math.ceil(rng_x)), 0)
+ max_y = min(node_coord[1] + int(math.ceil(rng_y)), self.ground_truth.shape[0])
+ min_y = max(node_coord[1] - int(math.ceil(rng_y)), 0)
+
+ if TARGETS_SET_DIR == "":
+ exclude = {208} # Exclude target positions
+ else:
+ exclude = {}
+ self.segmentation_info_mask[i] = max(x for x in self.segmentation_mask[min_y:max_y, min_x:max_x].flatten() if x not in exclude) / 100.0
+
+ self.filtered_seg_info_mask = copy.deepcopy(self.segmentation_info_mask)
+ done, num_targets_found = self.check_done()
+ self.num_targets_found = num_targets_found
+
+
+ def multi_robot_step(self, next_position_list, dist_list, travel_dist_list):
+ reward_list = []
+ for dist, robot_position in zip(dist_list, next_position_list):
+ self.graph_generator.route_node.append(robot_position)
+ next_node_index = self.find_index_from_coords(robot_position)
+ self.graph_generator.nodes_list[next_node_index].set_visited()
+ self.coverage_belief = self.update_robot_belief(robot_position, self.sensor_range, self.coverage_belief,
+ self.ground_truth)
+ self.robot_belief = self.ground_truth
+ self.downsampled_belief = block_reduce(self.robot_belief.copy(),
+ block_size=(self.resolution, self.resolution),
+ func=np.min)
+
+ frontiers = self.find_frontier()
+ individual_reward = -dist / 32
+
+ info_gain_reward = 0
+ robot_position_idx = self.find_index_from_coords(robot_position)
+ info_gain_reward = self.filtered_seg_info_mask[robot_position_idx][0] * 1.5
+ if self.guidepost[robot_position_idx] == 0.0:
+ info_gain_reward += 0.2
+ individual_reward += info_gain_reward
+
+ reward_list.append(individual_reward)
+
+ self.node_coords, self.graph, self.node_utility, self.guidepost = self.graph_generator.update_graph(self.robot_belief, self.old_robot_belief, frontiers, self.frontiers)
+ self.old_robot_belief = copy.deepcopy(self.robot_belief)
+
+ self.filtered_seg_info_mask = [info[0] if self.guidepost[i] == 0.0 else 0.0 for i, info in enumerate(self.segmentation_info_mask)]
+ self.filtered_seg_info_mask = np.expand_dims(np.array(self.filtered_seg_info_mask), axis=1)
+
+ self.frontiers = frontiers
+ self.explored_rate = self.evaluate_exploration_rate()
+
+ done, num_targets_found = self.check_done()
+ self.num_new_targets_found = num_targets_found - self.num_targets_found
+ team_reward = 0.0
+
+ self.num_targets_found = num_targets_found
+ self.targets_found_rate = self.evaluate_targets_found_rate()
+
+ if done:
+ team_reward += 40
+ for i in range(len(reward_list)):
+ reward_list[i] += team_reward
+
+ return reward_list, done
+
+
+ def import_ground_truth(self, map_index):
+ # occupied 1, free 255, unexplored 127
+
+ try:
+ ground_truth = (io.imread(map_index, 1)).astype(int)
+ if np.all(ground_truth == 0):
+ ground_truth = (io.imread(map_index, 1) * 255).astype(int)
+ except:
+ new_map_index = self.map_dir + '/' + self.map_list[0]
+ ground_truth = (io.imread(new_map_index, 1)).astype(int)
+ print('could not read the map_path ({}), hence skipping it and using ({}).'.format(map_index, new_map_index))
+
+ robot_location = np.nonzero(ground_truth == 208)
+ robot_location = np.array([np.array(robot_location)[1, 127], np.array(robot_location)[0, 127]])
+ ground_truth = (ground_truth > 150)
+ ground_truth = ground_truth * 254 + 1
+ return ground_truth, robot_location
+
+
+ def import_segmentation_mask(self, map_index):
+ mask = cv2.imread(map_index).astype(int)
+ return mask
+
+ def import_targets(self, map_index):
+ # occupied 1, free 255, unexplored 127, target 208
+ mask = cv2.imread(map_index).astype(int)
+ target_positions = self.find_target_locations(mask)
+ return target_positions
+
+
+ def find_target_locations(self, image_array, grey_value=208):
+
+ grey_pixels = np.where(image_array == grey_value)
+ binary_array = np.zeros_like(image_array, dtype=bool)
+ binary_array[grey_pixels] = True
+ labeled_array, num_features = label(binary_array)
+ slices = find_objects(labeled_array)
+
+ # Calculate the center of each box
+ centers = []
+ for slice in slices:
+ row_center = (slice[0].start + slice[0].stop - 1) // 2
+ col_center = (slice[1].start + slice[1].stop - 1) // 2
+ centers.append((col_center, row_center)) # (y,x)
+
+ return centers
+
+ def free_cells(self):
+ index = np.where(self.ground_truth == 255)
+ free = np.asarray([index[1], index[0]]).T
+ return free
+
+ def update_robot_belief(self, robot_position, sensor_range, robot_belief, ground_truth):
+ robot_belief = sensor_work(robot_position, sensor_range, robot_belief, ground_truth)
+ return robot_belief
+
+
+ def check_done(self):
+ done = False
+ num_targets_found = 0
+ self.target_found_idxs = []
+ for i, target in enumerate(self.target_positions):
+ if self.coverage_belief[target[1], target[0]] == 255:
+ num_targets_found += 1
+ self.target_found_idxs.append(i)
+
+ if TERMINATE_ON_TGTS_FOUND and num_targets_found >= len(self.target_positions):
+ done = True
+ if not TERMINATE_ON_TGTS_FOUND and np.sum(self.coverage_belief == 255) / np.sum(self.ground_truth == 255) >= 0.99:
+ done = True
+
+ return done, num_targets_found
+
+
+ def calculate_num_observed_frontiers(self, old_frontiers, frontiers):
+ frontiers_to_check = frontiers[:, 0] + frontiers[:, 1] * 1j
+ pre_frontiers_to_check = old_frontiers[:, 0] + old_frontiers[:, 1] * 1j
+ frontiers_num = np.intersect1d(frontiers_to_check, pre_frontiers_to_check).shape[0]
+ pre_frontiers_num = pre_frontiers_to_check.shape[0]
+ delta_num = pre_frontiers_num - frontiers_num
+
+ return delta_num
+
+ def calculate_reward(self, dist, frontiers):
+ reward = 0
+ reward -= dist / 64
+
+ frontiers_to_check = frontiers[:, 0] + frontiers[:, 1] * 1j
+ pre_frontiers_to_check = self.frontiers[:, 0] + self.frontiers[:, 1] * 1j
+ frontiers_num = np.intersect1d(frontiers_to_check, pre_frontiers_to_check).shape[0]
+ pre_frontiers_num = pre_frontiers_to_check.shape[0]
+ delta_num = pre_frontiers_num - frontiers_num
+
+ reward += delta_num / 50
+
+ return reward
+
+ def evaluate_exploration_rate(self):
+ rate = np.sum(self.coverage_belief == 255) / np.sum(self.ground_truth == 255)
+ return rate
+
+ def evaluate_targets_found_rate(self):
+ if len(self.target_positions) == 0:
+ return 0
+ else:
+ rate = self.num_targets_found / len(self.target_positions)
+ return rate
+
+ def calculate_new_free_area(self):
+ old_free_area = self.old_robot_belief == 255
+ current_free_area = self.robot_belief == 255
+
+ new_free_area = (current_free_area.astype(np.int) - old_free_area.astype(np.int)) * 255
+
+ return new_free_area, np.sum(old_free_area)
+
+ def calculate_dist_path(self, path):
+ dist = 0
+ start = path[0]
+ end = path[-1]
+ for index in path:
+ if index == end:
+ break
+ dist += np.linalg.norm(self.node_coords[start] - self.node_coords[index])
+ start = index
+ return dist
+
+ def find_frontier(self):
+ y_len = self.downsampled_belief.shape[0]
+ x_len = self.downsampled_belief.shape[1]
+ mapping = self.downsampled_belief.copy()
+ belief = self.downsampled_belief.copy()
+ # 0-1 unknown area map
+ mapping = (mapping == 127) * 1
+ mapping = np.lib.pad(mapping, ((1, 1), (1, 1)), 'constant', constant_values=0)
+ fro_map = mapping[2:][:, 1:x_len + 1] + mapping[:y_len][:, 1:x_len + 1] + mapping[1:y_len + 1][:, 2:] + \
+ mapping[1:y_len + 1][:, :x_len] + mapping[:y_len][:, 2:] + mapping[2:][:, :x_len] + mapping[2:][:,
+ 2:] + \
+ mapping[:y_len][:, :x_len]
+ ind_free = np.where(belief.ravel(order='F') == 255)[0]
+ ind_fron_1 = np.where(1 < fro_map.ravel(order='F'))[0]
+ ind_fron_2 = np.where(fro_map.ravel(order='F') < 8)[0]
+ ind_fron = np.intersect1d(ind_fron_1, ind_fron_2)
+ ind_to = np.intersect1d(ind_free, ind_fron)
+
+ map_x = x_len
+ map_y = y_len
+ x = np.linspace(0, map_x - 1, map_x)
+ y = np.linspace(0, map_y - 1, map_y)
+ t1, t2 = np.meshgrid(x, y)
+ points = np.vstack([t1.T.ravel(), t2.T.ravel()]).T
+
+ f = points[ind_to]
+ f = f.astype(int)
+
+ f = f * self.resolution
+
+ return f
+
+
+
+ def plot_env(self, n, path, step, travel_dist, robots_route, img_path_override=None, sat_path_override=None, msk_name_override=None, sound_id_override=None):
+
+ plt.switch_backend('agg')
+ plt.cla()
+ color_list = ["r", "g", "c", "m", "y", "k"]
+
+ if not LOAD_AVS_BENCH:
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
+ else:
+ fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(20, 5.5))
+
+ ### Fig: Segmentation Mask ###
+ if LOAD_AVS_BENCH:
+ ax = ax1
+ image = mpimg.imread(img_path_override)
+ ax.imshow(image)
+ ax.set_title("Ground Image")
+ ax.axis("off")
+
+ ### Fig: Environment ###
+ msk_name = ""
+ if LOAD_AVS_BENCH:
+ image = mpimg.imread(sat_path_override)
+ msk_name = msk_name_override
+
+ ### Fig1: Environment ###
+ ax = ax2
+ ax.imshow(image)
+ ax.axis((0, self.ground_truth_size[1], self.ground_truth_size[0], 0))
+ ax.set_title("Image")
+ for i, route in enumerate(robots_route):
+ robot_marker_color = color_list[i % len(color_list)]
+ xPoints = route[0]
+ yPoints = route[1]
+ ax.plot(xPoints, yPoints, c=robot_marker_color, linewidth=2)
+ ax.plot(xPoints[-1], yPoints[-1], markersize=12, zorder=99, marker="^", ls="-", c=robot_marker_color, mec="black")
+ ax.plot(xPoints[0], yPoints[0], 'co', c=robot_marker_color, markersize=8, zorder=5)
+
+ # Sensor range
+ rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH)
+ rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT)
+ max_x = min(xPoints[-1] + int(math.ceil(rng_x)), self.ground_truth.shape[1])
+ min_x = max(xPoints[-1] - int(math.ceil(rng_x)), 0)
+ max_y = min(yPoints[-1] + int(math.ceil(rng_y)), self.ground_truth.shape[0])
+ min_y = max(yPoints[-1] - int(math.ceil(rng_y)), 0)
+ ax.plot((min_x, min_x), (min_y, max_y), c=robot_marker_color, linewidth=1)
+ ax.plot((min_x, max_x), (max_y, max_y), c=robot_marker_color, linewidth=1)
+ ax.plot((max_x, max_x), (max_y, min_y), c=robot_marker_color, linewidth=1)
+ ax.plot((max_x, min_x), (min_y, min_y), c=robot_marker_color, linewidth=1)
+
+
+ ### Fig: Graph ###
+ ax = ax3 if LOAD_AVS_BENCH else ax1
+ ax.imshow(self.coverage_belief, cmap='gray')
+ ax.axis((0, self.ground_truth_size[1], self.ground_truth_size[0], 0))
+ ax.set_title("Information Graph")
+ if VIZ_GRAPH_EDGES:
+ for i in range(len(self.graph_generator.x)):
+ ax.plot(self.graph_generator.x[i], self.graph_generator.y[i], 'tan', zorder=1)
+ ax.scatter(self.node_coords[:, 0], self.node_coords[:, 1], c=self.filtered_seg_info_mask, zorder=5, s=8)
+
+ for i, route in enumerate(robots_route):
+ robot_marker_color = color_list[i % len(color_list)]
+ xPoints = route[0]
+ yPoints = route[1]
+ ax.plot(xPoints, yPoints, c=robot_marker_color, linewidth=2)
+ ax.plot(xPoints[-1], yPoints[-1], markersize=12, zorder=99, marker="^", ls="-", c=robot_marker_color, mec="black")
+ ax.plot(xPoints[0], yPoints[0], 'co', c=robot_marker_color, markersize=8, zorder=5)
+
+ # Sensor range
+ rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH)
+ rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT)
+ max_x = min(xPoints[-1] + int(math.ceil(rng_x)), self.ground_truth.shape[1])
+ min_x = max(xPoints[-1] - int(math.ceil(rng_x)), 0)
+ max_y = min(yPoints[-1] + int(math.ceil(rng_y)), self.ground_truth.shape[0])
+ min_y = max(yPoints[-1] - int(math.ceil(rng_y)), 0)
+ ax.plot((min_x, min_x), (min_y, max_y), c=robot_marker_color, linewidth=1)
+ ax.plot((min_x, max_x), (max_y, max_y), c=robot_marker_color, linewidth=1)
+ ax.plot((max_x, max_x), (max_y, min_y), c=robot_marker_color, linewidth=1)
+ ax.plot((max_x, min_x), (min_y, min_y), c=robot_marker_color, linewidth=1)
+
+ # Plot target positions
+ for target in self.target_positions:
+ if self.coverage_belief[target[1], target[0]] == 255:
+ ax.plot(target[0], target[1], color='g', marker='x', linestyle='-', markersize=12, markeredgewidth=4, zorder=99)
+ else:
+ ax.plot(target[0], target[1], color='r', marker='x', linestyle='-', markersize=12, markeredgewidth=4, zorder=99)
+
+ ### Fig: Segmentation Mask ###
+ ax = ax4 if LOAD_AVS_BENCH else ax2
+ if LOAD_AVS_BENCH and USE_CLIP_PREDS:
+ H, W = self.ground_truth_size
+ mask_viz = self.segmentation_info_mask.squeeze().reshape((NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT)).T
+ im = ax.imshow(
+ mask_viz,
+ cmap="viridis",
+ origin="upper",
+ extent=[0, W, H, 0],
+ interpolation="nearest",
+ zorder=0,
+ )
+ ax.set_xlim(0, W)
+ ax.set_ylim(H, 0)
+ ax.set_axis_off()
+ else:
+ im = ax.imshow(self.segmentation_mask.mean(axis=-1), cmap='viridis', vmin=0, vmax=100) # cmap='gray'
+ ax.axis((0, self.ground_truth_size[1], self.ground_truth_size[0], 0))
+ ax.set_title(f"Predicted Mask (Normalized)")
+ for i, route in enumerate(robots_route):
+ robot_marker_color = color_list[i % len(color_list)]
+ xPoints = route[0]
+ yPoints = route[1]
+ ax.plot(xPoints, yPoints, c=robot_marker_color, linewidth=2)
+ ax.plot(xPoints[-1], yPoints[-1], markersize=12, zorder=99, marker="^", ls="-", c=robot_marker_color, mec="black")
+ ax.plot(xPoints[0], yPoints[0], 'co', c=robot_marker_color, markersize=8, zorder=5)
+
+ # Sensor range
+ rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH)
+ rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT)
+ max_x = min(xPoints[-1] + int(math.ceil(rng_x)), self.ground_truth.shape[1])
+ min_x = max(xPoints[-1] - int(math.ceil(rng_x)), 0)
+ max_y = min(yPoints[-1] + int(math.ceil(rng_y)), self.ground_truth.shape[0])
+ min_y = max(yPoints[-1] - int(math.ceil(rng_y)), 0)
+ ax.plot((min_x, min_x), (min_y, max_y), c=robot_marker_color, linewidth=1)
+ ax.plot((min_x, max_x), (max_y, max_y), c=robot_marker_color, linewidth=1)
+ ax.plot((max_x, max_x), (max_y, min_y), c=robot_marker_color, linewidth=1)
+ ax.plot((max_x, min_x), (min_y, min_y), c=robot_marker_color, linewidth=1)
+
+ # Add a colorbar
+ cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
+ cbar.set_label("Normalized Probs")
+
+ if sound_id_override is not None:
+ plt.suptitle('Targets Found: {}/{} Coverage ratio: {:.4g} Travel Dist: {:.4g} \n ({}) \n (Sound ID: {})'.format(self.num_targets_found, len(self.target_positions), self.explored_rate, travel_dist, msk_name, sound_id_override))
+ elif msk_name != "":
+ plt.suptitle('Targets Found: {}/{} Coverage ratio: {:.4g} Travel Dist: {:.4g} \n ({})'.format(self.num_targets_found, len(self.target_positions), self.explored_rate, travel_dist, msk_name))
+ else:
+ plt.suptitle('Targets Found: {}/{} Coverage ratio: {:.4g} Travel Dist: {:.4g}'.format(self.num_targets_found, len(self.target_positions), self.explored_rate, travel_dist))
+
+ plt.tight_layout()
+ plt.savefig('{}/{}_{}_samples.png'.format(path, n, step, dpi=100))
+ frame = '{}/{}_{}_samples.png'.format(path, n, step)
+ self.frame_files.append(frame)
+ plt.close()
+
+
+ ####################
+ # ADDED: For app.py
+ ####################
+
+ def plot_heatmap(self, save_dir, step, travel_dist, robots_route=None):
+ """Plot only the segmentation heatmap and save it as ``{step}.png`` in
+ ``save_dir``. This lightweight helper is meant for asynchronous
+ streaming in the Gradio demo when full `plot_env` is too heavy.
+
+ Parameters
+ ----------
+ save_dir : str
+ Directory to save the generated PNG file.
+ step : int
+ Current timestep; becomes the filename ``{step}.png``.
+ robots_route : list | None
+ Optional list of routes (xPoints, yPoints) to overlay.
+ Returns
+ -------
+ str
+ Full path to the generated PNG file.
+ """
+ import os
+ plt.switch_backend('agg')
+ # Do not clear the global figure state in case it interferes with
+ # the current figure. Each call creates its own Figure object that
+ # we close explicitly at the end, so a global clear is unnecessary
+ # and may break concurrent drawing.
+ # plt.cla()
+
+ color_list = ["r", "g", "c", "m", "y", "k"]
+ fig, ax = plt.subplots(1, 1, figsize=(6, 6))
+
+ # Select the mask to visualise
+ # if TAXABIND_TTA and USE_CLIP_PREDS:
+ side_dim = int(np.sqrt(self.segmentation_info_mask.shape[0]))
+ mask_viz = self.segmentation_info_mask.squeeze().reshape((side_dim, side_dim)).T
+
+ # Properly map image to pixel coordinates and keep limits fixed
+ H, W = self.ground_truth_size # rows (y), cols (x)
+ im = ax.imshow(
+ mask_viz,
+ cmap="viridis",
+ origin="upper",
+ extent=[0, W, H, 0], # x: 0..W, y: H..0 (origin at top-left)
+ interpolation="nearest", # keep cell edges sharp & aligned
+ zorder=0,
+ )
+ ax.set_xlim(0, W)
+ ax.set_ylim(H, 0)
+ ax.set_axis_off() # hide ticks but keep limits
+ # else:
+ # im = ax.imshow(self.segmentation_mask.mean(axis=-1), cmap='viridis', vmin=0, vmax=100)
+ # ax.axis((0, self.ground_truth_size[1], self.ground_truth_size[0], 0))
+
+ # Optionally overlay robot paths
+ if robots_route is not None:
+ for i, route in enumerate(robots_route):
+ robot_marker_color = color_list[i % len(color_list)]
+ xPoints, yPoints = route
+ ax.plot(xPoints, yPoints, c=robot_marker_color, linewidth=2)
+ ax.plot(xPoints[-1], yPoints[-1], markersize=12, zorder=99, marker="^", ls="-", c=robot_marker_color, mec="black")
+ ax.plot(xPoints[0], yPoints[0], 'co', c=robot_marker_color, markersize=8, zorder=5)
+
+ # Plot target positions
+ for target in self.target_positions:
+ if self.coverage_belief[target[1], target[0]] == 255:
+ # ax.plot(target[0], target[1], 'go', markersize=8, zorder=99)
+ ax.plot(target[0], target[1], color='g', marker='x', linestyle='-', markersize=12, markeredgewidth=4, zorder=99)
+ else:
+ # ax.plot(target[0], target[1], 'ro', markersize=8, zorder=99)
+ ax.plot(target[0], target[1], color='r', marker='x', linestyle='-', markersize=12, markeredgewidth=4, zorder=99)
+
+ # Sensor range
+ rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH)
+ rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT)
+ max_x = min(xPoints[-1] + int(math.ceil(rng_x)), self.ground_truth.shape[1])
+ min_x = max(xPoints[-1] - int(math.ceil(rng_x)), 0)
+ max_y = min(yPoints[-1] + int(math.ceil(rng_y)), self.ground_truth.shape[0])
+ min_y = max(yPoints[-1] - int(math.ceil(rng_y)), 0)
+ ax.plot((min_x, min_x), (min_y, max_y), c=robot_marker_color, linewidth=1)
+ ax.plot((min_x, max_x), (max_y, max_y), c=robot_marker_color, linewidth=1)
+ ax.plot((max_x, max_x), (max_y, min_y), c=robot_marker_color, linewidth=1)
+ ax.plot((max_x, min_x), (min_y, min_y), c=robot_marker_color, linewidth=1)
+
+ # Color bar
+ cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
+ cbar.set_label("Normalized Probs")
+
+ # Change coverage to 1dp
+ plt.suptitle('Targets Found: {}/{} Coverage: {:.1f}% Steps: {}/{}'.format(
+ self.num_targets_found, \
+ len(self.target_positions),
+ self.explored_rate*100,
+ step+1,
+ NUM_EPS_STEPS),
+ y=0.94, # Closer to plot
+ )
+
+ plt.tight_layout()
+ os.makedirs(save_dir, exist_ok=True)
+ out_path = os.path.join(save_dir, f"{step}.png")
+ # Save atomically: write to temp file then move into place so the poller never sees a partial file.
+ tmp_path = out_path + ".tmp"
+ fig.savefig(tmp_path, dpi=100, format='png')
+ os.replace(tmp_path, out_path) # atomic on same filesystem
+ plt.close(fig)
+ return out_path
\ No newline at end of file
diff --git a/planner/graph.py b/planner/graph.py
new file mode 100644
index 0000000000000000000000000000000000000000..68d671478edcf687c8b474160715358584631318
--- /dev/null
+++ b/planner/graph.py
@@ -0,0 +1,167 @@
+#######################################################################
+# Name: env.py
+#
+# - Adapted from https://gist.github.com/betandr/541a1f6466b6855471de5ca30b74cb31
+# - Simple graph class to perform distance calculations (E.g. A-Star, Djikstra)
+#######################################################################
+
+
+class Edge:
+ def __init__(self, to_node, length):
+ self.to_node = to_node
+ self.length = length
+
+
+class Graph:
+ def __init__(self):
+ self.nodes = set()
+ self.edges = dict()
+
+ def add_node(self, node):
+ self.nodes.add(node)
+
+ def add_edge(self, from_node, to_node, length):
+ edge = Edge(to_node, length)
+ if from_node in self.edges:
+ from_node_edges = self.edges[from_node]
+ else:
+ self.edges[from_node] = dict()
+ from_node_edges = self.edges[from_node]
+ from_node_edges[to_node] = edge
+
+ def clear_edge(self, from_node):
+ if from_node in self.edges:
+ self.edges[from_node] = dict()
+
+def min_dist(q, dist):
+ """
+ Returns the node with the smallest distance in q.
+ Implemented to keep the main algorithm clean.
+ """
+ min_node = None
+ for node in q:
+ if min_node == None:
+ min_node = node
+ elif dist[node] < dist[min_node]:
+ min_node = node
+
+ return min_node
+
+
+INFINITY = float('Infinity')
+
+
+def dijkstra(graph, source):
+ q = set()
+ dist = {}
+ prev = {}
+
+ for v in graph.nodes:
+ dist[v] = INFINITY # unknown distance from source to v
+ prev[v] = INFINITY # previous node in optimal path from source
+ q.add(v) # all nodes initially in q (unvisited nodes)
+
+ # distance from source to source
+ dist[source] = 0
+
+ while q:
+ # node with the least distance selected first
+ u = min_dist(q, dist)
+
+ q.remove(u)
+
+ try:
+ if u in graph.edges:
+ for _, v in graph.edges[u].items():
+ alt = dist[u] + v.length
+ if alt < dist[v.to_node]:
+ # a shorter path to v has been found
+ dist[v.to_node] = alt
+ prev[v.to_node] = u
+ except:
+ pass
+
+ return dist, prev
+
+
+def to_array(prev, from_node):
+ """Creates an ordered list of labels as a route."""
+ previous_node = prev[from_node]
+ route = [from_node]
+
+ while previous_node != INFINITY:
+ route.append(previous_node)
+ temp = previous_node
+ previous_node = prev[temp]
+
+ route.reverse()
+ return route
+
+
+def h(index, destination, node_coords):
+ current = node_coords[index]
+ end = node_coords[destination]
+ h = abs(end[0] - current[0]) + abs(end[1] - current[1])
+ return h
+
+
+def a_star(start, destination, node_coords, graph):
+ if start == destination:
+ return [], 0
+ if str(destination) in graph.edges[str(start)].keys():
+ cost = graph.edges[str(start)][str(destination)].length
+ return [start, destination], cost
+ open_list = {start}
+ closed_list = set([])
+
+ g = {start: 0}
+ parents = {start: start}
+
+ while len(open_list) > 0:
+ n = None
+ h_n = 1e5
+ for v in open_list:
+ h_v = h(v, destination, node_coords)
+ if n is not None:
+ h_n = h(n, destination, node_coords)
+ if n is None or g[v] + h_v < g[n] + h_n:
+ n = v
+
+ if n is None:
+ print('Path does not exist!')
+ return None, 1e5
+
+ if n == destination:
+ reconst_path = []
+ while parents[n] != n:
+ reconst_path.append(n)
+ n = parents[n]
+ reconst_path.append(start)
+ reconst_path.reverse()
+ return reconst_path, g[destination]
+
+ for edge in graph.edges[str(n)].values():
+ m = int(edge.to_node)
+ cost = edge.length
+ if m not in open_list and m not in closed_list:
+ open_list.add(m)
+ parents[m] = n
+ g[m] = g[n] + cost
+
+ else:
+ if g[m] > g[n] + cost:
+ g[m] = g[n] + cost
+ parents[m] = n
+
+ if m in closed_list:
+ closed_list.remove(m)
+ open_list.add(m)
+
+ open_list.remove(n)
+ closed_list.add(n)
+
+ print('Path does not exist!')
+ return None, 1e5
+
+
+
diff --git a/planner/graph_generator.py b/planner/graph_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..89617c2786227681280237138751373b7b30f589
--- /dev/null
+++ b/planner/graph_generator.py
@@ -0,0 +1,300 @@
+#######################################################################
+# Name: graph_generator.py
+#
+# - Wrapper for graph.py
+# - Sends the formatted inputs into graph.py to get useful info
+#######################################################################
+
+import sys
+if sys.modules['TRAINING']:
+ from .parameter import *
+else:
+ from .test_parameter import *
+
+import numpy as np
+import shapely.geometry
+from sklearn.neighbors import NearestNeighbors
+from .node import Node
+from .graph import Graph, a_star
+
+
+class Graph_generator:
+ def __init__(self, map_size, k_size, sensor_range, plot=False):
+ self.k_size = k_size
+ self.graph = Graph()
+ self.node_coords = None
+ self.plot = plot
+ self.x = []
+ self.y = []
+ self.map_x = map_size[1]
+ self.map_y = map_size[0]
+ self.uniform_points, self.grid_coords = self.generate_uniform_points()
+ self.sensor_range = sensor_range
+ self.route_node = []
+ self.nodes_list = []
+ self.node_utility = None
+ self.guidepost = None
+
+
+ def edge_clear_all_nodes(self):
+ self.graph = Graph()
+ self.x = []
+ self.y = []
+
+
+ def edge_clear(self, coords):
+ node_index = str(self.find_index_from_coords(self.node_coords, coords))
+ self.graph.clear_edge(node_index)
+
+
+ def generate_graph(self, robot_belief, frontiers):
+ self.edge_clear_all_nodes()
+ free_area = self.free_area(robot_belief)
+
+ free_area_to_check = free_area[:, 0] + free_area[:, 1] * 1j
+ uniform_points_to_check = self.uniform_points[:, 0] + self.uniform_points[:, 1] * 1j
+ _, _, candidate_indices = np.intersect1d(free_area_to_check, uniform_points_to_check, return_indices=True)
+ node_coords = self.uniform_points[candidate_indices]
+
+ self.node_coords = self.unique_coords(node_coords).reshape(-1, 2)
+ self.find_nearest_neighbor_all_nodes(self.node_coords, robot_belief)
+
+ self.node_utility = []
+ for coords in self.node_coords:
+ node = Node(coords, frontiers, robot_belief)
+ self.nodes_list.append(node)
+ utility = node.utility
+ self.node_utility.append(utility)
+ self.node_utility = np.array(self.node_utility)
+
+ self.guidepost = np.zeros((self.node_coords.shape[0], 1))
+ x = self.node_coords[:,0] + self.node_coords[:,1]*1j
+ for node in self.route_node:
+ index = np.argwhere(x.reshape(-1) == node[0]+node[1]*1j)[0]
+ self.guidepost[index] = 1
+
+ return self.node_coords, self.graph.edges, self.node_utility, self.guidepost
+
+
+ def update_graph(self, robot_belief, old_robot_belief, frontiers, old_frontiers):
+ new_free_area = self.free_area((robot_belief - old_robot_belief > 0) * 255)
+ free_area_to_check = new_free_area[:, 0] + new_free_area[:, 1] * 1j
+ uniform_points_to_check = self.uniform_points[:, 0] + self.uniform_points[:, 1] * 1j
+ _, _, candidate_indices = np.intersect1d(free_area_to_check, uniform_points_to_check, return_indices=True)
+ new_node_coords = self.uniform_points[candidate_indices]
+ self.node_coords = np.concatenate((self.node_coords, new_node_coords))
+
+ old_node_to_update = []
+ for coords in new_node_coords:
+ neighbor_indices = self.find_k_neighbor(coords, self.node_coords, robot_belief)
+ old_node_to_update += neighbor_indices
+ old_node_to_update = set(old_node_to_update)
+ for index in old_node_to_update:
+ coords = self.node_coords[index]
+ self.edge_clear(coords)
+ self.find_k_neighbor(coords, self.node_coords, robot_belief)
+
+ old_frontiers_to_check = old_frontiers[:, 0] + old_frontiers[:, 1] * 1j
+ new_frontiers_to_check = frontiers[:, 0] + frontiers[:, 1] * 1j
+ observed_frontiers_index = np.where(
+ np.isin(old_frontiers_to_check, new_frontiers_to_check, assume_unique=True) == False)
+ new_frontiers_index = np.where(
+ np.isin(new_frontiers_to_check, old_frontiers_to_check, assume_unique=True) == False)
+ observed_frontiers = old_frontiers[observed_frontiers_index]
+ new_frontiers = frontiers[new_frontiers_index]
+ for node in self.nodes_list:
+ if node.zero_utility_node is True:
+ pass
+ else:
+ node.update_observable_frontiers(observed_frontiers, new_frontiers, robot_belief)
+
+ for new_coords in new_node_coords:
+ node = Node(new_coords, frontiers, robot_belief)
+ self.nodes_list.append(node)
+
+ self.node_utility = []
+ for i, coords in enumerate(self.node_coords):
+ utility = self.nodes_list[i].utility
+ self.node_utility.append(utility)
+ self.node_utility = np.array(self.node_utility)
+
+ self.guidepost = np.zeros((self.node_coords.shape[0], 1))
+ x = self.node_coords[:, 0] + self.node_coords[:, 1] * 1j
+ for node in self.route_node:
+ index = np.argwhere(x.reshape(-1) == node[0] + node[1] * 1j)
+ self.guidepost[index] = 1
+
+ return self.node_coords, self.graph.edges, self.node_utility, self.guidepost
+
+
+ def generate_uniform_points(self):
+ padding_x = 0.5 * (self.map_x / NUM_COORDS_WIDTH)
+ padding_y = 0.5 * (self.map_y / NUM_COORDS_HEIGHT)
+ x = np.linspace(padding_x, self.map_x - padding_x - 1, NUM_COORDS_WIDTH).round().astype(int)
+ y = np.linspace(padding_y, self.map_y - padding_y - 1, NUM_COORDS_HEIGHT).round().astype(int)
+
+ t1, t2 = np.meshgrid(x, y)
+ points = np.vstack([t1.T.ravel(), t2.T.ravel()]).T
+ matrix = np.stack((t1, t2), axis=-1)
+ return points, matrix
+
+
+ def free_area(self, robot_belief):
+ index = np.where(robot_belief == 255)
+ free = np.asarray([index[1], index[0]]).T
+ return free
+
+
+ def unique_coords(self, coords):
+ x = coords[:, 0] + coords[:, 1] * 1j
+ indices = np.unique(x, return_index=True)[1]
+ coords = np.array([coords[idx] for idx in sorted(indices)])
+ return coords
+
+
+ def find_k_neighbor(self, coords, node_coords, robot_belief):
+ dist_list = np.linalg.norm((coords-node_coords), axis=-1)
+ sorted_index = np.argsort(dist_list)
+ k = 0
+ neighbor_index_list = []
+ while k < self.k_size and k< node_coords.shape[0]:
+ neighbor_index = sorted_index[k]
+ neighbor_index_list.append(neighbor_index)
+ dist = dist_list[k]
+ start = coords
+ end = node_coords[neighbor_index]
+ if not self.check_collision(start, end, robot_belief):
+ a = str(self.find_index_from_coords(node_coords, start))
+ b = str(neighbor_index)
+ self.graph.add_node(a)
+ self.graph.add_edge(a, b, dist)
+
+ if self.plot:
+ self.x.append([start[0], end[0]])
+ self.y.append([start[1], end[1]])
+ k += 1
+ return neighbor_index_list
+
+
+ def find_k_neighbor_all_nodes(self, node_coords, robot_belief):
+ X = node_coords
+ if len(node_coords) >= self.k_size:
+ knn = NearestNeighbors(n_neighbors=self.k_size)
+ else:
+ knn = NearestNeighbors(n_neighbors=len(node_coords))
+ knn.fit(X)
+ distances, indices = knn.kneighbors(X)
+
+ for i, p in enumerate(X):
+ for j, neighbour in enumerate(X[indices[i][:]]):
+ start = p
+ end = neighbour
+ if not self.check_collision(start, end, robot_belief):
+ a = str(self.find_index_from_coords(node_coords, p))
+ b = str(self.find_index_from_coords(node_coords, neighbour))
+ self.graph.add_node(a)
+ self.graph.add_edge(a, b, distances[i, j])
+
+ if self.plot:
+ self.x.append([p[0], neighbour[0]])
+ self.y.append([p[1], neighbour[1]])
+
+
+ def find_nearest_neighbor_all_nodes(self, node_coords, robot_belief):
+ for i, p in enumerate(node_coords):
+ filtered_coords = self.get_neighbors_grid_coords(p)
+
+ for j, neighbour in enumerate(filtered_coords):
+ start = p
+ end = neighbour
+ if not self.check_collision(start, end, robot_belief):
+ a = str(self.find_index_from_coords(node_coords, p))
+ b = str(self.find_index_from_coords(node_coords, neighbour))
+ self.graph.add_node(a)
+ self.graph.add_edge(a, b, np.linalg.norm(start-end))
+
+ if self.plot:
+ self.x.append([p[0], neighbour[0]])
+ self.y.append([p[1], neighbour[1]])
+
+
+ def find_index_from_coords(self, node_coords, p):
+ return np.where(np.linalg.norm(node_coords - p, axis=1) < 1e-5)[0][0]
+
+
+ def find_closest_index_from_coords(self, node_coords, p):
+ return np.argmin(np.linalg.norm(node_coords - p, axis=1))
+
+
+ def find_index_from_grid_coords_2d(self, p):
+ diffs = np.linalg.norm(self.grid_coords - p, axis=2)
+ indices = np.where(diffs < 1e-5)
+
+ if indices[0].size > 0:
+ return indices[0][0], indices[1][0]
+ else:
+ raise ValueError(f"Coordinate {p} not found in self.grid_coords.")
+
+
+ def find_closest_index_from_grid_coords_2d(self, p):
+ distances = np.linalg.norm(self.grid_coords - p, axis=2)
+ flat_index = np.argmin(distances)
+ return np.unravel_index(flat_index, distances.shape)
+
+
+ def check_collision(self, start, end, robot_belief):
+ collision = False
+ line = shapely.geometry.LineString([start, end])
+
+ sortx = np.sort([start[0], end[0]])
+ sorty = np.sort([start[1], end[1]])
+
+ robot_belief = robot_belief[sorty[0]:sorty[1]+1, sortx[0]:sortx[1]+1]
+
+ occupied_area_index = np.where(robot_belief == 1)
+ occupied_area_coords = np.asarray([occupied_area_index[1]+sortx[0], occupied_area_index[0]+sorty[0]]).T
+ unexplored_area_index = np.where(robot_belief == 127)
+ unexplored_area_coords = np.asarray([unexplored_area_index[1]+sortx[0], unexplored_area_index[0]+sorty[0]]).T
+ unfree_area_coords = np.concatenate((occupied_area_coords, unexplored_area_coords))
+
+ for i in range(unfree_area_coords.shape[0]):
+ coords = ([(unfree_area_coords[i][0], unfree_area_coords[i][1]),
+ (unfree_area_coords[i][0] + 1, unfree_area_coords[i][1]),
+ (unfree_area_coords[i][0], unfree_area_coords[i][1] + 1),
+ (unfree_area_coords[i][0] + 1, unfree_area_coords[i][1] + 1)])
+ obstacle = shapely.geometry.Polygon(coords)
+ collision = line.intersects(obstacle)
+ if collision:
+ break
+
+ return collision
+
+
+ def find_shortest_path(self, current, destination, node_coords):
+ start_node = str(self.find_index_from_coords(node_coords, current))
+ end_node = str(self.find_index_from_coords(node_coords, destination))
+ route, dist = a_star(int(start_node), int(end_node), self.node_coords, self.graph)
+ if start_node != end_node:
+ assert route != []
+ route = list(map(str, route))
+ return dist, route
+
+ def get_neighbors_grid_coords(self, coord):
+ # Return the 8 closest neighbors of a given coordinate
+
+ nearest_coord = self.node_coords[self.find_closest_index_from_coords(self.node_coords, coord)]
+ rows, cols = self.grid_coords.shape[:2]
+ neighbors = []
+ i, j = self.find_index_from_grid_coords_2d(nearest_coord)
+
+ # Create a range of indices for rows and columns
+ row_range = np.clip([i - 1, i, i + 1], 0, rows - 1)
+ col_range = np.clip([j - 1, j, j + 1], 0, cols - 1)
+
+ # Iterate over the valid indices
+ for ni in row_range:
+ for nj in col_range:
+ if (ni, nj) != (i, j): # Skip the center point
+ neighbors.append(tuple(self.grid_coords[ni, nj]))
+
+ return neighbors
\ No newline at end of file
diff --git a/planner/model.py b/planner/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cede92ba760d7cc0f789397cc11aa8e0b7a61cb
--- /dev/null
+++ b/planner/model.py
@@ -0,0 +1,312 @@
+#######################################################################
+# Name: model.py
+#
+# - Attention-based encoders & decoders
+# - Policy Net: Input = Augmented Graph, Output = Node to go to
+# - Critic Net: Input = Augmented Graph + Action, Output = Q_Value
+#######################################################################
+
+import torch
+import torch.nn as nn
+import math
+
+
+class SingleHeadAttention(nn.Module):
+ def __init__(self, embedding_dim):
+ super(SingleHeadAttention, self).__init__()
+ self.input_dim = embedding_dim
+ self.embedding_dim = embedding_dim
+ self.value_dim = embedding_dim
+ self.key_dim = self.value_dim
+ self.tanh_clipping = 10
+ self.norm_factor = 1 / math.sqrt(self.key_dim)
+
+ self.w_query = nn.Parameter(torch.Tensor(self.input_dim, self.key_dim))
+ self.w_key = nn.Parameter(torch.Tensor(self.input_dim, self.key_dim))
+
+ self.init_parameters()
+
+ def init_parameters(self):
+ for param in self.parameters():
+ stdv = 1. / math.sqrt(param.size(-1))
+ param.data.uniform_(-stdv, stdv)
+
+ def forward(self, q, k, mask=None):
+
+ n_batch, n_key, n_dim = k.size()
+ n_query = q.size(1)
+
+ k_flat = k.reshape(-1, n_dim)
+ q_flat = q.reshape(-1, n_dim)
+
+ shape_k = (n_batch, n_key, -1)
+ shape_q = (n_batch, n_query, -1)
+
+ Q = torch.matmul(q_flat, self.w_query).view(shape_q)
+ K = torch.matmul(k_flat, self.w_key).view(shape_k)
+
+ U = self.norm_factor * torch.matmul(Q, K.transpose(1, 2))
+ U = self.tanh_clipping * torch.tanh(U)
+
+ if mask is not None:
+ U = U.masked_fill(mask == 1, -1e8)
+ attention = torch.log_softmax(U, dim=-1) # n_batch*n_query*n_key
+
+ return attention
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(self, embedding_dim, n_heads=8):
+ super(MultiHeadAttention, self).__init__()
+ self.n_heads = n_heads
+ self.input_dim = embedding_dim
+ self.embedding_dim = embedding_dim
+ self.value_dim = self.embedding_dim // self.n_heads
+ self.key_dim = self.value_dim
+ self.norm_factor = 1 / math.sqrt(self.key_dim)
+
+ self.w_query = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.key_dim))
+ self.w_key = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.key_dim))
+ self.w_value = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.value_dim))
+ self.w_out = nn.Parameter(torch.Tensor(self.n_heads, self.value_dim, self.embedding_dim))
+
+ self.init_parameters()
+
+ def init_parameters(self):
+ for param in self.parameters():
+ stdv = 1. / math.sqrt(param.size(-1))
+ param.data.uniform_(-stdv, stdv)
+
+ def forward(self, q, k=None, v=None, key_padding_mask=None, attn_mask=None):
+ if k is None:
+ k = q
+ if v is None:
+ v = q
+
+ n_batch, n_key, n_dim = k.size()
+ n_query = q.size(1)
+ n_value = v.size(1)
+
+ k_flat = k.contiguous().view(-1, n_dim)
+ v_flat = v.contiguous().view(-1, n_dim)
+ q_flat = q.contiguous().view(-1, n_dim)
+ shape_v = (self.n_heads, n_batch, n_value, -1)
+ shape_k = (self.n_heads, n_batch, n_key, -1)
+ shape_q = (self.n_heads, n_batch, n_query, -1)
+
+ Q = torch.matmul(q_flat, self.w_query).view(shape_q) # n_heads*batch_size*n_query*key_dim
+ K = torch.matmul(k_flat, self.w_key).view(shape_k) # n_heads*batch_size*targets_size*key_dim
+ V = torch.matmul(v_flat, self.w_value).view(shape_v) # n_heads*batch_size*targets_size*value_dim
+
+ U = self.norm_factor * torch.matmul(Q, K.transpose(2, 3)) # n_heads*batch_size*n_query*targets_size
+
+ if attn_mask is not None:
+ attn_mask = attn_mask.view(1, n_batch, n_query, n_key).expand_as(U)
+
+ if key_padding_mask is not None:
+ key_padding_mask = key_padding_mask.repeat(1, n_query, 1)
+ key_padding_mask = key_padding_mask.view(1, n_batch, n_query, n_key).expand_as(U) # copy for n_heads times
+
+ if attn_mask is not None and key_padding_mask is not None:
+ mask = (attn_mask + key_padding_mask)
+ elif attn_mask is not None:
+ mask = attn_mask
+ elif key_padding_mask is not None:
+ mask = key_padding_mask
+ else:
+ mask = None
+
+ if mask is not None:
+ U = U.masked_fill(mask > 0, -1e8)
+
+ attention = torch.softmax(U, dim=-1) # n_heads*batch_size*n_query*targets_size
+ heads = torch.matmul(attention, V) # n_heads*batch_size*n_query*value_dim
+ out = torch.mm(
+ heads.permute(1, 2, 0, 3).reshape(-1, self.n_heads * self.value_dim),
+ # batch_size*n_query*n_heads*value_dim
+ self.w_out.view(-1, self.embedding_dim)
+ # n_heads*value_dim*embedding_dim
+ ).view(-1, n_query, self.embedding_dim)
+
+
+ return out, attention # batch_size*n_query*embedding_dim
+
+
+class Normalization(nn.Module):
+ def __init__(self, embedding_dim):
+ super(Normalization, self).__init__()
+ self.normalizer = nn.LayerNorm(embedding_dim)
+
+ def forward(self, input):
+ return self.normalizer(input.view(-1, input.size(-1))).view(*input.size())
+
+
+class EncoderLayer(nn.Module):
+ def __init__(self, embedding_dim, n_head):
+ super(EncoderLayer, self).__init__()
+ self.multiHeadAttention = MultiHeadAttention(embedding_dim, n_head)
+ self.normalization1 = Normalization(embedding_dim)
+ self.feedForward = nn.Sequential(nn.Linear(embedding_dim, 512), nn.ReLU(inplace=True),
+ nn.Linear(512, embedding_dim))
+ self.normalization2 = Normalization(embedding_dim)
+
+ def forward(self, src, key_padding_mask=None, attn_mask=None):
+ h0 = src
+ h = self.normalization1(src)
+ h, _ = self.multiHeadAttention(q=h, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
+ h = h + h0
+ h1 = h
+ h = self.normalization2(h)
+ h = self.feedForward(h)
+ h2 = h + h1
+ return h2
+
+
+class DecoderLayer(nn.Module):
+ def __init__(self, embedding_dim, n_head):
+ super(DecoderLayer, self).__init__()
+ self.multiHeadAttention = MultiHeadAttention(embedding_dim, n_head)
+ self.normalization1 = Normalization(embedding_dim)
+ self.feedForward = nn.Sequential(nn.Linear(embedding_dim, 512),
+ nn.ReLU(inplace=True),
+ nn.Linear(512, embedding_dim))
+ self.normalization2 = Normalization(embedding_dim)
+
+ def forward(self, tgt, memory, key_padding_mask=None, attn_mask=None):
+ h0 = tgt
+ tgt = self.normalization1(tgt)
+ memory = self.normalization1(memory)
+ h, w = self.multiHeadAttention(q=tgt, k=memory, v=memory, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
+ h = h + h0
+ h1 = h
+ h = self.normalization2(h)
+ h = self.feedForward(h)
+ h2 = h + h1
+ return h2, w
+
+
+class Encoder(nn.Module):
+ def __init__(self, embedding_dim=128, n_head=8, n_layer=1):
+ super(Encoder, self).__init__()
+ self.layers = nn.ModuleList(EncoderLayer(embedding_dim, n_head) for i in range(n_layer))
+
+ def forward(self, src, key_padding_mask=None, attn_mask=None):
+ for layer in self.layers:
+ src = layer(src, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
+ return src
+
+
+class Decoder(nn.Module):
+ def __init__(self, embedding_dim=128, n_head=8, n_layer=1):
+ super(Decoder, self).__init__()
+ self.layers = nn.ModuleList([DecoderLayer(embedding_dim, n_head) for i in range(n_layer)])
+
+ def forward(self, tgt, memory, key_padding_mask=None, attn_mask=None):
+ for layer in self.layers:
+ tgt, w = layer(tgt, memory, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
+ return tgt, w
+
+
+class PolicyNet(nn.Module):
+ def __init__(self, input_dim, embedding_dim):
+ super(PolicyNet, self).__init__()
+ self.initial_embedding = nn.Linear(input_dim, embedding_dim) # layer for non-end position
+
+ self.current_embedding = nn.Linear(embedding_dim * 2, embedding_dim)
+
+ self.encoder = Encoder(embedding_dim=embedding_dim, n_head=8, n_layer=6)
+ self.decoder = Decoder(embedding_dim=embedding_dim, n_head=8, n_layer=1)
+ self.pointer = SingleHeadAttention(embedding_dim)
+
+ def encode_graph(self, node_inputs, node_padding_mask, edge_mask):
+ node_feature = self.initial_embedding(node_inputs)
+ enhanced_node_feature = self.encoder(src=node_feature, key_padding_mask=node_padding_mask, attn_mask=edge_mask)
+
+ return enhanced_node_feature
+
+ def output_policy(self, enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask):
+ k_size = edge_inputs.size()[2]
+ current_edge = torch.gather(edge_inputs, 1, current_index.repeat(1, 1, k_size))
+ current_edge = current_edge.permute(0, 2, 1)
+ embedding_dim = enhanced_node_feature.size()[2]
+
+ neigboring_feature = torch.gather(enhanced_node_feature, 1, current_edge.repeat(1, 1, embedding_dim))
+
+ current_node_feature = torch.gather(enhanced_node_feature, 1, current_index.repeat(1, 1, embedding_dim))
+
+ if edge_padding_mask is not None:
+ current_mask = torch.gather(edge_padding_mask, 1, current_index.repeat(1,1,k_size)).to(enhanced_node_feature.device)
+ else:
+ current_mask = None
+ current_mask[:,:,0] = 1 # don't stay at current position
+
+ if not 0 in current_mask:
+ current_mask[:,:,0] = 0
+
+ enhanced_current_node_feature, _ = self.decoder(current_node_feature, enhanced_node_feature, node_padding_mask)
+ enhanced_current_node_feature = self.current_embedding(torch.cat((enhanced_current_node_feature, current_node_feature), dim=-1))
+ logp = self.pointer(enhanced_current_node_feature, neigboring_feature, current_mask)
+ logp= logp.squeeze(1) # batch_size*k_size
+
+ return logp
+
+ def forward(self, node_inputs, edge_inputs, current_index, node_padding_mask=None, edge_padding_mask=None, edge_mask=None):
+ enhanced_node_feature = self.encode_graph(node_inputs, node_padding_mask, edge_mask)
+ logp = self.output_policy(enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask)
+ return logp
+
+
+class QNet(nn.Module):
+ def __init__(self, input_dim, embedding_dim):
+ super(QNet, self).__init__()
+ self.initial_embedding = nn.Linear(input_dim, embedding_dim) # layer for non-end position
+ self.action_embedding = nn.Linear(embedding_dim*3, embedding_dim)
+
+ self.encoder = Encoder(embedding_dim=embedding_dim, n_head=8, n_layer=6)
+ self.decoder = Decoder(embedding_dim=embedding_dim, n_head=8, n_layer=1)
+
+ self.q_values_layer = nn.Linear(embedding_dim, 1)
+
+ def encode_graph(self, node_inputs, node_padding_mask, edge_mask):
+ embedding_feature = self.initial_embedding(node_inputs)
+ embedding_feature = self.encoder(src=embedding_feature, key_padding_mask=node_padding_mask, attn_mask=edge_mask)
+
+ return embedding_feature
+
+ def output_q_values(self, enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask):
+ k_size = edge_inputs.size()[2]
+ current_edge = torch.gather(edge_inputs, 1, current_index.repeat(1, 1, k_size))
+ current_edge = current_edge.permute(0, 2, 1)
+ embedding_dim = enhanced_node_feature.size()[2]
+
+ neigboring_feature = torch.gather(enhanced_node_feature, 1, current_edge.repeat(1, 1, embedding_dim))
+
+ current_node_feature = torch.gather(enhanced_node_feature, 1, current_index.repeat(1, 1, embedding_dim))
+
+ enhanced_current_node_feature, attention_weights = self.decoder(current_node_feature, enhanced_node_feature, node_padding_mask)
+ action_features = torch.cat((enhanced_current_node_feature.repeat(1, k_size, 1), current_node_feature.repeat(1, k_size, 1), neigboring_feature), dim=-1)
+ action_features = self.action_embedding(action_features)
+ q_values = self.q_values_layer(action_features)
+
+ if edge_padding_mask is not None:
+ current_mask = torch.gather(edge_padding_mask, 1, current_index.repeat(1, 1, k_size)).to(
+ enhanced_node_feature.device)
+ else:
+ current_mask = None
+ current_mask[:, :, 0] = 1 # don't stay at current position
+
+ if not 0 in current_mask:
+ current_mask[:,:,0] = 0
+
+ current_mask = current_mask.permute(0, 2, 1)
+ zero = torch.zeros_like(q_values).to(q_values.device)
+ q_values = torch.where(current_mask == 1, zero, q_values)
+
+ return q_values, attention_weights
+
+ def forward(self, node_inputs, edge_inputs, current_index, node_padding_mask=None, edge_padding_mask=None,
+ edge_mask=None):
+ enhanced_node_feature = self.encode_graph(node_inputs, node_padding_mask, edge_mask)
+ q_values, attention_weights = self.output_q_values(enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask)
+ return q_values, attention_weights
+
diff --git a/planner/node.py b/planner/node.py
new file mode 100644
index 0000000000000000000000000000000000000000..b856ef1487226e8010bbc352ae7ecdb33276eef4
--- /dev/null
+++ b/planner/node.py
@@ -0,0 +1,96 @@
+#######################################################################
+# Name: node.py
+#
+# - Contains info per node on graph (edge)
+# - Contains: Position, Utility, Visitation History
+#######################################################################
+
+import sys
+if sys.modules['TRAINING']:
+ from .parameter import *
+else:
+ from .test_parameter import *
+
+import numpy as np
+import shapely.geometry
+
+
+class Node():
+ def __init__(self, coords, frontiers, robot_belief):
+ self.coords = coords
+ self.observable_frontiers = []
+ self.sensor_range = SENSOR_RANGE
+ self.initialize_observable_frontiers(frontiers, robot_belief)
+ self.utility = self.get_node_utility()
+ if self.utility == 0:
+ self.zero_utility_node = True
+ else:
+ self.zero_utility_node = False
+
+ def initialize_observable_frontiers(self, frontiers, robot_belief):
+ dist_list = np.linalg.norm(frontiers - self.coords, axis=-1)
+ frontiers_in_range = frontiers[dist_list < self.sensor_range - 10]
+ for point in frontiers_in_range:
+ collision = self.check_collision(self.coords, point, robot_belief)
+ if not collision:
+ self.observable_frontiers.append(point)
+
+ def get_node_utility(self):
+ return len(self.observable_frontiers)
+
+ def update_observable_frontiers(self, observed_frontiers, new_frontiers, robot_belief):
+ if observed_frontiers != []:
+ observed_index = []
+ for i, point in enumerate(self.observable_frontiers):
+ if point[0] + point[1] * 1j in observed_frontiers[:, 0] + observed_frontiers[:, 1] * 1j:
+ observed_index.append(i)
+ for index in reversed(observed_index):
+ self.observable_frontiers.pop(index)
+ #
+ if new_frontiers != []:
+ dist_list = np.linalg.norm(new_frontiers - self.coords, axis=-1)
+ new_frontiers_in_range = new_frontiers[dist_list < self.sensor_range - 15]
+ for point in new_frontiers_in_range:
+ collision = self.check_collision(self.coords, point, robot_belief)
+ if not collision:
+ self.observable_frontiers.append(point)
+
+ self.utility = self.get_node_utility()
+ if self.utility == 0:
+ self.zero_utility_node = True
+ else:
+ self.zero_utility_node = False
+
+ def set_visited(self):
+ self.observable_frontiers = []
+ self.utility = 0
+ self.zero_utility_node = True
+
+ def check_collision(self, start, end, robot_belief):
+ collision = False
+ line = shapely.geometry.LineString([start, end])
+
+ sortx = np.sort([start[0], end[0]])
+ sorty = np.sort([start[1], end[1]])
+
+ robot_belief = robot_belief[sorty[0]:sorty[1] + 1, sortx[0]:sortx[1] + 1]
+
+ occupied_area_index = np.where(robot_belief == 1)
+ occupied_area_coords = np.asarray(
+ [occupied_area_index[1] + sortx[0], occupied_area_index[0] + sorty[0]]).T
+ unexplored_area_index = np.where(robot_belief == 127)
+ unexplored_area_coords = np.asarray(
+ [unexplored_area_index[1] + sortx[0], unexplored_area_index[0] + sorty[0]]).T
+ unfree_area_coords = np.concatenate((occupied_area_coords, unexplored_area_coords))
+
+ for i in range(unfree_area_coords.shape[0]):
+ coords = ([(unfree_area_coords[i][0], unfree_area_coords[i][1]),
+ (unfree_area_coords[i][0] + 1, unfree_area_coords[i][1]),
+ (unfree_area_coords[i][0], unfree_area_coords[i][1] + 1),
+ (unfree_area_coords[i][0] + 1, unfree_area_coords[i][1] + 1)])
+ obstacle = shapely.geometry.Polygon(coords)
+ collision = line.intersects(obstacle)
+ if collision:
+ break
+
+ return collision
diff --git a/planner/robot.py b/planner/robot.py
new file mode 100644
index 0000000000000000000000000000000000000000..898c3146d7d7c0a2c6d4463e4184a6d89b93898d
--- /dev/null
+++ b/planner/robot.py
@@ -0,0 +1,58 @@
+#######################################################################
+# Name: robot.py
+#
+# - Stores S(t), A(t), R(t), S(t+1)
+#######################################################################
+
+import torch
+from copy import deepcopy
+
+class Robot:
+ def __init__(self, robot_id, position, plot=False):
+ self.robot_id = robot_id
+ self.plot = plot
+ self.travel_dist = 0
+ self.robot_position = position
+ self.observations = None
+ self.trajectory_coords = []
+ self.targets_found_on_path = []
+
+ self.episode_buffer = []
+ for i in range(15):
+ self.episode_buffer.append([])
+
+ if self.plot:
+ self.xPoints = [self.robot_position[0]]
+ self.yPoints = [self.robot_position[1]]
+
+ def save_observations(self, observations):
+ node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask = observations
+ self.episode_buffer[0] += deepcopy(node_inputs).to('cpu')
+ self.episode_buffer[1] += deepcopy(edge_inputs).to('cpu')
+ self.episode_buffer[2] += deepcopy(current_index).to('cpu')
+ self.episode_buffer[3] += deepcopy(node_padding_mask).to('cpu')
+ self.episode_buffer[4] += deepcopy(edge_padding_mask).to('cpu')
+ self.episode_buffer[5] += deepcopy(edge_mask).to('cpu')
+
+ def save_action(self, action_index):
+ self.episode_buffer[6] += action_index.unsqueeze(0).unsqueeze(0)
+
+ def save_reward_done(self, reward, done):
+ self.episode_buffer[7] += deepcopy(torch.FloatTensor([[[reward]]])).to('cpu')
+ self.episode_buffer[8] += deepcopy(torch.tensor([[[(int(done))]]])).to('cpu')
+ if self.plot:
+ self.xPoints.append(self.robot_position[0])
+ self.yPoints.append(self.robot_position[1])
+
+ def save_next_observations(self, observations):
+ node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask = observations
+ self.episode_buffer[9] += deepcopy(node_inputs).to('cpu')
+ self.episode_buffer[10] += deepcopy(edge_inputs).to('cpu')
+ self.episode_buffer[11] += deepcopy(current_index).to('cpu')
+ self.episode_buffer[12] += deepcopy(node_padding_mask).to('cpu')
+ self.episode_buffer[13] += deepcopy(edge_padding_mask).to('cpu')
+ self.episode_buffer[14] += deepcopy(edge_mask).to('cpu')
+
+ def save_trajectory_coords(self, robot_position_coords, num_target_found):
+ self.trajectory_coords.append(robot_position_coords)
+ self.targets_found_on_path.append(num_target_found)
diff --git a/planner/sensor.py b/planner/sensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..98b010b54268c79cebd077039eab90bb426bc2e8
--- /dev/null
+++ b/planner/sensor.py
@@ -0,0 +1,128 @@
+#######################################################################
+# Name: sensor.py
+#
+# - Computes sensor related checks (e.g. collision, utility etc)
+#######################################################################
+
+import sys
+if sys.modules['TRAINING']:
+ from .parameter import *
+else:
+ from .test_parameter import *
+
+import math
+import numpy as np
+import copy
+
+def collision_check(x0, y0, x1, y1, ground_truth, robot_belief):
+ x0 = x0.round()
+ y0 = y0.round()
+ x1 = x1.round()
+ y1 = y1.round()
+ dx, dy = abs(x1 - x0), abs(y1 - y0)
+ x, y = x0, y0
+ error = dx - dy
+ x_inc = 1 if x1 > x0 else -1
+ y_inc = 1 if y1 > y0 else -1
+ dx *= 2
+ dy *= 2
+
+ collision_flag = 0
+ max_collision = 10
+
+ while 0 <= x < ground_truth.shape[1] and 0 <= y < ground_truth.shape[0]:
+ k = ground_truth.item(y, x)
+ if k == 1 and collision_flag < max_collision:
+ collision_flag += 1
+ if collision_flag >= max_collision:
+ break
+
+ if k !=1 and collision_flag > 0:
+ break
+
+ if x == x1 and y == y1:
+ break
+
+ robot_belief.itemset((y, x), k)
+
+ if error > 0:
+ x += x_inc
+ error -= dy
+ else:
+ y += y_inc
+ error += dx
+
+ return robot_belief
+
+
+def sensor_work(robot_position, sensor_range, robot_belief, ground_truth, sensor_model=SENSOR_MODEL):
+ x0 = robot_position[0]
+ y0 = robot_position[1]
+ rng_x = 0.5 * (ground_truth.shape[1] / NUM_COORDS_WIDTH)
+ rng_y = 0.5 * (ground_truth.shape[0] / NUM_COORDS_HEIGHT)
+
+ if sensor_model == "rectangular": # TODO: add collision check
+ max_x = min(x0 + int(math.ceil(rng_x)), ground_truth.shape[1])
+ min_x = max(x0 - int(math.ceil(rng_x)), 0)
+ max_y = min(y0 + int(math.ceil(rng_y)), ground_truth.shape[0])
+ min_y = max(y0 - int(math.ceil(rng_y)), 0)
+ robot_belief[min_y:max_y, min_x:max_x] = ground_truth[min_y:max_y, min_x:max_x]
+ else:
+ sensor_angle_inc = 0.5 / 180 * np.pi
+ sensor_angle = 0
+ while sensor_angle < 2 * np.pi:
+ x1 = x0 + np.cos(sensor_angle) * sensor_range
+ y1 = y0 + np.sin(sensor_angle) * sensor_range
+ robot_belief = collision_check(x0, y0, x1, y1, ground_truth, robot_belief)
+ sensor_angle += sensor_angle_inc
+ return robot_belief
+
+
+def unexplored_area_check(x0, y0, x1, y1, current_belief):
+ x0 = x0.round()
+ y0 = y0.round()
+ x1 = x1.round()
+ y1 = y1.round()
+ dx, dy = abs(x1 - x0), abs(y1 - y0)
+ x, y = x0, y0
+ error = dx - dy
+ x_inc = 1 if x1 > x0 else -1
+ y_inc = 1 if y1 > y0 else -1
+ dx *= 2
+ dy *= 2
+
+ while 0 <= x < current_belief.shape[1] and 0 <= y < current_belief.shape[0]:
+ k = current_belief.item(y, x)
+ if x == x1 and y == y1:
+ break
+
+ if k == 1:
+ break
+
+ if k == 127:
+ current_belief.itemset((y, x), 0)
+ break
+
+ if error > 0:
+ x += x_inc
+ error -= dy
+ else:
+ y += y_inc
+ error += dx
+
+ return current_belief
+
+
+def calculate_utility(waypoint_position, sensor_range, robot_belief):
+ sensor_angle_inc = 5 / 180 * np.pi
+ sensor_angle = 0
+ x0 = waypoint_position[0]
+ y0 = waypoint_position[1]
+ current_belief = copy.deepcopy(robot_belief)
+ while sensor_angle < 2 * np.pi:
+ x1 = x0 + np.cos(sensor_angle) * sensor_range
+ y1 = y0 + np.sin(sensor_angle) * sensor_range
+ current_belief = unexplored_area_check(x0, y0, x1, y1, current_belief)
+ sensor_angle += sensor_angle_inc
+ utility = np.sum(robot_belief == 127) - np.sum(current_belief == 127)
+ return utility
diff --git a/planner/test_info_surfing.py b/planner/test_info_surfing.py
new file mode 100644
index 0000000000000000000000000000000000000000..306311f2570ffa81838f64b8431885904433647c
--- /dev/null
+++ b/planner/test_info_surfing.py
@@ -0,0 +1,1071 @@
+#######################################################################
+# Name: test_info_surfing.py
+#
+# - Runs robot in environment using Info Surfing Planner
+#######################################################################
+
+import sys
+sys.modules['TRAINING'] = False # False = Inference Testing
+
+import copy
+import os
+import imageio
+import numpy as np
+import matplotlib.pyplot as plt
+from pathlib import Path
+from time import time
+from types import SimpleNamespace
+from skimage.transform import resize
+from taxabind_avs.satbind.kmeans_clustering import CombinedSilhouetteInertiaClusterer
+from .env import Env
+from .test_parameter import *
+
+
+OPPOSITE_ACTIONS = {1: 3, 2: 4, 3: 1, 4: 2, 5: 7, 6: 8, 7: 5, 8: 6}
+# color
+agentColor = (1, 0.2, 0.6)
+agentCommColor = (1, 0.6, 0.2)
+obstacleColor = (0., 0., 0.)
+targetNotFound = (0., 1., 0.)
+targetFound = (0.545, 0.27, 0.075)
+highestProbColor = (1., 0., 0.)
+highestUncertaintyColor = (0., 0., 1.)
+lowestProbColor = (1., 1., 1.)
+
+
+class ISEnv:
+ """Custom Environment that follows gym interface"""
+ metadata = {'render.modes': ['human']}
+
+ def __init__(self, global_step=0, state=None, shape=(24, 24), numAgents=8, observationSize=11, sensorSize=1, diag=False, save_image=False, clip_seg_tta=None):
+
+ self.global_step = global_step
+ self.infoMap = None
+ self.targetMap = None
+ self.agents = []
+ self.targets = []
+ self.numAgents = numAgents
+ self.found_target = []
+ self.shape = shape
+ self.observationSize = observationSize
+ self.sensorSize = sensorSize
+ self.diag = diag
+ self.communicateCircle = 11
+ self.distribs = []
+ self.mask = None
+ self.finished = False
+ self.action_vects = [[-1., 0.], [0., 1.], [1., 0], [0., -1.]] if not diag else [[-1., 0.], [0., 1.], [1., 0], [0., -1.], [-0.707, -0.707], [-0.707, 0.707], [0.707, 0.707], [0.707, -0.707]]
+ self.actionlist = []
+ self.IS_step = 0
+ self.save_image = save_image
+ self.clip_seg_tta = clip_seg_tta
+ self.perf_metrics = dict()
+ self.steps_to_first_tgt = None
+ self.steps_to_mid_tgt = None
+ self.steps_to_last_tgt = None
+ self.targets_found_on_path = []
+ self.step_since_tta = 0
+ self.IS_frame_files = []
+ self.bad_mask_init = False
+
+ # define env
+ self.env = Env(map_index=self.global_step, n_agent=numAgents, k_size=K_SIZE, plot=save_image, test=True)
+
+ # Overwrite state
+ if self.clip_seg_tta is not None:
+ self.clip_seg_tta.reset(sample_idx=self.global_step)
+
+ # Override target positions in env
+ self.env.target_positions = [(pose[1], pose[0]) for pose in self.clip_seg_tta.target_positions]
+
+ # Override segmentation mask
+ if not USE_CLIP_PREDS and OVERRIDE_MASK_DIR != "":
+ score_mask_path = os.path.join(OVERRIDE_MASK_DIR, self.clip_seg_tta.gt_mask_name)
+ print("score_mask_path: ", score_mask_path)
+ if os.path.exists(score_mask_path):
+ self.env.segmentation_mask = self.env.import_segmentation_mask(score_mask_path)
+ self.env.begin(self.env.map_start_position)
+ else:
+ print(f"\n\n{RED}ERROR: Trying to override, but score mask not found at path:{NC} ", score_mask_path)
+ self.bad_mask_init = True
+
+ # Save clustered embeds from sat encoder
+ if USE_CLIP_PREDS:
+ self.kmeans_clusterer = CombinedSilhouetteInertiaClusterer(
+ k_min=1,
+ k_max=8,
+ k_avg_max=4,
+ silhouette_threshold=0.15,
+ relative_threshold=0.15,
+ random_state=0,
+ min_patch_size=5,
+ n_smooth_iter=2,
+ ignore_label=-1,
+ plot=self.save_image,
+ gifs_dir = GIFS_PATH
+ )
+ # Generate kmeans clusters
+ self.kmeans_sat_embeds_clusters = self.kmeans_clusterer.fit_predict(
+ patch_embeds=self.clip_seg_tta.patch_embeds,
+ map_shape=(CLIP_GRIDS_DIMS[0], CLIP_GRIDS_DIMS[1]),
+ )
+
+ if EXECUTE_TTA:
+ print("Will execute TTA...")
+
+ IS_info_map = copy.deepcopy(self.env.segmentation_info_mask)
+ IS_agent_loc = copy.deepcopy(self.env.start_positions)
+ IS_target_loc = copy.deepcopy(self.env.target_positions)
+ state=[IS_info_map, IS_agent_loc, IS_target_loc]
+ self.setWorld(state)
+
+
+ def init_render(self):
+ """
+ Call this once (e.g., in __init__ or just before the scenario loop)
+ to initialize storage for agent paths and turn interactive plotting on.
+ """
+ # Keep track of each agent's trajectory
+ self.trajectories = [[] for _ in range(self.numAgents)]
+ self.trajectories_upscaled = [[] for _ in range(self.numAgents)]
+
+ # Turn on interactive mode so we can update the same figure repeatedly
+ plt.ion()
+ plt.figure(figsize=(6,6))
+ plt.title("Information Map with Agents, Targets, and Sensor Ranges")
+
+
+ def record_positions(self):
+ """
+ Call this after all agents have moved in a step (or whenever you want to update
+ the trajectory). It appends the current positions of each agent to `self.trajectories`.
+ """
+ for idx, agent in enumerate(self.agents):
+ self.trajectories[idx].append((agent.row, agent.col))
+ self.trajectories_upscaled[idx].append(self.env.graph_generator.grid_coords[agent.row, agent.col])
+
+
+ def render(self, episode_num, step_num):
+ """
+ Renders the current state in a single matplotlib plot.
+ Ensures consistent image size for GIF generation.
+ """
+
+ # Completely reset the figure to avoid leftover state
+ plt.close('all')
+ fig = plt.figure(figsize=(6.4, 4.8), dpi=100)
+ ax = fig.add_subplot(111)
+
+ # Plot the information map
+ ax.imshow(self.infoMap, origin='lower', cmap='gray')
+
+ # Show agent positions and their trajectories
+ for idx, agent in enumerate(self.agents):
+ positions = self.trajectories[idx]
+ if len(positions) > 1:
+ rows = [p[0] for p in positions]
+ cols = [p[1] for p in positions]
+ ax.plot(cols, rows, linewidth=1)
+
+ ax.scatter(agent.col, agent.row, marker='o', s=50)
+
+ # Plot target locations
+ for t in self.targets:
+ color = 'green' if np.isnan(t.time_found) else 'red'
+ ax.scatter(t.col, t.row, marker='x', s=100, color=color)
+
+ # Title and axis formatting
+ ax.set_title(f"Step: {self.IS_step}")
+ ax.invert_yaxis()
+
+ # Create output folder if it doesn't exist
+ if not os.path.exists(GIFS_PATH):
+ os.makedirs(GIFS_PATH)
+
+ # Save the frame with consistent canvas
+ frame_path = f'{GIFS_PATH}/IS_{episode_num}_{step_num}.png'
+ plt.savefig(frame_path, bbox_inches='tight', pad_inches=0.1)
+ self.IS_frame_files.append(frame_path)
+
+ # Cleanup
+ plt.close(fig)
+
+
+ def setWorld(self, state=None):
+ """
+ 1. empty all the element
+ 2. create the new episode
+ """
+ if state is not None:
+ self.infoMap = copy.deepcopy(state[0].reshape(self.shape).T)
+ agents = []
+ self.numAgents = len(state[1])
+ for a in range(1, self.numAgents + 1):
+ abs_pos = state[1].pop(0)
+ abs_pos = np.array(abs_pos)
+ row, col = self.env.graph_generator.find_closest_index_from_grid_coords_2d(np.array(abs_pos))
+ agents.append(Agent(ID=a, row=row, col=col, sensorSize=self.sensorSize, infoMap=np.copy(self.infoMap),
+ uncertaintyMap=np.copy(self.infoMap), shape=self.shape, numAgents=self.numAgents))
+ self.agents = agents
+
+ targets, n_targets = [], 1
+ for t in range(len(state[2])):
+ abs_pos = state[2].pop(0)
+ abs_pos = np.array(abs_pos)
+ row, col = self.env.graph_generator.find_closest_index_from_grid_coords_2d(abs_pos)
+ targets.append(Target(ID=n_targets, row=row, col=col, time_found=np.nan))
+ n_targets = n_targets + 1
+ self.targets = targets
+
+ def extractObservation(self, agent):
+ """
+ Extract observations from information map
+ """
+
+ transform_row = self.observationSize // 2 - agent.row
+ transform_col = self.observationSize // 2 - agent.col
+
+ observation_layers = np.zeros((1, self.observationSize, self.observationSize))
+ min_row = max((agent.row - self.observationSize // 2), 0)
+ max_row = min((agent.row + self.observationSize // 2 + 1), self.shape[0])
+ min_col = max((agent.col - self.observationSize // 2), 0)
+ max_col = min((agent.col + self.observationSize // 2 + 1), self.shape[1])
+
+ observation = np.full((self.observationSize, self.observationSize), 0.)
+ infoMap = np.full((self.observationSize, self.observationSize), 0.)
+ densityMap = np.full((self.observationSize, self.observationSize), 0.)
+
+ infoMap[(min_row + transform_row):(max_row + transform_row),
+ (min_col + transform_col):(max_col + transform_col)] = self.infoMap[
+ min_row:max_row, min_col:max_col]
+ observation_layers[0] = infoMap
+
+ return observation_layers
+
+
+ def listNextValidActions(self, agent_id, prev_action=0):
+ """
+ No movement: 0
+ North (-1,0): 1
+ East (0,1): 2
+ South (1,0): 3
+ West (0,-1): 4
+ """
+ available_actions = [0]
+ agent = self.agents[agent_id - 1]
+
+ MOVES = [(-1, 0), (0, 1), (1, 0), (0, -1), (-1, -1), (-1, 1), (1, 1), (1, -1)]
+ size = 4 + self.diag * 4
+ for action in range(size):
+ out_of_bounds = agent.row + MOVES[action][0] >= self.shape[0] \
+ or agent.row + MOVES[action][0] < 0\
+ or agent.col + MOVES[action][1] >= self.shape[1] \
+ or agent.col + MOVES[action][1] < 0
+
+ if (not out_of_bounds) and not (prev_action == OPPOSITE_ACTIONS[action + 1]):
+ available_actions.append(action + 1)
+
+ return np.array(available_actions)
+
+
+ def executeAction(self, agentID, action, timeStep):
+ """
+ No movement: 0
+ North (-1,0): 1
+ East (0,1): 2
+ South (1,0): 3
+ West (0,-1): 4
+ LeftUp (-1,-1) : 5
+ RightUP (-1,1) :6
+ RightDown (1,1) :7
+ RightLeft (1,-1) :8
+ """
+ agent = self.agents[agentID - 1]
+ origLoc = agent.getLocation()
+
+ if (action >= 1) and (action <= 8):
+ agent.move(action)
+ row, col = agent.getLocation()
+
+ # If the move is not valid, roll it back
+ if (row < 0) or (col < 0) or (row >= self.shape[0]) or (col >= self.shape[1]):
+ self.updateInfoCheckTarget(agentID, timeStep, origLoc)
+ return 0
+
+ elif action == 0:
+ self.updateInfoCheckTarget(agentID, timeStep, origLoc)
+ return 0
+
+ else:
+ print("INVALID ACTION: {}".format(action))
+ sys.exit()
+
+ newLoc = agent.getLocation()
+ self.updateInfoCheckTarget(agentID, timeStep, origLoc)
+ return action
+
+
+ def updateInfoCheckTarget(self, agentID, timeStep, origLoc):
+ """
+ update the self.infoMap and check whether the agent has found a target
+ """
+ agent = self.agents[agentID - 1]
+ transform_row = self.sensorSize // 2 - agent.row
+ transform_col = self.sensorSize // 2 - agent.col
+
+ min_row = max((agent.row - self.sensorSize // 2), 0)
+ max_row = min((agent.row + self.sensorSize // 2 + 1), self.shape[0])
+ min_col = max((agent.col - self.sensorSize // 2), 0)
+ max_col = min((agent.col + self.sensorSize // 2 + 1), self.shape[1])
+ for t in self.targets:
+ if (t.row == agent.row) and (t.col == agent.col):
+ t.updateFound(timeStep)
+ self.found_target.append(t)
+ t.status = True
+
+ self.infoMap[min_row:max_row, min_col:max_col] *= 0.05
+
+
+ def updateInfoEntireTrajectory(self, agentID):
+ """
+ update the self.infoMap and check whether the agent has found a target
+ """
+ traj = self.trajectories[agentID - 1]
+
+ for (row,col) in traj:
+ min_row = max((row - self.sensorSize // 2), 0)
+ max_row = min((row + self.sensorSize // 2 + 1), self.shape[0])
+ min_col = max((col - self.sensorSize // 2), 0)
+ max_col = min((col + self.sensorSize // 2 + 1), self.shape[1])
+ self.infoMap[min_row:max_row, min_col:max_col] *= 0.05
+
+
+ # Execute one time step within the environment
+ def step(self, agentID, action, timeStep):
+ """
+ the agents execute the actions
+ No movement: 0
+ North (-1,0): 1
+ East (0,1): 2
+ South (1,0): 3
+ West (0,-1): 4
+ """
+ assert (agentID > 0)
+
+ self.executeAction(agentID, action, timeStep)
+
+
+ def observe(self, agentID):
+ assert (agentID > 0)
+ vectorObs = self.extractObservation(self.agents[agentID - 1])
+ return [vectorObs]
+
+
+ def check_finish(self):
+ if TERMINATE_ON_TGTS_FOUND:
+ found_status = [t.time_found for t in self.targets]
+ d = False
+ if np.isnan(found_status).sum() == 0:
+ d = True
+ return d
+ else:
+ return False
+
+
+ def gradVec(self, observation, agent):
+ a = observation[0]
+
+ # Make info & unc cells with low value as 0
+ a[a < 0.0002] = 0.0
+
+ # Center square from 11x11
+ a_11x11 = a[4:7, 4:7]
+ m_11x11 = np.array((a_11x11))
+
+ # Center square from 9x9
+ a_9x9 = self.pooling(a, (3, 3), stride=(1, 1), method='max', pad=False)
+ a_9x9 = a_9x9[3:6, 3:6]
+ m_9x9 = np.array((a_9x9))
+
+ # Center square from 6x6
+ a_6x6 = self.pooling(a, (6, 6), stride=(1, 1), method='max', pad=False)
+ a_6x6 = a_6x6[1:4, 1:4]
+ m_6x6 = np.array((a_6x6))
+
+ # Center square from 3x3
+ a_3x3 = self.pooling(a, (5, 5), stride=(3, 3), method='max', pad=False)
+ m_3x3 = np.array((a_3x3))
+
+ # Merging multiScales with weights
+ m = m_3x3 * 0.25 + m_6x6 * 0.25 + m_9x9 * 0.25 + m_11x11 * 0.25
+ a = m
+
+ adx, ady = np.gradient(a)
+ den = np.linalg.norm(np.array([adx[1, 1], ady[1, 1]]))
+ if (den != 0) and (not np.isnan(den)):
+ infovec = np.array([adx[1, 1], ady[1, 1]]) / den
+ else:
+ infovec = 0
+ agentvec = []
+
+ if len(agentvec) == 0:
+ den = np.linalg.norm(infovec)
+ if (den != 0) and (not np.isnan(den)):
+ direction = infovec / den
+ else:
+ direction = self.action_vects[np.random.randint(4 + self.diag * 4)]
+ else:
+ den = np.linalg.norm(np.mean(agentvec, 0))
+ if (den != 0) and (not np.isnan(den)):
+ agentvec = np.mean(agentvec, 0) / den
+ else:
+ agentvec = 0
+
+ den = np.linalg.norm(0.6 * infovec + 0.4 * agentvec)
+ if (den != 0) and (not np.isnan(den)):
+ direction = (0.6 * infovec + 0.4 * agentvec) / den
+ else:
+ direction = self.action_vects[np.random.randint(4 + self.diag * 4)]
+
+ action_vec = [[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]] if not self.diag else [[0., 0.], [-1., 0.], [0., 1.], [1., 0], [0., -1.], [-0.707, -0.707], [-0.707, 0.707], [0.707, 0.707], [0.707, -0.707]]
+ actionid = np.argmax([np.dot(direction, a) for a in action_vec])
+ actionid = self.best_valid_action(actionid, agent, direction)
+ return actionid
+
+
+ def best_valid_action(self, actionid, agent, direction):
+ if len(self.actionlist) > 1:
+ if self.action_invalid(actionid, agent):
+ action_vec = [[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]] if not self.diag else [[0., 0.], [-1., 0.], [0., 1.], [1., 0], [0., -1.], [-0.707, -0.707], [-0.707, 0.707], [0.707, 0.707], [0.707, -0.707]]
+ actionid = np.array([np.dot(direction, a) for a in action_vec])
+ actionid = actionid.argsort()
+ pi = 3 + self.diag*4
+ while self.action_invalid(actionid[pi], agent) and pi >= 0:
+ pi -= 1
+ if pi == -1:
+ return OPPOSITE_ACTIONS[self.actionlist[self.IS_step - 1][agent - 1]]
+ elif actionid[pi] == 0:
+ return OPPOSITE_ACTIONS[self.actionlist[self.IS_step - 1][agent - 1]]
+ else:
+ return actionid[pi]
+ return actionid
+
+
+ def action_invalid(self, action, agent):
+ # Going back to the previous cell is disabled
+ if action == OPPOSITE_ACTIONS[self.actionlist[self.IS_step - 1][agent - 1]]:
+ return True
+ # Move N,E,S,W
+ if (action >= 1) and (action <= 8):
+ agent = self.agents[agent - 1]
+ agent.move(action)
+ row, col = agent.getLocation()
+
+ # If the move is not valid, roll it back
+ if ((row < 0) or (col < 0) or (row >= self.shape[0]) or (col >= self.shape[1])):
+ agent.reverseMove(action)
+ return True
+
+ agent.reverseMove(action)
+ return False
+ return False
+
+
+ def step_all_parallel(self):
+ actions = []
+ reward = 0
+ # Decide actions for each agent
+ for agent_id in range(1, self.numAgents + 1):
+ o = self.observe(agent_id)
+ actions.append(self.gradVec(o[0], agent_id))
+ self.actionlist.append(actions)
+
+ # Execute those actions
+ for agent_id in range(1, self.numAgents + 1):
+ self.step(agent_id, actions[agent_id - 1], self.IS_step)
+
+ # Record for visualization
+ self.record_positions()
+
+ def is_scenario(self, max_step=512, episode_number=0):
+
+ # Return all metrics as None if faulty mask init
+ if self.bad_mask_init:
+ self.perf_metrics['tax'] = None
+ self.perf_metrics['travel_dist'] = None
+ self.perf_metrics['travel_steps'] = None
+ self.perf_metrics['steps_to_first_tgt'] = None
+ self.perf_metrics['steps_to_mid_tgt'] = None
+ self.perf_metrics['steps_to_last_tgt'] = None
+ self.perf_metrics['explored_rate'] = None
+ self.perf_metrics['targets_found'] = None
+ self.perf_metrics['targets_total'] = None
+ self.perf_metrics['kmeans_k'] = None
+ self.perf_metrics['tgts_gt_score'] = None
+ self.perf_metrics['clip_inference_time'] = None
+ self.perf_metrics['tta_time'] = None
+ self.perf_metrics['success_rate'] = None
+ return
+
+ eps_start = time()
+ self.IS_step = 0
+ self.finished = False
+ reward = 0
+
+ # Initialize the rendering just once before the loop
+ self.init_render()
+ self.record_positions()
+
+ # Initial Setup
+ if LOAD_AVS_BENCH and USE_CLIP_PREDS:
+ if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]: # If heatmap is resized from clip original dims
+ heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT))
+ self.env.segmentation_info_mask = np.expand_dims(heatmap.T.flatten(), axis=1)
+ unnormalized_heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap_unnormalized, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT))
+ self.env.segmentation_info_mask_unnormalized = np.expand_dims(unnormalized_heatmap.T.flatten(), axis=1)
+ self.infoMap = copy.deepcopy(heatmap)
+ print("Resized heatmap to", NUM_COORDS_WIDTH, "x", NUM_COORDS_HEIGHT)
+ else:
+ self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1)
+ self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1)
+ self.infoMap = copy.deepcopy(self.clip_seg_tta.heatmap)
+
+ self.targets_found_on_path.append(self.env.num_new_targets_found)
+
+ while self.IS_step < max_step and not self.check_finish():
+ self.step_all_parallel()
+ self.IS_step += 1
+
+ # Render after each step
+ if self.save_image:
+ self.render(episode_num=self.global_step, step_num=self.IS_step)
+
+ # Update in env
+ next_position_list = [self.trajectories_upscaled[i][-1] for i, agent in enumerate(self.agents)]
+ dist_list = [0 for _ in range(self.numAgents)]
+ travel_dist_list = [self.compute_travel_distance(traj) for traj in self.trajectories]
+ self.env.multi_robot_step(next_position_list, dist_list, travel_dist_list)
+ self.targets_found_on_path.append(self.env.num_new_targets_found)
+
+ # TTA Update via Poisson Test (with KMeans clustering stats)
+ robot_id = 0 # Assume 1 agent for now
+ robot_traj = self.trajectories[robot_id]
+ if LOAD_AVS_BENCH and USE_CLIP_PREDS and EXECUTE_TTA:
+ flat_traj_coords = [robot_traj[i][1] * self.shape[0] + robot_traj[i][0] for i in range(len(robot_traj))]
+ robot = SimpleNamespace(
+ trajectory_coords=flat_traj_coords,
+ targets_found_on_path=self.targets_found_on_path
+ )
+ self.poisson_tta_update(robot, self.global_step, self.IS_step)
+ self.infoMap = copy.deepcopy(self.env.segmentation_info_mask.reshape((self.shape[1],self.shape[0])).T)
+ self.updateInfoEntireTrajectory(robot_id)
+
+ # Update metrics
+ self.log_metrics(step=self.IS_step-1)
+
+ ### Save a frame to generate gif of robot trajectories ###
+ if self.save_image:
+ robots_route = [ ([], []) ] # Assume 1 robot
+ for point in self.trajectories_upscaled[robot_id]:
+ robots_route[robot_id][0].append(point[0])
+ robots_route[robot_id][1].append(point[1])
+ if not os.path.exists(GIFS_PATH):
+ os.makedirs(GIFS_PATH)
+ if LOAD_AVS_BENCH:
+ sound_id_override = None if self.clip_seg_tta.sound_ids == [] else self.clip_seg_tta.sound_ids[0]
+ self.env.plot_env(
+ self.global_step,
+ GIFS_PATH,
+ self.IS_step-1,
+ max(travel_dist_list),
+ robots_route,
+ img_path_override=self.clip_seg_tta.img_paths[0], # Viz 1st
+ sat_path_override=self.clip_seg_tta.imo_path,
+ msk_name_override=self.clip_seg_tta.species_name,
+ sound_id_override=sound_id_override,
+ )
+ else:
+ self.env.plot_env(
+ self.global_step,
+ GIFS_PATH,
+ self.IS_step-1,
+ max(travel_dist_list),
+ robots_route
+ )
+
+ # Log metrics
+ if LOAD_AVS_BENCH:
+ tax = Path(self.clip_seg_tta.gt_mask_name).stem
+ self.perf_metrics['tax'] = " ".join(tax.split("_")[1:])
+ else:
+ self.perf_metrics['tax'] = None
+ travel_distances = [self.compute_travel_distance(traj) for traj in self.trajectories]
+ self.perf_metrics['travel_dist'] = max(travel_distances)
+ self.perf_metrics['travel_steps'] = self.IS_step
+ self.perf_metrics['steps_to_first_tgt'] = self.steps_to_first_tgt
+ self.perf_metrics['steps_to_mid_tgt'] = self.steps_to_mid_tgt
+ self.perf_metrics['steps_to_last_tgt'] = self.steps_to_last_tgt
+ self.perf_metrics['explored_rate'] = self.env.explored_rate
+ self.perf_metrics['targets_found'] = self.env.targets_found_rate
+ self.perf_metrics['targets_total'] = len(self.env.target_positions)
+ if USE_CLIP_PREDS:
+ self.perf_metrics['kmeans_k'] = self.kmeans_clusterer.final_k
+ self.perf_metrics['tgts_gt_score'] = self.clip_seg_tta.tgts_gt_score
+ self.perf_metrics['clip_inference_time'] = self.clip_seg_tta.clip_inference_time
+ self.perf_metrics['tta_time'] = self.clip_seg_tta.tta_time
+ else:
+ self.perf_metrics['kmeans_k'] = None
+ self.perf_metrics['tgts_gt_score'] = None
+ self.perf_metrics['clip_inference_time'] = None
+ self.perf_metrics['tta_time'] = None
+ if FORCE_LOGGING_DONE_TGTS_FOUND and self.env.targets_found_rate == 1.0:
+ self.perf_metrics['success_rate'] = True
+ else:
+ self.perf_metrics['success_rate'] = self.env.check_done()[0]
+
+ # save gif
+ if self.save_image:
+ path = GIFS_PATH
+ self.make_gif(path, self.global_step)
+
+ print(YELLOW, f"[Eps {episode_number} Completed] Time Taken: {time()-eps_start:.2f}s, Steps: {self.IS_step}", NC)
+
+
+ def asStride(self, arr, sub_shape, stride):
+ """
+ Get a strided sub-matrices view of an ndarray.
+ See also skimage.util.shape.view_as_windows()
+ """
+ s0, s1 = arr.strides[:2]
+ m1, n1 = arr.shape[:2]
+ m2, n2 = sub_shape
+ view_shape = (1+(m1-m2)//stride[0], 1+(n1-n2)//stride[1], m2, n2)+arr.shape[2:]
+ strides = (stride[0]*s0, stride[1]*s1, s0, s1)+arr.strides[2:]
+ subs = np.lib.stride_tricks.as_strided(arr, view_shape, strides=strides)
+ return subs
+
+
+ def pooling(self, mat, ksize, stride=None, method='max', pad=False):
+ """
+ Overlapping pooling on 2D or 3D data.
+
+ : ndarray, input array to pool.
+ : tuple of 2, kernel size in (ky, kx).
+ : tuple of 2 or None, stride of pooling window.
+ If None, same as (non-overlapping pooling).
+ : str, 'max for max-pooling,
+ 'mean' for mean-pooling.
+ : bool, pad or not. If no pad, output has size
+ (n-f)//s+1, n being size, f being kernel size, s stride.
+ if pad, output has size ceil(n/s).
+
+ Return : pooled matrix.
+ """
+
+ m, n = mat.shape[:2]
+ ky, kx = ksize
+ if stride is None:
+ stride = (ky, kx)
+ sy, sx = stride
+
+ _ceil = lambda x, y: int(np.ceil(x/float(y)))
+
+ if pad:
+ ny = _ceil(m,sy)
+ nx = _ceil(n,sx)
+ size = ((ny-1)*sy+ky, (nx-1)*sx+kx) + mat.shape[2:]
+ mat_pad = np.full(size,np.nan)
+ mat_pad[:m,:n,...] = mat
+ else:
+ mat_pad = mat[:(m-ky)//sy*sy+ky, :(n-kx)//sx*sx+kx, ...]
+
+ view = self.asStride(mat_pad,ksize,stride)
+
+ if method == 'max':
+ result = np.nanmax(view,axis=(2,3))
+ else:
+ result = np.nanmean(view,axis=(2,3))
+
+ return result
+
+
+ def compute_travel_distance(self, trajectory):
+ distance = 0.0
+ for i in range(1, len(trajectory)):
+ # Convert the tuple positions to numpy arrays for easy computation.
+ prev_pos = np.array(trajectory[i-1])
+ curr_pos = np.array(trajectory[i])
+ # Euclidean distance between consecutive positions.
+ distance += np.linalg.norm(curr_pos - prev_pos)
+ return distance
+
+ ################################################################################
+ # SPPP Related Fns
+ ################################################################################
+
+ def log_metrics(self, step):
+ # Update tgt found metrics
+ if self.steps_to_first_tgt is None and self.env.num_targets_found == 1:
+ self.steps_to_first_tgt = step + 1
+ if self.steps_to_mid_tgt is None and self.env.num_targets_found == int(len(self.env.target_positions) / 2):
+ self.steps_to_mid_tgt = step + 1
+ if self.steps_to_last_tgt is None and self.env.num_targets_found == len(self.env.target_positions):
+ self.steps_to_last_tgt = step + 1
+
+
+ def transpose_flat_idx(self, idx, H=NUM_COORDS_HEIGHT, W=NUM_COORDS_WIDTH):
+ """
+ Transpose a flat index from an ``HΓW`` grid to the equivalent
+ position in the ``WΓH`` transposed grid while **keeping the result
+ in 1-D**.
+ """
+ # --- Safety check to catch out-of-range indices ---
+ assert 0 <= idx < H * W, f"idx {idx} out of bounds for shape ({H}, {W})"
+
+ # Original (row, col)
+ row, col = divmod(idx, W)
+ # After transpose these coordinates swap
+ row_T, col_T = col, row
+
+ # Flatten back into 1-D (row-major) for the WΓH grid
+ return row_T * H + col_T
+
+
+ def poisson_tta_update(self, robot, episode, step):
+
+ # Generate Kmeans Clusters Stats
+ # Scale index back to CLIP_GRIDS_DIMS to be compatible with CLIP patch size
+ if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]:
+ # High-res remap via pixel coordinates preserves exact neighbourhood
+ filt_traj_coords, filt_targets_found_on_path = self.scale_trajectory(
+ robot.trajectory_coords,
+ self.env.target_positions,
+ old_dims=(NUM_COORDS_HEIGHT, NUM_COORDS_WIDTH),
+ full_dims=(512, 512),
+ new_dims=(CLIP_GRIDS_DIMS[0], CLIP_GRIDS_DIMS[1])
+ )
+ else:
+ filt_traj_coords = [self.transpose_flat_idx(idx) for idx in robot.trajectory_coords]
+ filt_targets_found_on_path = robot.targets_found_on_path
+
+ region_stats_dict = self.kmeans_clusterer.compute_region_statistics(
+ self.kmeans_sat_embeds_clusters,
+ self.clip_seg_tta.heatmap_unnormalized,
+ filt_traj_coords,
+ episode_num=episode,
+ step_num=step
+ )
+
+ # Prep & execute TTA
+ self.step_since_tta += 1
+ if robot.targets_found_on_path[-1] or self.step_since_tta % STEPS_PER_TTA == 0:
+
+ num_cells = self.clip_seg_tta.heatmap.shape[0] * self.clip_seg_tta.heatmap.shape[1]
+ pos_sample_weight_scale, neg_sample_weight_scale = [], []
+
+ for i, sample_loc in enumerate(filt_traj_coords):
+ label = self.kmeans_clusterer.get_label_id(self.kmeans_sat_embeds_clusters, sample_loc)
+ num_patches = region_stats_dict[label]['num_patches']
+ patches_visited = region_stats_dict[label]['patches_visited']
+ expectation = region_stats_dict[label]['expectation']
+
+ # Exponent like focal loss to wait for more samples before confidently decreasing
+ pos_weight = 4.0
+ neg_weight = min(1.0, (patches_visited/(3*num_patches))**GAMMA_EXPONENT)
+ pos_sample_weight_scale.append(pos_weight)
+ neg_sample_weight_scale.append(neg_weight)
+
+ # Adaptative LR (as samples increase, increase LR to fit more datapoints)
+ adaptive_lr = MIN_LR + (MAX_LR - MIN_LR) * (step / num_cells)
+
+ # TTA Update
+ self.clip_seg_tta.execute_tta(
+ filt_traj_coords,
+ filt_targets_found_on_path,
+ tta_steps=NUM_TTA_STEPS,
+ lr=adaptive_lr,
+ pos_sample_weight=pos_sample_weight_scale,
+ neg_sample_weight=neg_sample_weight_scale,
+ reset_weights=RESET_WEIGHTS
+ )
+ if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]: # If heatmap is resized from clip original dims
+ heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT))
+ self.env.segmentation_info_mask = np.expand_dims(heatmap.T.flatten(), axis=1)
+ unnormalized_heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap_unnormalized, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT))
+ self.env.segmentation_info_mask_unnormalized = np.expand_dims(unnormalized_heatmap.T.flatten(), axis=1)
+ print("~Resized heatmap to", NUM_COORDS_WIDTH, "x", NUM_COORDS_HEIGHT)
+ else:
+ self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1)
+ self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1)
+
+ self.step_since_tta = 0
+
+
+ def convert_heatmap_resolution(self, heatmap, full_dims=(512, 512), new_dims=(24, 24)):
+ heatmap_large = resize(heatmap, full_dims, order=1, # order=1 β bilinear
+ mode='reflect', anti_aliasing=True)
+
+ coords = self.env.graph_generator.grid_coords # (N, N, 2)
+ rows, cols = coords[...,1], coords[...,0]
+ heatmap_resized = heatmap_large[rows, cols]
+ heatmap_resized = heatmap_resized.reshape(new_dims[1], new_dims[0])
+ return heatmap_resized
+
+
+ def convert_labelmap_resolution(self, labelmap, full_dims=(512, 512), new_dims=(24, 24)):
+ """
+ 1) Upsample via nearestβneighbor to full_dims
+ 2) Sample back down to your graph grid using grid_coords
+ """
+ # 1) Upsample with nearestβneighbor, preserving integer labels
+ up = resize(
+ labelmap,
+ full_dims,
+ order=0, # nearestβneighbor
+ mode='edge', # padding mode
+ preserve_range=True, # don't normalize labels
+ anti_aliasing=False # must be False for labels
+ ).astype(labelmap.dtype) # back to original integer dtype
+
+ # 2) Downsample via your precomputed grid coords (NΓNΓ2)
+ coords = self.env.graph_generator.grid_coords # shape (N, N, 2)
+ rows = coords[...,1].astype(int)
+ cols = coords[...,0].astype(int)
+
+ small = up[rows, cols] # shape (N, N)
+ small = small.reshape(new_dims[0], new_dims[1])
+ return small
+
+
+ def scale_trajectory(self,
+ flat_indices,
+ targets,
+ old_dims=(17, 17),
+ full_dims=(512, 512),
+ new_dims=(24, 24)):
+ """
+ Args:
+ flat_indices: list of ints in [0..old_H*old_W-1]
+ targets: list of (y_pix, x_pix) in [0..full_H-1]
+ old_dims: (old_H, old_W)
+ full_dims: (full_H, full_W)
+ new_dims: (new_H, new_W)
+
+ Returns:
+ new_flat_traj: list of unique flattened indices in new_HΓnew_W
+ counts: list of ints, same length as new_flat_traj
+ """
+ old_H, old_W = old_dims
+ full_H, full_W = full_dims
+ new_H, new_W = new_dims
+
+ # 1) bin targets into new grid
+ cell_h_new = full_H / new_H
+ cell_w_new = full_W / new_W
+ grid_counts = [[0]*new_W for _ in range(new_H)]
+ for x_pix, y_pix in targets: # note (x, y) order as in original implementation
+ i_t = min(int(y_pix / cell_h_new), new_H - 1)
+ j_t = min(int(x_pix / cell_w_new), new_W - 1)
+ grid_counts[i_t][j_t] += 1
+
+ # 2) Walk the trajectory indices and project each old cell's *entire
+ # pixel footprint* onto the finer 24Γ24 grid.
+ cell_h_full = full_H / old_H
+ cell_w_full = full_W / old_W
+
+ seen = set()
+ new_flat_traj = []
+
+ for node_idx in flat_indices:
+ if node_idx < 0 or node_idx >= len(self.env.graph_generator.node_coords):
+ continue
+
+ coord_xy = self.env.graph_generator.node_coords[node_idx]
+ try:
+ row_old, col_old = self.env.graph_generator.find_index_from_grid_coords_2d(coord_xy)
+ except Exception:
+ continue
+
+ # Bounding box of the old cell in full-resolution pixel space
+ y0 = row_old * cell_h_full
+ y1 = (row_old + 1) * cell_h_full
+ x0 = col_old * cell_w_full
+ x1 = (col_old + 1) * cell_w_full
+
+ # Which new-grid rows & cols overlap? (inclusive ranges)
+ i_start = max(0, min(int(y0 / cell_h_new), new_H - 1))
+ i_end = max(0, min(int((y1 - 1) / cell_h_new), new_H - 1))
+ j_start = max(0, min(int(x0 / cell_w_new), new_W - 1))
+ j_end = max(0, min(int((x1 - 1) / cell_w_new), new_W - 1))
+
+ for ii in range(i_start, i_end + 1):
+ for jj in range(j_start, j_end + 1):
+ f_new = ii * new_W + jj
+ if f_new not in seen:
+ seen.add(f_new)
+ new_flat_traj.append(f_new)
+
+ # 3) annotate counts
+ counts = []
+ for f in new_flat_traj:
+ i_new, j_new = divmod(f, new_W)
+ counts.append(grid_counts[i_new][j_new])
+
+ return new_flat_traj, counts
+
+
+ ################################################################################
+
+ def make_gif(self, path, n):
+ """ Generate a gif given list of images """
+ with imageio.get_writer('{}/{}_target_rate_{:.2f}.gif'.format(path, n, self.env.targets_found_rate), mode='I',
+ fps=5) as writer:
+ for frame in self.env.frame_files:
+ image = imageio.imread(frame)
+ writer.append_data(image)
+ print('gif complete\n')
+
+ # Remove files
+ for filename in self.env.frame_files[:-1]:
+ os.remove(filename)
+
+ # For KMeans gif
+ if LOAD_AVS_BENCH and USE_CLIP_PREDS:
+ with imageio.get_writer('{}/{}_kmeans_stats.gif'.format(path, n), mode='I',
+ fps=5) as writer:
+ for frame in self.kmeans_clusterer.kmeans_frame_files:
+ image = imageio.imread(frame)
+ writer.append_data(image)
+ print('Kmeans Clusterer gif complete\n')
+
+ # Remove files
+ for filename in self.kmeans_clusterer.kmeans_frame_files[:-1]:
+ os.remove(filename)
+
+
+ # IS gif
+ with imageio.get_writer('{}/{}_IS.gif'.format(path, n), mode='I',
+ fps=5) as writer:
+ for frame in self.IS_frame_files:
+ image = imageio.imread(frame)
+ writer.append_data(image)
+ print('Kmeans Clusterer gif complete\n')
+
+ # Remove files
+ for filename in self.IS_frame_files[:-1]:
+ os.remove(filename)
+
+ ################################################################################
+
+
+class Agent:
+ def __init__(self, ID, infoMap=None, uncertaintyMap=None, shape=None, row=0, col=0, sensorSize=9, numAgents=8):
+ self.ID = ID
+ self.row = row
+ self.col = col
+ self.numAgents = numAgents
+ self.sensorSize = sensorSize
+
+ def setLocation(self, row, col):
+ self.row = row
+ self.col = col
+
+ def getLocation(self):
+ return [self.row, self.col]
+
+ def move(self, action):
+ """
+ No movement: 0
+ North (-1,0): 1
+ East (0,1): 2
+ South (1,0): 3
+ West (0,-1): 4
+ LeftUp (-1,-1) : 5
+ RightUP (-1,1) :6
+ RightDown (1,1) :7
+ RightLeft (1,-1) :8
+ check valid action of the agent. be sure not to be out of the boundary
+ """
+ if action == 0:
+ return 0
+ elif action == 1:
+ self.row -= 1
+ elif action == 2:
+ self.col += 1
+ elif action == 3:
+ self.row += 1
+ elif action == 4:
+ self.col -= 1
+ elif action == 5:
+ self.row -= 1
+ self.col -= 1
+ elif action == 6:
+ self.row -= 1
+ self.col += 1
+ elif action == 7:
+ self.row += 1
+ self.col += 1
+ elif action == 8:
+ self.row += 1
+ self.col -= 1
+
+ def reverseMove(self, action):
+ if action == 0:
+ return 0
+ elif action == 1:
+ self.row += 1
+ elif action == 2:
+ self.col -= 1
+ elif action == 3:
+ self.row -= 1
+ elif action == 4:
+ self.col += 1
+ elif action == 5:
+ self.row += 1
+ self.col += 1
+ elif action == 6:
+ self.row += 1
+ self.col -= 1
+ elif action == 7:
+ self.row -= 1
+ self.col -= 1
+ elif action == 8:
+ self.row -= 1
+ self.col += 1
+ else:
+ print("agent can only move NESW/1234")
+ sys.exit()
+
+
+class Target:
+ def __init__(self, row, col, ID, time_found=np.nan):
+ self.row = row
+ self.col = col
+ self.ID = ID
+ self.time_found = time_found
+ self.status = None
+ self.time_visited = time_found
+
+ def getLocation(self):
+ return self.row, self.col
+
+ def updateFound(self, timeStep):
+ if np.isnan(self.time_found):
+ self.time_found = timeStep
+
+ def updateVisited(self, timeStep):
+ if np.isnan(self.time_visited):
+ self.time_visited = timeStep
+
+
+if __name__ == "__main__":
+
+ search_env = Env(map_index=1, k_size=K_SIZE, n_agent=NUM_ROBOTS, plot=SAVE_GIFS)
+
+ IS_info_map = search_env.segmentation_info_mask
+ IS_agent_loc = search_env.start_positions
+ IS_target_loc = [[312, 123], [123, 312], [312, 312], [123, 123]]
+
+ env = ISEnv(state=[IS_info_map, IS_agent_loc, IS_target_loc], shape=(NUM_COORDS_HEIGHT, NUM_COORDS_WIDTH))
+ env.is_scenario(NUM_EPS_STEPS)
+ print()
diff --git a/planner/test_parameter.py b/planner/test_parameter.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fd0808a446b306d7e43397d2674ca160f03158a
--- /dev/null
+++ b/planner/test_parameter.py
@@ -0,0 +1,118 @@
+############################################################################################
+# Name: test_parameter.py
+#
+# NOTE: Change all your hyper-params here for eval
+# Simple How-To Guide:
+# 1. CLIP TTA: USE_CLIP_PREDS=True, EXECUTE_TTA=True
+# 2. CLIP (No TTA): USE_CLIP_PREDS=True, EXECUTE_TTA=False
+# 3. Custom masks (e.g. LISA): USE_CLIP_PREDS=False, EXECUTE_TTA=False
+############################################################################################
+
+import os
+import sys
+sys.modules['TRAINING'] = False # False = Inference Testing
+
+###############################################################
+# Overload Params
+###############################################################
+
+OPT_VARS = {}
+def getenv(var_name, default=None, cast_type=str):
+ try:
+ value = os.environ.get(var_name, None)
+ if value is None:
+ result = default
+ elif cast_type == bool:
+ result = value.lower() in ("true", "1", "yes")
+ else:
+ result = cast_type(value)
+ except (ValueError, TypeError):
+ result = default
+
+ OPT_VARS[var_name] = result # Log the result
+ return result
+
+###############################################################
+# General
+###############################################################
+
+# --- GENERAL --- #
+USE_GPU = False
+NUM_GPU = getenv("NUM_GPU", default=1, cast_type=int) # the number of GPUs
+NUM_META_AGENT = getenv("NUM_META_AGENT", default=2, cast_type=int) # the number of concurrent processes
+NUM_EPS_STEPS = getenv("NUM_EPS_STEPS", default=400, cast_type=int)
+FIX_START_POSITION = getenv("FIX_START_POSITION", default=True, cast_type=bool) # Whether to fix the starting position of the robots (middle index)
+NUM_ROBOTS = 1 # Only allow for 1 robot
+NUM_COORDS_WIDTH=24 # How many node coords across width?
+NUM_COORDS_HEIGHT=24 # How many node coords across height?
+CLIP_GRIDS_DIMS=[24,24] # [16,16] if 'openai/clip-vit-large-patch14-336'
+SENSOR_RANGE=80 # Only applicable to 'circle' sensor model
+SENSOR_MODEL="rectangular" # "rectangular", "circle" (NOTE: no colllision check for rectangular)
+TERMINATE_ON_TGTS_FOUND = True # Whether to terminate episode when all targets found
+FORCE_LOGGING_DONE_TGTS_FOUND = True # Whether to force csv logging when all targets found
+
+
+# --- Planner Params --- #
+POLICY = getenv("POLICY", default="RL", cast_type=str)
+NUM_TEST = 800 # Overriden if LOAD_AVS_BENCH
+NUM_RUN = 1
+MODEL_NAME = "avs_rl_policy.pth"
+INPUT_DIM = 4
+EMBEDDING_DIM = 128
+K_SIZE = 8
+
+
+# --- Folders & Visualizations --- #
+GRIDMAP_SET_DIR = "maps/gpt4o/envs_val"
+MASK_SET_DIR = "maps/example/masks_val" # Overriden if LOAD_AVS_BENCH
+TARGETS_SET_DIR = ""
+# TARGETS_SET_DIR = "maps/example/gt_masks_val_with_tgts" # Overriden if LOAD_AVS_BENCH
+OVERRIDE_MASK_DIR = getenv("OVERRIDE_MASK_DIR", default="", cast_type=str) # Override initial score mask from CLIP
+SAVE_GIFS = getenv("SAVE_GIFS", default=True, cast_type=bool) # do you want to save GIFs
+FOLDER_NAME = 'avs_search'
+MODEL_PATH = f'inference/model'
+GIFS_PATH = f'inference/test_results/gifs/{FOLDER_NAME}'
+LOG_PATH = f'inference/test_results/log/{FOLDER_NAME}'
+LOG_TEMPLATE_XLSX = f'inference/template.xlsx'
+CSV_EXPT_NAME = getenv("CSV_EXPT_NAME", default="data", cast_type=str)
+VIZ_GRAPH_EDGES = False # do you want to visualize the graph edges
+
+
+#######################################################################
+# AVS Params
+#######################################################################
+
+# General PARAMS
+USE_CLIP_PREDS = getenv("USE_CLIP_PREDS", default=True, cast_type=bool) # If false, use custom masks from OVERRIDE_MASK_DIR
+QUERY_TAX = getenv("QUERY_TAX", default="", cast_type=str) # "" = Test all tax (can accept taxonomy substrings)
+EXECUTE_TTA = getenv("EXECUTE_TTA", default=True, cast_type=bool) # Whether to execute TTA mask updates
+QUERY_MODALITY = getenv("QUERY_MODALITY", default="image", cast_type=str) # "image", "text", "sound"
+STEPS_PER_TTA = 20 # no. steps before each TTA series
+NUM_TTA_STEPS = 1 # no. of TTA steps during each series
+RESET_WEIGHTS = True
+MIN_LR = 1e-6
+MAX_LR = 1e-5
+GAMMA_EXPONENT = 2
+
+# Paths related to AVS (TRAIN w/ TARGETS)
+LOAD_AVS_BENCH = True # Whether to init AVS datasets
+AVS_IMG_DIR = '/mnt/hdd/avs_bench_ds/inat21'
+AVS_IMO_DIR = '/mnt/hdd/avs_bench_ds/sat_jpg/train_512px'
+AVS_INAT_JSON_PATH = '/mnt/hdd/avs_bench_ds/inat21/train.json'
+AVS_SOUND_DIR = '/mnt/hdd/avs_bench_ds/sound_mp3/test'
+AVS_GAUSSIAN_BLUR_KERNEL = (5,5)
+AVS_SAT_TO_IMG_IDS_PATH = getenv("AVS_SAT_TO_IMG_IDS_PATH", default="search_tri_modal|val_in_domain", cast_type=str)
+AVS_LOAD_PRETRAINED_HF_CHECKPOINT = getenv("AVS_LOAD_PRETRAINED_HF_CHECKPOINT", default=True, cast_type=bool) # If false, load locally using CHECKPOINT_PATHs
+AVS_SAT_CHECKPOINT_PATH = getenv("AVS_SAT_CHECKPOINT_PATH", default="", cast_type=str)
+AVS_SOUND_CHECKPOINT_PATH = getenv("AVS_SOUND_CHECKPOINT_PATH", default="", cast_type=str)
+
+#######################################################################
+# UTILS
+#######################################################################
+
+# COLORS (for printing)
+RED='\033[1;31m'
+GREEN='\033[1;32m'
+YELLOW='\033[1;93m'
+NC_BOLD='\033[1m' # Bold, No Color
+NC='\033[0m' # No Color
diff --git a/planner/test_worker.py b/planner/test_worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..f96c4b935e779761b4ce323822755d66e8676ba3
--- /dev/null
+++ b/planner/test_worker.py
@@ -0,0 +1,590 @@
+#######################################################################
+# Name: test_worker.py
+#
+# - Runs robot in environment using RL Planner
+#######################################################################
+
+from .test_parameter import *
+
+import imageio
+import os
+import copy
+import numpy as np
+import torch
+from time import time
+from pathlib import Path
+from skimage.transform import resize
+from taxabind_avs.satbind.kmeans_clustering import CombinedSilhouetteInertiaClusterer
+from .env import Env
+from .robot import Robot
+
+np.seterr(invalid='raise', divide='raise')
+
+
+class TestWorker:
+ def __init__(self, meta_agent_id, n_agent, policy_net, global_step, device='cuda', greedy=False, save_image=False, clip_seg_tta=None):
+ self.device = device
+ self.greedy = greedy
+ self.n_agent = n_agent
+ self.metaAgentID = meta_agent_id
+ self.global_step = global_step
+ self.k_size = K_SIZE
+ self.save_image = save_image
+ self.clip_seg_tta = clip_seg_tta
+ self.execute_tta = EXECUTE_TTA # Added to interface with app.py
+
+ self.env = Env(map_index=self.global_step, n_agent=n_agent, k_size=self.k_size, plot=save_image, test=True)
+ self.local_policy_net = policy_net
+
+ self.robot_list = []
+ self.all_robot_positions = []
+ for i in range(self.n_agent):
+ robot_position = self.env.start_positions[i]
+ robot = Robot(robot_id=i, position=robot_position, plot=save_image)
+ self.robot_list.append(robot)
+ self.all_robot_positions.append(robot_position)
+
+ self.perf_metrics = dict()
+ self.bad_mask_init = False
+
+ # NOTE: Option to override gifs_path to interface with app.py
+ self.gifs_path = GIFS_PATH
+
+ # NOTE: updated due to app.py (hf does not allow heatmap to persist)
+ if LOAD_AVS_BENCH:
+ if clip_seg_tta is not None:
+ heatmap, heatmap_unnormalized, heatmap_unnormalized_initial, patch_embeds = self.clip_seg_tta.reset(sample_idx=self.global_step)
+ self.clip_seg_tta.heatmap = heatmap
+ self.clip_seg_tta.heatmap_unnormalized = heatmap_unnormalized
+ self.clip_seg_tta.heatmap_unnormalized_initial = heatmap_unnormalized_initial
+ self.clip_seg_tta.patch_embeds = patch_embeds
+
+ # Override target positions in env
+ self.env.target_positions = [(pose[1], pose[0]) for pose in self.clip_seg_tta.target_positions]
+
+ # Override segmentation mask
+ if not USE_CLIP_PREDS and OVERRIDE_MASK_DIR != "":
+ score_mask_path = os.path.join(OVERRIDE_MASK_DIR, self.clip_seg_tta.gt_mask_name)
+ print("score_mask_path: ", score_mask_path)
+ if os.path.exists(score_mask_path):
+ self.env.segmentation_mask = self.env.import_segmentation_mask(score_mask_path)
+ self.env.begin(self.env.map_start_position)
+ else:
+ print(f"\n\n{RED}ERROR: Trying to override, but score mask not found at path:{NC} ", score_mask_path)
+ self.bad_mask_init = True
+
+ # Save clustered embeds from sat encoder
+ if USE_CLIP_PREDS:
+ self.kmeans_clusterer = CombinedSilhouetteInertiaClusterer(
+ k_min=1,
+ k_max=8,
+ k_avg_max=4,
+ silhouette_threshold=0.15,
+ relative_threshold=0.15,
+ random_state=0,
+ min_patch_size=5,
+ n_smooth_iter=2,
+ ignore_label=-1,
+ plot=self.save_image,
+ gifs_dir = GIFS_PATH
+ )
+ # Generate kmeans clusters
+ self.kmeans_sat_embeds_clusters = self.kmeans_clusterer.fit_predict(
+ patch_embeds=self.clip_seg_tta.patch_embeds,
+ map_shape=(CLIP_GRIDS_DIMS[0], CLIP_GRIDS_DIMS[1]),
+ )
+ print("Chosen k:", self.kmeans_clusterer.final_k)
+
+ # if EXECUTE_TTA:
+ # print("Will execute TTA...")
+
+ # Define Poisson TTA params
+ self.step_since_tta = 0
+ self.steps_to_first_tgt = None
+ self.steps_to_mid_tgt = None
+ self.steps_to_last_tgt = None
+
+
+ def run_episode(self, curr_episode):
+
+ # Return all metrics as None if faulty mask init
+ if self.bad_mask_init:
+ self.perf_metrics['tax'] = None
+ self.perf_metrics['travel_dist'] = None
+ self.perf_metrics['travel_steps'] = None
+ self.perf_metrics['steps_to_first_tgt'] = None
+ self.perf_metrics['steps_to_mid_tgt'] = None
+ self.perf_metrics['steps_to_last_tgt'] = None
+ self.perf_metrics['explored_rate'] = None
+ self.perf_metrics['targets_found'] = None
+ self.perf_metrics['targets_total'] = None
+ self.perf_metrics['kmeans_k'] = None
+ self.perf_metrics['tgts_gt_score'] = None
+ self.perf_metrics['clip_inference_time'] = None
+ self.perf_metrics['tta_time'] = None
+ self.perf_metrics['success_rate'] = None
+ return
+
+ eps_start = time()
+ done = False
+ for robot_id, deciding_robot in enumerate(self.robot_list):
+ deciding_robot.observations = self.get_observations(deciding_robot.robot_position)
+ if LOAD_AVS_BENCH and USE_CLIP_PREDS:
+ if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]: # If heatmap is resized from clip original dims
+ heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT))
+ self.env.segmentation_info_mask = np.expand_dims(heatmap.T.flatten(), axis=1)
+ unnormalized_heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap_unnormalized, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT))
+ self.env.segmentation_info_mask_unnormalized = np.expand_dims(unnormalized_heatmap.T.flatten(), axis=1)
+ print("Resized heatmap to", NUM_COORDS_WIDTH, "x", NUM_COORDS_HEIGHT)
+ else:
+ self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1)
+ self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1)
+
+ ### Run episode ###
+ for step in range(NUM_EPS_STEPS):
+
+ next_position_list = []
+ dist_list = []
+ travel_dist_list = []
+ dist_array = np.zeros((self.n_agent, 1))
+ for robot_id, deciding_robot in enumerate(self.robot_list):
+ observations = deciding_robot.observations
+
+ ### Forward pass through policy to get next position ###
+ next_position, action_index = self.select_node(observations)
+ dist = np.linalg.norm(next_position - deciding_robot.robot_position)
+
+ ### Log results of action (e.g. distance travelled) ###
+ dist_array[robot_id] = dist
+ dist_list.append(dist)
+ travel_dist_list.append(deciding_robot.travel_dist)
+ next_position_list.append(next_position)
+ self.all_robot_positions[robot_id] = next_position
+
+ arriving_sequence = np.argsort(dist_list)
+ next_position_list = np.array(next_position_list)
+ dist_list = np.array(dist_list)
+ travel_dist_list = np.array(travel_dist_list)
+ next_position_list = next_position_list[arriving_sequence]
+ dist_list = dist_list[arriving_sequence]
+ travel_dist_list = travel_dist_list[arriving_sequence]
+
+ ### Take Action (Deconflict if 2 agents choose the same target position) ###
+ next_position_list, dist_list = self.solve_conflict(arriving_sequence, next_position_list, dist_list)
+ reward_list, done = self.env.multi_robot_step(next_position_list, dist_list, travel_dist_list)
+
+ ### Update observations + rewards from action ###
+ for reward, robot_id in zip(reward_list, arriving_sequence):
+ robot = self.robot_list[robot_id]
+ robot.save_trajectory_coords(self.env.find_index_from_coords(robot.robot_position), self.env.num_new_targets_found)
+
+ # # TTA Update via Poisson Test (with KMeans clustering stats)
+ if LOAD_AVS_BENCH and USE_CLIP_PREDS and self.execute_tta:
+ self.poisson_tta_update(robot, self.global_step, step)
+
+ robot.observations = self.get_observations(robot.robot_position)
+ robot.save_reward_done(reward, done)
+
+ # Update metrics
+ self.log_metrics(step=step)
+
+ ### Save a frame to generate gif of robot trajectories ###
+ if self.save_image:
+ robots_route = []
+ for robot in self.robot_list:
+ robots_route.append([robot.xPoints, robot.yPoints])
+ if not os.path.exists(self.gifs_path):
+ os.makedirs(self.gifs_path)
+ if LOAD_AVS_BENCH:
+ # NOTE: Replaced since using app.py
+ self.env.plot_heatmap(self.gifs_path, step, max(travel_dist_list), robots_route)
+
+ if done:
+ break
+
+ if LOAD_AVS_BENCH:
+ tax = Path(self.clip_seg_tta.gt_mask_name).stem
+ self.perf_metrics['tax'] = " ".join(tax.split("_")[1:])
+ else:
+ self.perf_metrics['tax'] = None
+ self.perf_metrics['travel_dist'] = max(travel_dist_list)
+ self.perf_metrics['travel_steps'] = step + 1
+ self.perf_metrics['steps_to_first_tgt'] = self.steps_to_first_tgt
+ self.perf_metrics['steps_to_mid_tgt'] = self.steps_to_mid_tgt
+ self.perf_metrics['steps_to_last_tgt'] = self.steps_to_last_tgt
+ self.perf_metrics['explored_rate'] = self.env.explored_rate
+ self.perf_metrics['targets_found'] = self.env.targets_found_rate
+ self.perf_metrics['targets_total'] = len(self.env.target_positions)
+ if USE_CLIP_PREDS:
+ self.perf_metrics['kmeans_k'] = self.kmeans_clusterer.final_k
+ self.perf_metrics['tgts_gt_score'] = self.clip_seg_tta.tgts_gt_score
+ self.perf_metrics['clip_inference_time'] = self.clip_seg_tta.clip_inference_time
+ self.perf_metrics['tta_time'] = self.clip_seg_tta.tta_time
+ else:
+ self.perf_metrics['kmeans_k'] = None
+ self.perf_metrics['tgts_gt_score'] = None
+ self.perf_metrics['clip_inference_time'] = None
+ self.perf_metrics['tta_time'] = None
+ if FORCE_LOGGING_DONE_TGTS_FOUND and self.env.targets_found_rate == 1.0:
+ self.perf_metrics['success_rate'] = True
+ else:
+ self.perf_metrics['success_rate'] = done
+
+ # save gif
+ if self.save_image:
+ path = self.gifs_path # NOTE: Set to self.gifs_path since using app.py
+ self.make_gif(path, curr_episode)
+
+ print(YELLOW, f"[Eps {curr_episode} Completed] Time Taken: {time()-eps_start:.2f}s, Steps: {step+1}", NC)
+
+ def get_observations(self, robot_position):
+ """ Get robot's sensor observation of environment given position """
+ current_node_index = self.env.find_index_from_coords(robot_position)
+ current_index = torch.tensor([current_node_index]).unsqueeze(0).unsqueeze(0).to(self.device) # (1,1,1)
+
+ node_coords = copy.deepcopy(self.env.node_coords)
+ graph = copy.deepcopy(self.env.graph)
+ node_utility = copy.deepcopy(self.env.node_utility)
+ guidepost = copy.deepcopy(self.env.guidepost)
+ segmentation_info_mask = copy.deepcopy(self.env.filtered_seg_info_mask)
+
+ n_nodes = node_coords.shape[0]
+ node_coords = node_coords / 640
+ node_utility = node_utility / 50
+ node_utility_inputs = node_utility.reshape((n_nodes, 1))
+
+ occupied_node = np.zeros((n_nodes, 1))
+ for position in self.all_robot_positions:
+ index = self.env.find_index_from_coords(position)
+ if index == current_index.item():
+ occupied_node[index] = -1
+ else:
+ occupied_node[index] = 1
+
+ node_inputs = np.concatenate((node_coords, segmentation_info_mask, guidepost), axis=1)
+ node_inputs = torch.FloatTensor(node_inputs).unsqueeze(0).to(self.device)
+ node_padding_mask = None
+
+ graph = list(graph.values())
+ edge_inputs = []
+ for node in graph:
+ node_edges = list(map(int, node))
+ edge_inputs.append(node_edges)
+
+ bias_matrix = self.calculate_edge_mask(edge_inputs)
+ edge_mask = torch.from_numpy(bias_matrix).float().unsqueeze(0).to(self.device)
+
+ for edges in edge_inputs:
+ while len(edges) < self.k_size:
+ edges.append(0)
+
+ edge_inputs = torch.tensor(edge_inputs).unsqueeze(0).to(self.device)
+ edge_padding_mask = torch.zeros((1, len(edge_inputs), K_SIZE), dtype=torch.int64).to(self.device)
+ one = torch.ones_like(edge_padding_mask, dtype=torch.int64).to(self.device)
+ edge_padding_mask = torch.where(edge_inputs == 0, one, edge_padding_mask)
+
+ observations = node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask
+ return observations
+
+
+ def select_node(self, observations):
+ """ Forward pass through policy to get next position to go to on map """
+ node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask = observations
+ with torch.no_grad():
+ logp_list = self.local_policy_net(node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask)
+
+ if self.greedy:
+ action_index = torch.argmax(logp_list, dim=1).long()
+ else:
+ action_index = torch.multinomial(logp_list.exp(), 1).long().squeeze(1)
+
+ next_node_index = edge_inputs[:, current_index.item(), action_index.item()]
+
+ next_position = self.env.node_coords[next_node_index]
+
+ return next_position, action_index
+
+ def solve_conflict(self, arriving_sequence, next_position_list, dist_list):
+ """ Deconflict if 2 agents choose the same target position """
+ for j, [robot_id, next_position] in enumerate(zip(arriving_sequence, next_position_list)):
+ moving_robot = self.robot_list[robot_id]
+ # if next_position[0] + next_position[1] * 1j in (next_position_list[:, 0] + next_position_list[:, 1] * 1j)[:j]:
+ # dist_to_next_position = np.argsort(np.linalg.norm(self.env.node_coords - next_position, axis=1))
+ # k = 0
+ # while next_position[0] + next_position[1] * 1j in (next_position_list[:, 0] + next_position_list[:, 1] * 1j)[:j]:
+ # k += 1
+ # next_position = self.env.node_coords[dist_to_next_position[k]]
+
+ dist = np.linalg.norm(next_position - moving_robot.robot_position)
+ next_position_list[j] = next_position
+ dist_list[j] = dist
+ moving_robot.travel_dist += dist
+ moving_robot.robot_position = next_position
+
+ return next_position_list, dist_list
+
+ def work(self, currEpisode):
+ '''
+ Interacts with the environment. The agent gets either gradients or experience buffer
+ '''
+ self.run_episode(currEpisode)
+
+ def calculate_edge_mask(self, edge_inputs):
+ size = len(edge_inputs)
+ bias_matrix = np.ones((size, size))
+ for i in range(size):
+ for j in range(size):
+ if j in edge_inputs[i]:
+ bias_matrix[i][j] = 0
+ return bias_matrix
+
+ def make_gif(self, path, n):
+ """ Generate a gif given list of images """
+ with imageio.get_writer('{}/{}_target_rate_{:.2f}.gif'.format(path, n, self.env.targets_found_rate), mode='I',
+ fps=5) as writer:
+ for frame in self.env.frame_files:
+ image = imageio.imread(frame)
+ writer.append_data(image)
+ print('gif complete\n')
+
+ # Remove files
+ for filename in self.env.frame_files[:-1]:
+ os.remove(filename)
+
+ # For gif during TTA
+ if LOAD_AVS_BENCH:
+ with imageio.get_writer('{}/{}_kmeans_stats.gif'.format(path, n), mode='I',
+ fps=5) as writer:
+ for frame in self.kmeans_clusterer.kmeans_frame_files:
+ image = imageio.imread(frame)
+ writer.append_data(image)
+ print('Kmeans Clusterer gif complete\n')
+
+ # Remove files
+ for filename in self.kmeans_clusterer.kmeans_frame_files[:-1]:
+ os.remove(filename)
+
+ ################################################################################
+ # SPPP Related Fns
+ ################################################################################
+
+ def log_metrics(self, step):
+ # Update tgt found metrics
+ if self.steps_to_first_tgt is None and self.env.num_targets_found == 1:
+ self.steps_to_first_tgt = step + 1
+ if self.steps_to_mid_tgt is None and self.env.num_targets_found == int(len(self.env.target_positions) / 2):
+ self.steps_to_mid_tgt = step + 1
+ if self.steps_to_last_tgt is None and self.env.num_targets_found == len(self.env.target_positions):
+ self.steps_to_last_tgt = step + 1
+
+
+ def transpose_flat_idx(self, idx, H=NUM_COORDS_HEIGHT, W=NUM_COORDS_WIDTH):
+ """
+ Transpose a flat index from an ``HΓW`` grid to the equivalent
+ position in the ``WΓH`` transposed grid while **keeping the result
+ in 1-D**.
+ """
+ # --- Safety check to catch out-of-range indices ---
+ assert 0 <= idx < H * W, f"idx {idx} out of bounds for shape ({H}, {W})"
+
+ # Original (row, col)
+ row, col = divmod(idx, W)
+ # After transpose these coordinates swap
+ row_T, col_T = col, row
+
+ # Flatten back into 1-D (row-major) for the WΓH grid
+ return row_T * H + col_T
+
+
+ def poisson_tta_update(self, robot, episode, step):
+
+ # Generate Kmeans Clusters Stats
+ # Scale index back to CLIP_GRIDS_DIMS to be compatible with CLIP patch size
+ if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]:
+ # High-res remap via pixel coordinates preserves exact neighbourhood
+ filt_traj_coords, filt_targets_found_on_path = self.scale_trajectory(
+ robot.trajectory_coords,
+ self.env.target_positions,
+ old_dims=(NUM_COORDS_HEIGHT, NUM_COORDS_WIDTH),
+ full_dims=(512, 512),
+ new_dims=(CLIP_GRIDS_DIMS[0], CLIP_GRIDS_DIMS[1])
+ )
+ else:
+ filt_traj_coords = [self.transpose_flat_idx(idx) for idx in robot.trajectory_coords]
+ filt_targets_found_on_path = robot.targets_found_on_path
+
+ region_stats_dict = self.kmeans_clusterer.compute_region_statistics(
+ self.kmeans_sat_embeds_clusters,
+ self.clip_seg_tta.heatmap_unnormalized,
+ filt_traj_coords,
+ episode_num=episode,
+ step_num=step
+ )
+
+ # Prep & execute TTA
+ self.step_since_tta += 1
+ if robot.targets_found_on_path[-1] or self.step_since_tta % STEPS_PER_TTA == 0:
+
+ # NOTE: integration with app.py on hf
+ self.clip_seg_tta.executing_tta = True
+
+ num_cells = self.clip_seg_tta.heatmap.shape[0] * self.clip_seg_tta.heatmap.shape[1]
+ pos_sample_weight_scale, neg_sample_weight_scale = [], []
+
+ for i, sample_loc in enumerate(filt_traj_coords):
+ label = self.kmeans_clusterer.get_label_id(self.kmeans_sat_embeds_clusters, sample_loc)
+ num_patches = region_stats_dict[label]['num_patches']
+ patches_visited = region_stats_dict[label]['patches_visited']
+ expectation = region_stats_dict[label]['expectation']
+
+ # Exponent like focal loss to wait for more samples before confidently decreasing
+ pos_weight = 4.0
+ neg_weight = min(1.0, (patches_visited/(3*num_patches))**GAMMA_EXPONENT)
+ pos_sample_weight_scale.append(pos_weight)
+ neg_sample_weight_scale.append(neg_weight)
+
+ # # # Adaptative LR (as samples increase, increase LR to fit more datapoints)
+ adaptive_lr = MIN_LR + (MAX_LR - MIN_LR) * (step / num_cells)
+
+ # TTA Update
+ # NOTE: updated due to app.py (hf does not allow heatmap to persist)
+ heatmap = self.clip_seg_tta.execute_tta(
+ filt_traj_coords,
+ filt_targets_found_on_path,
+ tta_steps=NUM_TTA_STEPS,
+ lr=adaptive_lr,
+ pos_sample_weight=pos_sample_weight_scale,
+ neg_sample_weight=neg_sample_weight_scale,
+ reset_weights=RESET_WEIGHTS
+ )
+ self.clip_seg_tta.heatmap = heatmap
+
+ if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]: # If heatmap is resized from clip original dims
+ heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT))
+ self.env.segmentation_info_mask = np.expand_dims(heatmap.T.flatten(), axis=1)
+ unnormalized_heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap_unnormalized, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT))
+ self.env.segmentation_info_mask_unnormalized = np.expand_dims(unnormalized_heatmap.T.flatten(), axis=1)
+ print("~Resized heatmap to", NUM_COORDS_WIDTH, "x", NUM_COORDS_HEIGHT)
+ else:
+ self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1)
+ self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1)
+
+ self.step_since_tta = 0
+
+ # NOTE: integration with app.py on hf
+ self.clip_seg_tta.executing_tta = False
+
+
+ def convert_heatmap_resolution(self, heatmap, full_dims=(512, 512), new_dims=(24, 24)):
+ heatmap_large = resize(heatmap, full_dims, order=1, # order=1 β bilinear
+ mode='reflect', anti_aliasing=True)
+
+ coords = self.env.graph_generator.grid_coords # (N, N, 2)
+ rows, cols = coords[...,1], coords[...,0]
+ heatmap_resized = heatmap_large[rows, cols]
+ heatmap_resized = heatmap_resized.reshape(new_dims[1], new_dims[0])
+ return heatmap_resized
+
+
+ def convert_labelmap_resolution(self, labelmap, full_dims=(512, 512), new_dims=(24, 24)):
+ """
+ 1) Upsample via nearestβneighbor to full_dims
+ 2) Sample back down to your graph grid using grid_coords
+ """
+ # 1) Upsample with nearestβneighbor, preserving integer labels
+ up = resize(
+ labelmap,
+ full_dims,
+ order=0, # nearestβneighbor
+ mode='edge', # padding mode
+ preserve_range=True, # don't normalize labels
+ anti_aliasing=False # must be False for labels
+ ).astype(labelmap.dtype) # back to original integer dtype
+
+ # 2) Downsample via your precomputed grid coords
+ coords = self.env.graph_generator.grid_coords # shape (N, N, 2)
+ rows = coords[...,1].astype(int)
+ cols = coords[...,0].astype(int)
+
+ small = up[rows, cols] # shape (N, N)
+ small = small.reshape(new_dims[0], new_dims[1])
+ return small
+
+
+ def scale_trajectory(self,
+ flat_indices,
+ targets,
+ old_dims=(17, 17),
+ full_dims=(512, 512),
+ new_dims=(24, 24)):
+ """
+ Args:
+ flat_indices: list of ints in [0..old_H*old_W-1]
+ targets: list of (y_pix, x_pix) in [0..full_H-1]
+ old_dims: (old_H, old_W)
+ full_dims: (full_H, full_W)
+ new_dims: (new_H, new_W)
+
+ Returns:
+ new_flat_traj: list of unique flattened indices in new_HΓnew_W
+ counts: list of ints, same length as new_flat_traj
+ """
+ old_H, old_W = old_dims
+ full_H, full_W = full_dims
+ new_H, new_W = new_dims
+
+ # 1) bin targets into new grid
+ cell_h_new = full_H / new_H
+ cell_w_new = full_W / new_W
+ grid_counts = [[0]*new_W for _ in range(new_H)]
+ for x_pix, y_pix in targets: # note (x, y) order as in original implementation
+ i_t = min(int(y_pix / cell_h_new), new_H - 1)
+ j_t = min(int(x_pix / cell_w_new), new_W - 1)
+ grid_counts[i_t][j_t] += 1
+
+ # 2) Walk the trajectory indices and project each old cell's *entire
+ # pixel footprint* onto the finer 24Γ24 grid.
+ cell_h_full = full_H / old_H
+ cell_w_full = full_W / old_W
+
+ seen = set()
+ new_flat_traj = []
+
+ for node_idx in flat_indices:
+ if node_idx < 0 or node_idx >= len(self.env.graph_generator.node_coords):
+ continue
+
+ coord_xy = self.env.graph_generator.node_coords[node_idx]
+ try:
+ row_old, col_old = self.env.graph_generator.find_index_from_grid_coords_2d(coord_xy)
+ except Exception:
+ continue
+
+ # Bounding box of the old cell in full-resolution pixel space
+ y0 = row_old * cell_h_full
+ y1 = (row_old + 1) * cell_h_full
+ x0 = col_old * cell_w_full
+ x1 = (col_old + 1) * cell_w_full
+
+ # Which new-grid rows & cols overlap? (inclusive ranges)
+ i_start = max(0, min(int(y0 / cell_h_new), new_H - 1))
+ i_end = max(0, min(int((y1 - 1) / cell_h_new), new_H - 1))
+ j_start = max(0, min(int(x0 / cell_w_new), new_W - 1))
+ j_end = max(0, min(int((x1 - 1) / cell_w_new), new_W - 1))
+
+ for ii in range(i_start, i_end + 1):
+ for jj in range(j_start, j_end + 1):
+ f_new = ii * new_W + jj
+ if f_new not in seen:
+ seen.add(f_new)
+ new_flat_traj.append(f_new)
+
+ # 3) annotate counts
+ counts = []
+ for f in new_flat_traj:
+ i_new, j_new = divmod(f, new_W)
+ counts.append(grid_counts[i_new][j_new])
+
+ return new_flat_traj, counts
+
+ ################################################################################
diff --git a/planner/worker.py b/planner/worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..6abc2909d0aa4cacc4ab38b7d00a874e3bb3d263
--- /dev/null
+++ b/planner/worker.py
@@ -0,0 +1,272 @@
+#######################################################################
+# Name: worker.py
+#
+# - Runs robot in environment for N steps
+# - Collects & Returns S(t), A(t), R(t), S(t+1)
+#######################################################################
+
+from .parameter import *
+
+import os
+import json
+import copy
+import imageio
+import numpy as np
+import torch
+from time import time
+from .env import Env
+from .robot import Robot
+
+class Worker:
+ def __init__(self, meta_agent_id, n_agent, policy_net, q_net, global_step, device='cuda', greedy=False, save_image=False, clip_seg_tta=None):
+ self.device = device
+ self.greedy = greedy
+ self.n_agent = n_agent
+ self.metaAgentID = meta_agent_id
+ self.global_step = global_step
+ self.node_padding_size = NODE_PADDING_SIZE
+ self.k_size = K_SIZE
+ self.save_image = save_image
+ self.clip_seg_tta = clip_seg_tta
+
+ # Randomize map_index
+ mask_index = None
+ if MASKS_RAND_INDICES_PATH != "":
+ with open(MASKS_RAND_INDICES_PATH, 'r') as f:
+ mask_index_rand_json = json.load(f)
+ mask_index = mask_index_rand_json[self.global_step % len(mask_index_rand_json)]
+ print("mask_index: ", mask_index)
+
+ self.env = Env(map_index=self.global_step, n_agent=n_agent, k_size=self.k_size, plot=save_image, mask_index=mask_index)
+ self.local_policy_net = policy_net
+ self.local_q_net = q_net
+
+ self.robot_list = []
+ self.all_robot_positions = []
+
+ for i in range(self.n_agent):
+ robot_position = self.env.start_positions[i]
+ robot = Robot(robot_id=i, position=robot_position, plot=save_image)
+ self.robot_list.append(robot)
+ self.all_robot_positions.append(robot_position)
+
+ self.perf_metrics = dict()
+ self.episode_buffer = []
+ for i in range(15):
+ self.episode_buffer.append([])
+
+
+ def run_episode(self, curr_episode):
+
+ eps_start = time()
+ done = False
+ for robot_id, deciding_robot in enumerate(self.robot_list):
+ deciding_robot.observations = self.get_observations(deciding_robot.robot_position)
+
+ ### Run episode ###
+ for step in range(NUM_EPS_STEPS):
+
+ next_position_list = []
+ dist_list = []
+ travel_dist_list = []
+ dist_array = np.zeros((self.n_agent, 1))
+ for robot_id, deciding_robot in enumerate(self.robot_list):
+ observations = deciding_robot.observations
+ deciding_robot.save_observations(observations)
+
+ ### Forward pass through policy to get next position ###
+ next_position, action_index = self.select_node(observations)
+ deciding_robot.save_action(action_index)
+
+ dist = np.linalg.norm(next_position - deciding_robot.robot_position)
+
+ ### Log results of action (e.g. distance travelled) ###
+ dist_array[robot_id] = dist
+ dist_list.append(dist)
+ travel_dist_list.append(deciding_robot.travel_dist)
+ next_position_list.append(next_position)
+ self.all_robot_positions[robot_id] = next_position
+
+ arriving_sequence = np.argsort(dist_list)
+ next_position_list = np.array(next_position_list)
+ dist_list = np.array(dist_list)
+ travel_dist_list = np.array(travel_dist_list)
+ next_position_list = next_position_list[arriving_sequence]
+ dist_list = dist_list[arriving_sequence]
+ travel_dist_list = travel_dist_list[arriving_sequence]
+
+ ### Take Action (Deconflict if 2 agents choose the same target position) ###
+ next_position_list, dist_list = self.solve_conflict(arriving_sequence, next_position_list, dist_list)
+ reward_list, done = self.env.multi_robot_step(next_position_list, dist_list, travel_dist_list)
+
+ ### Update observations + rewards from action ###
+ for reward, robot_id in zip(reward_list, arriving_sequence):
+ robot = self.robot_list[robot_id]
+ robot.observations = self.get_observations(robot.robot_position)
+ robot.save_trajectory_coords(self.env.find_index_from_coords(robot.robot_position), self.env.num_new_targets_found)
+ robot.save_reward_done(reward, done)
+ robot.save_next_observations(robot.observations)
+
+ ### Save a frame to generate gif of robot trajectories ###
+ if self.save_image:
+ robots_route = []
+ for robot in self.robot_list:
+ robots_route.append([robot.xPoints, robot.yPoints])
+ if not os.path.exists(GIFS_PATH):
+ os.makedirs(GIFS_PATH)
+ self.env.plot_env(self.global_step, GIFS_PATH, step, max(travel_dist_list), robots_route)
+
+ if done:
+ break
+
+ for robot in self.robot_list:
+ for i in range(15):
+ self.episode_buffer[i] += robot.episode_buffer[i]
+
+ self.perf_metrics['travel_dist'] = max(travel_dist_list)
+ self.perf_metrics['explored_rate'] = self.env.explored_rate
+ self.perf_metrics['targets_found'] = self.env.targets_found_rate
+ self.perf_metrics['success_rate'] = done
+
+ # save gif
+ if self.save_image:
+ path = GIFS_PATH
+ self.make_gif(path, curr_episode)
+
+ print(YELLOW, f"[Eps {curr_episode} Completed] Time Taken: {time()-eps_start:.2f}s, Steps: {step+1}", NC)
+
+
+ def get_observations(self, robot_position):
+ """ Get robot's sensor observation of environment given position """
+ current_node_index = self.env.find_index_from_coords(robot_position)
+ current_index = torch.tensor([current_node_index]).unsqueeze(0).unsqueeze(0).to(self.device) # (1,1,1)
+
+ node_coords = copy.deepcopy(self.env.node_coords)
+ graph = copy.deepcopy(self.env.graph)
+ node_utility = copy.deepcopy(self.env.node_utility)
+ guidepost = copy.deepcopy(self.env.guidepost)
+ segmentation_info_mask = copy.deepcopy(self.env.filtered_seg_info_mask)
+
+ n_nodes = node_coords.shape[0]
+ node_coords = node_coords / 640
+ node_utility = node_utility / 50
+
+ node_utility_inputs = node_utility.reshape((n_nodes, 1))
+
+ occupied_node = np.zeros((n_nodes, 1))
+ for position in self.all_robot_positions:
+ index = self.env.find_index_from_coords(position)
+ if index == current_index.item():
+ occupied_node[index] = -1
+ else:
+ occupied_node[index] = 1
+
+ node_inputs = np.concatenate((node_coords, segmentation_info_mask, guidepost), axis=1)
+ node_inputs = torch.FloatTensor(node_inputs).unsqueeze(0).to(self.device) # (1, node_padding_size+1, 3)
+
+ assert node_coords.shape[0] < self.node_padding_size
+ padding = torch.nn.ZeroPad2d((0, 0, 0, self.node_padding_size - node_coords.shape[0]))
+ node_inputs = padding(node_inputs)
+
+ node_padding_mask = torch.zeros((1, 1, node_coords.shape[0]), dtype=torch.int64).to(self.device)
+ node_padding = torch.ones((1, 1, self.node_padding_size - node_coords.shape[0]), dtype=torch.int64).to(
+ self.device)
+ node_padding_mask = torch.cat((node_padding_mask, node_padding), dim=-1)
+
+ graph = list(graph.values())
+ edge_inputs = []
+ for node in graph:
+ node_edges = list(map(int, node))
+ edge_inputs.append(node_edges)
+
+ bias_matrix = self.calculate_edge_mask(edge_inputs)
+ edge_mask = torch.from_numpy(bias_matrix).float().unsqueeze(0).to(self.device)
+
+ assert len(edge_inputs) < self.node_padding_size
+ padding = torch.nn.ConstantPad2d(
+ (0, self.node_padding_size - len(edge_inputs), 0, self.node_padding_size - len(edge_inputs)), 1)
+ edge_mask = padding(edge_mask)
+ padding2 = torch.nn.ZeroPad2d((0, 0, 0, self.node_padding_size - len(edge_inputs)))
+
+ for edges in edge_inputs:
+ while len(edges) < self.k_size:
+ edges.append(0)
+
+ edge_inputs = torch.tensor(edge_inputs).unsqueeze(0).to(self.device) # (1, node_padding_size+1, k_size)
+ edge_inputs = padding2(edge_inputs)
+
+ edge_padding_mask = torch.zeros((1, len(edge_inputs), K_SIZE), dtype=torch.int64).to(self.device)
+ one = torch.ones_like(edge_padding_mask, dtype=torch.int64).to(self.device)
+ edge_padding_mask = torch.where(edge_inputs == 0, one, edge_padding_mask)
+
+ observations = node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask
+ return observations
+
+
+ def select_node(self, observations):
+ """ Forward pass through policy to get next position to go to on map """
+ node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask = observations
+ with torch.no_grad():
+ logp_list = self.local_policy_net(node_inputs, edge_inputs, current_index, node_padding_mask,
+ edge_padding_mask, edge_mask)
+
+ if self.greedy:
+ action_index = torch.argmax(logp_list, dim=1).long()
+ else:
+ action_index = torch.multinomial(logp_list.exp(), 1).long().squeeze(1)
+
+ next_node_index = edge_inputs[:, current_index.item(), action_index.item()]
+
+ next_position = self.env.node_coords[next_node_index]
+
+ return next_position, action_index
+
+
+ def solve_conflict(self, arriving_sequence, next_position_list, dist_list):
+ """ Deconflict if 2 agents choose the same target position """
+ for j, [robot_id, next_position] in enumerate(zip(arriving_sequence, next_position_list)):
+ moving_robot = self.robot_list[robot_id]
+ # if next_position[0] + next_position[1] * 1j in (next_position_list[:, 0] + next_position_list[:, 1] * 1j)[:j]:
+ # dist_to_next_position = np.argsort(np.linalg.norm(self.env.node_coords - next_position, axis=1))
+ # k = 0
+ # while next_position[0] + next_position[1] * 1j in (next_position_list[:, 0] + next_position_list[:, 1] * 1j)[:j]:
+ # k += 1
+ # next_position = self.env.node_coords[dist_to_next_position[k]]
+
+ dist = np.linalg.norm(next_position - moving_robot.robot_position)
+ next_position_list[j] = next_position
+ dist_list[j] = dist
+ moving_robot.travel_dist += dist
+ moving_robot.robot_position = next_position
+
+ return next_position_list, dist_list
+
+
+ def work(self, currEpisode):
+ '''
+ Interacts with the environment. The agent gets either gradients or experience buffer
+ '''
+ self.run_episode(currEpisode)
+
+ def calculate_edge_mask(self, edge_inputs):
+ size = len(edge_inputs)
+ bias_matrix = np.ones((size, size))
+ for i in range(size):
+ for j in range(size):
+ if j in edge_inputs[i]:
+ bias_matrix[i][j] = 0
+ return bias_matrix
+
+
+ def make_gif(self, path, n):
+ """ Generate a gif given list of images """
+ with imageio.get_writer('{}/{}_target_rate_{:.2f}.gif'.format(path, n, self.env.targets_found_rate), mode='I',
+ fps=5) as writer:
+ for frame in self.env.frame_files:
+ image = imageio.imread(frame)
+ writer.append_data(image)
+ print('gif complete\n')
+
+ # Remove files
+ for filename in self.env.frame_files[:-1]:
+ os.remove(filename)
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..92dc9c5e3fcf1940f69f063679662be9590408a2
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,24 @@
+# python 3.10.14
+
+# Search-TTA
+scipy==1.14.1
+scikit-learn==1.6.1
+scikit-image==0.24.0
+matplotlib==3.9.1
+imageio==2.36.0
+shapely==2.0.7
+rasterio==1.4.1
+kneed==0.8.5
+easydict==1.13
+
+# Taxabind
+numpy==1.26.3
+torch==2.4.0
+torchvision==0.19.0
+torchaudio==2.4.0
+pytorch-lightning==2.2.1
+open_clip_torch==2.30.0
+transformers==4.45.1
+tokenizers==0.20.3
+opencv-python==4.10.0.84
+gradio==4.44.1
\ No newline at end of file
diff --git a/taxabind_avs/LICENSE b/taxabind_avs/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64
--- /dev/null
+++ b/taxabind_avs/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/taxabind_avs/README.md b/taxabind_avs/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..bc6410a407f561c691db935a6ed0701be6181375
--- /dev/null
+++ b/taxabind_avs/README.md
@@ -0,0 +1 @@
+This folder is adapted from the [Taxabind repository](https://github.com/mvrl/TaxaBind).
\ No newline at end of file
diff --git a/taxabind_avs/satbind/clip_seg_tta.py b/taxabind_avs/satbind/clip_seg_tta.py
new file mode 100644
index 0000000000000000000000000000000000000000..f17334f2b2c150b22ce959909e981757294b96a5
--- /dev/null
+++ b/taxabind_avs/satbind/clip_seg_tta.py
@@ -0,0 +1,564 @@
+##############################################################################
+# Name: clip_seg_tta.py
+#
+# - Performs TTA on sat encoder a collected measurements
+###############################################################################
+
+import os
+import sys
+sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
+
+import cv2
+import time
+import torch
+import numpy as np
+import open_clip
+import torch.nn as nn
+import spaces
+from PIL import Image
+from matplotlib import pyplot as plt
+from dataset import SatNatDataset
+from model_sat import SatBind
+from transformers import ClapAudioModelWithProjection
+from clip_vision_per_patch_model import CLIPVisionPerPatchModel
+from types import SimpleNamespace
+from config_sat import config
+
+
+class ClipSegTTA:
+ def __init__(
+ self,
+ img_dir: str,
+ imo_dir: str,
+ json_path: str,
+ sat_to_img_ids_path: str,
+ sat_checkpoint_path: str,
+ load_pretrained_hf_ckpt: bool = True,
+ sample_index: int = 0, # Set using 'reset'
+ blur_kernel = (5,5), # (0,0) for no gaussian blur
+ batch_size: int = 1,
+ num_workers: int = 1,
+ device: str = "cuda",
+ sat_to_img_ids_json_is_train_dict: bool = True,
+ tax_to_filter_val: str = "",
+ load_model: bool = True,
+ query_modality: str = "image", # image, text, sound
+ sound_dir: str = None,
+ sound_checkpoint_path: str = None,
+ ):
+
+ self.img_dir = img_dir
+ self.imo_dir = imo_dir
+ self.json_path = json_path
+ self.sat_to_img_ids_path = sat_to_img_ids_path
+ self.sat_checkpoint_path = sat_checkpoint_path
+ self.pretrained_hf_ckpt = load_pretrained_hf_ckpt
+ self.sample_index = sample_index
+ self.blur_kernel = blur_kernel
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.device = device
+ self.sat_to_img_ids_json_is_train_dict = sat_to_img_ids_json_is_train_dict
+ self.tax_to_filter_val = tax_to_filter_val
+ self.load_model = load_model
+ self.query_modality = query_modality
+ self.sound_dir = sound_dir
+ self.sound_checkpoint_path = sound_checkpoint_path
+
+ # Prepare the dataset
+ start_time = time.time()
+ self.load_data()
+ print(f"Dataset loaded in {(time.time()-start_time):.2f}s.")
+
+ if self.load_model:
+ start_time = time.time()
+
+ # Load the global model (original/frozen checkpoint)
+ self.load_global_model()
+ self.tokenizer = open_clip.get_tokenizer(config.image_encoder_finetuned)
+ print(f"Global model loaded in {(time.time()-start_time):.2f}s.")
+
+ # Create the local model that will be adapted for TTA
+ if self.pretrained_hf_ckpt:
+ imo_encoder = CLIPVisionPerPatchModel.from_pretrained(config.sat_encoder_finetuned)
+ imo_encoder.to(self.device)
+ imo_encoder.eval()
+ bio_model, *_ = open_clip.create_model_and_transforms(config.image_encoder_finetuned)
+ bio_model.to(self.device)
+ bio_model.eval()
+ logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+ self.model_local = SimpleNamespace(imo_encoder=imo_encoder, bio_model=bio_model, logit_scale=logit_scale)
+
+ # Load sound model if provided
+ if self.query_modality =="sound":
+ self.sound_model = ClapAudioModelWithProjection.from_pretrained(config.sound_encoder_finetuned)
+ self.sound_model.to(self.device)
+ self.sound_model.eval()
+ print("~Loaded HF checkpoint")
+ else:
+ self.model_local = SatBind(train_dataset=None, val_dataset=None)
+ self.model_local.to(self.device)
+ self.model_local.eval()
+
+ # Load sound model if provided
+ if self.query_modality =="sound" and self.sound_checkpoint_path:
+ from soundbind.model_sound import AudioBind
+ self.sound_model = AudioBind.load_from_checkpoint(self.sound_checkpoint_path, train_dataset=None, val_dataset=None)
+ self.sound_model.to(self.device)
+ self.sound_model.eval()
+ print("~Loaded local checkpoint")
+
+ self.clip_inference_time = 0.0
+ self.tta_time = 0.0
+
+ # NOTE: integration with app.py on hf
+ if sample_index >= 0:
+ self.reset(sample_idx=self.sample_index)
+ self.executing_tta = False
+
+
+ def load_data(self):
+ """Load or initialize the dataset."""
+ self.dataset = SatNatDataset(
+ img_dir=self.img_dir,
+ imo_dir=self.imo_dir,
+ json_path=self.json_path,
+ sat_to_img_ids_path=self.sat_to_img_ids_path,
+ patch_size=config.patch_size,
+ mode="val",
+ get_img_path=True,
+ sat_to_img_ids_json_is_train_dict=self.sat_to_img_ids_json_is_train_dict,
+ tax_to_filter_val=self.tax_to_filter_val,
+ sound_dir=self.sound_dir
+ )
+
+ def reset(self, sample_idx):
+ """Reset the parameters & local model for the current sample."""
+ if self.load_model:
+ self.reset_local_model() # Reset to global weights as init
+
+ # NOTE: integration with app.py on hf
+ if sample_idx >= 0:
+ self.img_paths, self.imo_path, self.imgs, self.imo, self.sounds, self.sound_ids, self.species_name, self.target_positions, self.gt_mask_name = self.dataset.get_search_ds_data(sample_idx)
+ self.imgs = self.imgs.to(self.device)
+
+ self.tgts_gt_score = None
+ if self.load_model:
+ self.heatmap, self.heatmap_unnormalized, self.heatmap_unnormalized_initial, self.patch_embeds = None, None, None, None
+ img = self.imgs[0].unsqueeze(0).to(self.device)
+ imo = self.imo.unsqueeze(0).to(self.device)
+ txt = [self.species_name]
+ if self.sounds != []:
+ sound = self.sounds[0].to(self.device)
+ for k in sound.keys():
+ sound[k] = sound[k].to(self.device)
+ else:
+ sound = None
+ self.generate_heatmap(img, imo, txt, sound=sound, modality=self.query_modality)
+
+ # Find avg heatmap score for target positions
+ scores = []
+ imo_orig = Image.open(self.imo_path)
+ for pos in self.target_positions:
+ row_trans = int(pos[0] * self.heatmap.shape[0] / imo_orig.size[0])
+ col_trans = int(pos[1] * self.heatmap.shape[1] / imo_orig.size[1])
+ scores.append(self.heatmap[row_trans, col_trans])
+ self.tgts_gt_score = np.mean(scores)
+
+ # NOTE: integration with app.py on hf
+ return self.heatmap, self.heatmap_unnormalized, self.heatmap_unnormalized_initial, self.patch_embeds
+
+
+ def load_global_model(self):
+ """Load the global SatBind model from checkpoint, move to device, and eval."""
+ if self.pretrained_hf_ckpt:
+ print("Downloading HF checkpoint (if not already downloaded)...")
+ self.model_global = CLIPVisionPerPatchModel.from_pretrained(config.sat_encoder_finetuned)
+ else:
+ self.model_global = SatBind.load_from_checkpoint(
+ self.sat_checkpoint_path, train_dataset=None, val_dataset=None
+ )
+ self.model_global = self.model_global.to(self.device)
+ self.model_global.eval()
+
+
+ def reset_local_model(self):
+ """
+ Reset the local model to match the global model's parameters
+ and freeze/unfreeze layers for TTA.
+ """
+ start_time = time.time()
+ with torch.no_grad():
+ local_params = self.model_local.imo_encoder.parameters() \
+ if self.pretrained_hf_ckpt else self.model_local.parameters()
+ for param_global, param_local in zip(
+ self.model_global.parameters(), local_params
+ ):
+ param_local.data.copy_(param_global.data)
+
+ if self.pretrained_hf_ckpt:
+ for param in self.model_local.imo_encoder.parameters():
+ param.requires_grad = True
+ self.model_local.imo_encoder.eval()
+ else:
+ # Freeze everything except the satellite encoder & custom projection
+ for name, param in self.model_local.named_parameters():
+ if "imo_encoder" in name or "visual_projection_custom" in name:
+ param.requires_grad = True
+ else:
+ param.requires_grad = False
+ self.model_local.eval()
+
+
+ # NOTE: integration with app.py on hf
+ @spaces.GPU(duration=120)
+ def execute_tta(
+ self,
+ patch_indices: list,
+ patch_is_pos: list,
+ pos_sample_weight: float,
+ neg_sample_weight: float,
+ tta_steps: int = 10,
+ lr: float = 2e-6,
+ reset_weights: bool = True,
+ num_viz_steps: int = 1,
+ viz_heatmap: bool = False,
+ ):
+ """
+ Run test-time adaptation using the local model. The local model is first
+ reset to the global weights. After TTA, the global model remains
+ unchanged; only the local model is updated.
+ """
+
+ ### Option 1: SAMPLE FROM DATASET
+ # 1) Reset the local model to global weights
+ if reset_weights:
+ self.reset_local_model()
+
+ ## NOTE: Added due to app.py (to allocate to GPU only when needed on HF)
+ print("Allocating models to GPU...")
+ self.device = torch.device("cuda")
+ self.model_local.imo_encoder.to(self.device)
+ self.model_local.bio_model.to(self.device)
+
+ # 2) Prepare the sample(s) for TTA
+ img = self.imgs[0].unsqueeze(0).to(self.device)
+ imo = self.imo.unsqueeze(0).to(self.device) # vectorize
+ txt = [self.species_name]
+ if self.sounds != []:
+ sound = self.sounds[0].to(self.device)
+ for k in sound.keys():
+ sound[k] = sound[k].to(self.device)
+ else:
+ sound = None
+ patch_indices = [idx+1 for idx in patch_indices] # Consider the [CLS] token offset
+ patch_idx = torch.tensor(patch_indices).to(self.device)
+
+ # ---------------------------------------------------------------------
+
+ # 5) Set up optimizer
+ local_params = self.model_local.imo_encoder.parameters() \
+ if self.pretrained_hf_ckpt else self.model_local.parameters()
+ optimizer = torch.optim.Adam(
+ [p for p in local_params if p.requires_grad], lr=lr
+ )
+ start_time = time.time()
+
+ # 6) TTA loop
+ for step in range(tta_steps):
+ batch_size = imo.shape[0]
+
+ # Query embeds
+ query_embeds = self.generate_query_embeds(img, imo, txt, sound=sound, modality=self.query_modality)
+
+ # Sat Embeds
+ if self.pretrained_hf_ckpt:
+ imo_embeds = self.model_local.imo_encoder.vision_model(imo, return_dict=True).last_hidden_state
+ imo_embeds = imo_embeds[torch.arange(batch_size), patch_idx]
+ imo_embeds = self.model_local.imo_encoder.visual_projection(imo_embeds)
+ else:
+ imo_embeds = self.model_local.imo_encoder(imo).last_hidden_state # (batch, Patches, hidden_dim)
+ imo_embeds = imo_embeds[torch.arange(batch_size), patch_idx] # (batch, hidden_dim)
+ imo_embeds = self.model_local.visual_projection_custom(imo_embeds) # (batch_size, proj_dim)
+ imo_embeds = torch.nn.functional.normalize(imo_embeds, dim=-1)
+
+ # Compute Similarity Loss
+ logit_scale = self.model_local.logit_scale.exp()
+ similarity = imo_embeds @ query_embeds.t() * logit_scale
+
+ # Negative Log Likelihood loss for spatial poisson point process
+ patch_probs = similarity.squeeze().sigmoid()
+ counts = torch.tensor(patch_is_pos, dtype=torch.float32, device=similarity.device)
+ pos_weights = torch.tensor(pos_sample_weight, dtype=torch.float32, device=similarity.device)
+ neg_weights = torch.tensor(neg_sample_weight, dtype=torch.float32, device=similarity.device)
+ loss = (neg_weights * patch_probs - pos_weights * counts * torch.log(patch_probs + 1e-6))
+ loss = loss.sum()
+
+ # Backprop and update
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ self.tta_time = time.time() - start_time
+
+ # Visualization every 'num_viz_steps' steps (if enabled)
+ if (step + 1) % num_viz_steps == 0 and viz_heatmap:
+ self.generate_heatmap(img, imo, txt, sound=sound, modality=self.query_modality)
+ self.visualize_heatmap(
+ step=step,
+ img_path_viz=self.img_paths[0], # Viz 1st image
+ imo_path_viz=self.imo_path,
+ patch_idx_viz=patch_idx,
+ patch_is_pos=patch_is_pos,
+ species_name=self.species_name
+ )
+
+ # Save final heatmap after TTA steps
+ self.generate_heatmap(img, imo, txt, sound=sound, modality=self.query_modality)
+
+ ## NOTE: Added due to app.py (to allocate to GPU only when needed on HF)
+ print("Deallocating models from GPU...")
+ self.device = torch.device("cpu")
+ self.model_local.imo_encoder.to(self.device)
+ self.model_local.bio_model.to(self.device)
+
+ return self.heatmap
+
+
+ def generate_query_embeds(self, img, imo, txt, sound=None, modality="image"):
+
+ # Query Embeds
+ if modality == "image":
+ query_embeds, *_ = self.model_local.bio_model(img) # (batch_size, proj_dim)
+ if query_embeds.shape[0] > 1:
+ query_embeds = query_embeds.mean(dim=0, keepdim=True) # (1, proj_dim)
+ elif modality == "text":
+ txt_tokenized = self.tokenizer(txt).to(imo.device)
+ _, query_embeds, _ = self.model_local.bio_model(text=txt_tokenized)
+ elif modality == "sound":
+ if sound == None:
+ print("!!!! Sound modality requires sound input !!!")
+ exit(1)
+ if self.pretrained_hf_ckpt:
+ unnormalized_audio_embeds = self.sound_model(**sound).audio_embeds
+ else:
+ unnormalized_audio_embeds = self.sound_model.audio_encoder(sound)
+ query_embeds = torch.nn.functional.normalize(unnormalized_audio_embeds, dim=-1)
+ else:
+ raise ValueError("Invalid modality")
+
+ return query_embeds
+
+
+ def generate_heatmap(self, img, imo, txt, sound=None, modality="image"):
+
+ start_time = time.time()
+
+ # Satellite encoder outputs
+ if self.pretrained_hf_ckpt:
+ imo_embeds = self.model_local.imo_encoder(imo)
+ else:
+ imo_embeds = self.model_local.imo_encoder(imo).last_hidden_state
+ imo_embeds = self.model_local.visual_projection_custom(imo_embeds)
+ imo_embeds = torch.nn.functional.normalize(imo_embeds, dim=-1)
+
+ # Remove batch dimension -> (num_tokens, proj_dim)
+ imo_embeds = imo_embeds.squeeze(0)
+ self.patch_embeds = imo_embeds.clone()[1:].cpu().detach().numpy()
+
+ # Ground image embedding (bio CLIP model)
+ query_embeds = self.generate_query_embeds(img, imo, txt, sound=sound, modality=modality)
+
+ # Same logit scale as in SatBind
+ logit_scale = self.model_local.logit_scale.exp()
+ sim = query_embeds @ imo_embeds.t() * logit_scale
+ # Sigmoid to get similarity scores
+ scores = sim.t().sigmoid() # (num_tokens, 1)
+
+ # Exclude [CLS] token at index 0
+ score_no_cls = scores[1:].squeeze() # shape: (num_tokens-1,)
+ num_tokens = score_no_cls.shape[0]
+ side_dim = int(num_tokens**0.5)
+ sim_scores = score_no_cls.reshape(side_dim, side_dim).clone()
+ sim_scores = sim_scores.cpu().detach().numpy()
+
+ self.clip_inference_time = time.time() - start_time
+
+ # Gausian Smoothing
+ if self.blur_kernel != (0,0):
+ sim_scores = cv2.GaussianBlur(sim_scores, self.blur_kernel, 0)
+
+ # Normalize to expectation
+ self.heatmap_unnormalized = sim_scores
+ scale = len(self.target_positions) / (self.heatmap_unnormalized.sum())
+ self.heatmap_unnormalized *= scale
+ if self.heatmap_unnormalized_initial is None:
+ self.heatmap_unnormalized_initial = self.heatmap_unnormalized.copy()
+
+ # Standard normalization to (0,1)
+ self.heatmap = sim_scores.copy()
+ self.heatmap = (self.heatmap - self.heatmap.min()) / (self.heatmap.max() - self.heatmap.min())
+
+
+ def visualize_heatmap(
+ self,
+ step: int,
+ img_path_viz: str,
+ imo_path_viz: str,
+ patch_idx_viz: torch.Tensor,
+ patch_is_pos: list,
+ species_name: str
+ ):
+ """
+ Visualization function that plots the ground image, satellite image with
+ highlighted patch, and the learned heatmap.
+ """
+
+ # Switch off gradients for visualization
+ with torch.no_grad():
+ side_dim = self.heatmap.shape[0]
+
+ # -----------------------------------------------------------------
+ # Highlight the patch in the satellite image
+ sat_img_orig = Image.open(imo_path_viz)
+ sat_highlight = np.array(
+ self.dataset.debug_imo_viz_transform(sat_img_orig.copy())
+ )
+
+ for idx, patch_idx in enumerate(patch_idx_viz):
+
+ # Because patch_idx includes the [CLS] offset, subtract 1
+ patch_idx_actual = patch_idx - 1
+
+ # Get dimensions (H x W)
+ H, W = sat_highlight.shape[0], sat_highlight.shape[1]
+
+ # Number of patches in each dimension
+ patches_per_col = W // config.patch_size
+ patches_per_row = H // config.patch_size
+
+ # Determine row/col in the patch grid
+ patch_row = patch_idx_actual // patches_per_col
+ patch_col = patch_idx_actual % patches_per_row
+
+ # Pixel boundaries
+ x_start = patch_col * config.patch_size
+ x_end = (patch_col + 1) * config.patch_size
+ y_start = patch_row * config.patch_size
+ y_end = (patch_row + 1) * config.patch_size
+
+ # Blue color for positive patches (transparent)
+ if patch_is_pos[idx]:
+ sat_highlight[y_start:y_end, x_start:x_end, 0] = 0
+ sat_highlight[y_start:y_end, x_start:x_end, 1] = 0
+ sat_highlight[y_start:y_end, x_start:x_end, 2] = 255
+ # Red color for negative patches (transparent)
+ else:
+ sat_highlight[y_start:y_end, x_start:x_end, 0] = 255
+ sat_highlight[y_start:y_end, x_start:x_end, 1] = 0
+ sat_highlight[y_start:y_end, x_start:x_end, 2] = 0
+
+
+ # -----------------------------------------------------------------
+ # Plot results
+ fig, axes = plt.subplots(1, 3, figsize=(12, 6))
+ fig.suptitle(f"Query: {species_name}")
+
+ # Ground image
+ img_orig = Image.open(img_path_viz)
+ axes[0].imshow(img_orig)
+ axes[0].set_title("Ground Image")
+ axes[0].axis("off")
+
+ # Satellite image
+ axes[1].imshow(sat_highlight)
+ axes[1].set_title("Sat Image")
+ axes[1].axis("off")
+
+ # Heatmap
+ heatmap_np = self.heatmap_unnormalized
+ im = axes[2].imshow(heatmap_np, cmap="viridis")
+ axes[2].set_title(
+ f"Heatmap at TTA Step {step:03d} ({side_dim}x{side_dim})"
+ )
+ axes[2].axis("off")
+ fig.colorbar(im, ax=axes[2], fraction=0.046, pad=0.04)
+
+ plt.tight_layout()
+ plt.show()
+
+
+if __name__ == "__main__":
+
+ ###########################################
+ # Example with Image Modality
+ ###########################################
+
+ clip_seg_tta = ClipSegTTA(
+ img_dir="/mnt/hdd/avs_bench_ds/inat21",
+ imo_dir="/mnt/hdd/avs_bench_ds/sat_jpg/train_512px",
+ json_path="/mnt/hdd/avs_bench_ds/inat21/train.json",
+ sat_to_img_ids_path="search_tri_modal|val_in_domain",
+ patch_size=14,
+ load_pretrained_hf_ckpt=True,
+ sat_checkpoint_path="",
+ sample_index=261,
+ batch_size=1,
+ num_workers=1,
+ device="cuda",
+ sat_to_img_ids_json_is_train_dict=False,
+ query_modality="image"
+ )
+
+ # Image modality test
+ patch_indices = [50, 357]
+ patch_is_pos = [True, False]
+ pos_sample_weight = 1.0
+ neg_sample_weight = 1.0
+ clip_seg_tta.execute_tta(
+ patch_indices,
+ patch_is_pos,
+ pos_sample_weight,
+ neg_sample_weight,
+ tta_steps=10, # for sanity check
+ num_viz_steps=2,
+ viz_heatmap=True
+ )
+
+ ###########################################
+ # Example with Sound Modality
+ ###########################################
+
+ # # Sound Modality Test
+ # clip_seg_tta = ClipSegTTA(
+ # img_dir="/mnt/hdd/avs_bench_ds/inat21",
+ # imo_dir="/mnt/hdd/avs_bench_ds/sat_jpg/train_512px",
+ # json_path="/mnt/hdd/avs_bench_ds/inat21/train.json",
+ # sat_to_img_ids_path="search_quad_modal|val_in_domain",
+ # sound_dir='/mnt/hdd/avs_bench_ds/sound_mp3/test',
+ # patch_size=14,
+ # sat_checkpoint_path="",
+ # sound_checkpoint_path = "",
+ # sample_index=120,
+ # batch_size=1,
+ # num_workers=1,
+ # device="cuda",
+ # sat_to_img_ids_json_is_train_dict=False,
+ # query_modality="sound"
+ # )
+
+ # # Sound modality test
+ # patch_indices = [422, 32]
+ # patch_is_pos = [True, False]
+ # pos_sample_weight = 1.0
+ # neg_sample_weight = 1.0
+ # clip_seg_tta.execute_tta(
+ # patch_indices,
+ # patch_is_pos,
+ # pos_sample_weight,
+ # neg_sample_weight,
+ # tta_steps=30,
+ # num_viz_steps=2,
+ # viz_heatmap=True
+ # )
\ No newline at end of file
diff --git a/taxabind_avs/satbind/clip_vision_per_patch_model.py b/taxabind_avs/satbind/clip_vision_per_patch_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7461295d79f9cf82904bdf59120827e0d954c53
--- /dev/null
+++ b/taxabind_avs/satbind/clip_vision_per_patch_model.py
@@ -0,0 +1,32 @@
+##############################################################################
+# Name: clip_vision_per_patch_model.py
+#
+# - Overloads CLIP template with custom functions
+###############################################################################
+
+import torch
+from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig
+
+class CLIPVisionPerPatchModel(CLIPVisionModelWithProjection):
+ """
+ Like CLIPVisionModelWithProjection but returns
+ per-patch embeddings instead of pooled CLS tokens.
+ """
+ def __init__(self, config: CLIPVisionConfig):
+ super().__init__(config)
+ # everything else (self.vision_model, self.visual_projection)
+ # is set up for you by the parent class
+
+ def forward(self, pixel_values, **kwargs):
+ # 1) run the ViT backbone β last_hidden_state [B, n_patches, hidden_size]
+ outputs = self.vision_model(pixel_values, return_dict=True, **kwargs)
+ hidden_states = outputs.last_hidden_state
+
+ # 2) project every patch token β [B, n_patches, projection_dim]
+ patch_embeds = self.visual_projection(hidden_states)
+
+ # 3) Postprocessing embeds
+ patch_embeds = torch.nn.functional.normalize(patch_embeds, dim=-1)
+ patch_embeds = patch_embeds.squeeze() # (Patches, proj_dim)
+
+ return patch_embeds
\ No newline at end of file
diff --git a/taxabind_avs/satbind/config_sat.py b/taxabind_avs/satbind/config_sat.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1b0cd3ecfd52860c9ea3e208ad0bc0e967aa25c
--- /dev/null
+++ b/taxabind_avs/satbind/config_sat.py
@@ -0,0 +1,45 @@
+##############################################################################
+# Name: config_sat.py
+#
+# - Parameters for train/eval sat encoder
+###############################################################################
+
+from easydict import EasyDict as edict
+
+config = edict()
+
+# Pixel level CLIP training
+config.img_dir = '/mnt/hdd/avs_bench_ds/inat21'
+config.imo_dir = '/mnt/hdd/avs_bench_ds/sat_jpg/train_512px'
+config.imo_dir_val = '/mnt/hdd/avs_bench_ds/sat_jpg/test_512px'
+config.train_json_path = '/mnt/hdd/avs_bench_ds/inat21/train.json'
+config.val_json_path = '/mnt/hdd/avs_bench_ds/inat21/val.json'
+config.sat_to_img_ids_train_json_path = 'clip_tri_modal|train'
+config.sat_to_img_ids_val_json_path = 'clip_tri_modal|val'
+
+# batch_size * accumulate_grad_batches * devices = constant (i.e. 256 * 8 * 2 = 4096)
+config.batch_size = 32
+config.lr = 1e-4
+config.accumulate_grad_batches = 64
+config.max_epochs = 20
+config.num_workers = 16
+config.devices = 2
+config.val_check_interval = 0.5
+config.sat_encoder = 'openai/clip-vit-large-patch14-336'
+config.avs_dataset = 'derektan95/avs-bench'
+config.patch_size = 14
+
+config.save_dir = 'checkpoints'
+config.filename = 'satbind-{epoch:02d}-{val_loss:.2f}'
+
+config.locked_tuning = True
+
+config.resume_from_checkpoint = False
+config.resume_checkpoint_name = 'satbind-resume'
+
+# huggingface finetuned
+config.image_encoder_finetuned = 'hf-hub:imageomics/bioclip'
+config.sound_encoder_finetuned = 'derektan95/search-tta-sound' # For eval only
+config.sat_encoder_finetuned = 'derektan95/search-tta-sat' # For eval only
+
+print("config: \n", config)
\ No newline at end of file
diff --git a/taxabind_avs/satbind/dataset.py b/taxabind_avs/satbind/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..96ae92418070dfe56072d123abcba1a35a4e4acb
--- /dev/null
+++ b/taxabind_avs/satbind/dataset.py
@@ -0,0 +1,215 @@
+##############################################################################
+# Name: dataset.py
+#
+# - Handles loading of trimodal dataset
+# - https://huggingface.co/datasets/derektan95/avs-bench
+###############################################################################
+
+import os
+import sys
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
+
+import math
+import json
+import torch
+from pathlib import Path
+from torch.utils.data import Dataset
+from torchvision import transforms
+from PIL import Image
+from datasets import load_dataset
+from config_sat import config
+
+
+class SatNatDataset(Dataset):
+ def __init__(self, img_dir, imo_dir, json_path, sat_to_img_ids_path, patch_size, mode='train', get_img_path=False, sat_to_img_ids_json_is_train_dict=True, tax_to_filter_val="", sound_dir=None):
+ self.img_dir = img_dir
+ self.imo_dir = imo_dir
+ self.patch_size = patch_size
+ self.get_img_path = get_img_path
+ self.mode = mode
+ self.sat_to_img_ids_json_is_train_dict = sat_to_img_ids_json_is_train_dict
+ self.tax_to_filter_val = tax_to_filter_val
+ self.sound_dir = sound_dir
+ self.current_epoch = 0
+
+ ## NOTE: Removed as unnecessary for app.py
+ # self.json = json.load(open(json_path, 'r'))
+ # self.images = self.json['images']
+ # self.annot = self.json['annotations']
+ # for i in range(len(self.images)):
+ # assert self.images[i]['id'] == self.annot[i]['id']
+ # self.images[i]['label'] = self.annot[i]['category_id']
+ # self.filtered_json = [d for d in self.images if d['latitude'] is not None and d['longitude'] is not None]
+ # self.species_text = list(set([" ".join(d['file_name'].split("/")[1].split("_")[1:]) for d in self.filtered_json]))
+ # self.inat_json_dict = {
+ # "images": {img["id"]: img for img in self.images},
+ # "annotations": {ann["id"]: ann for ann in self.annot},
+ # }
+
+ # # Load from huggingface dataset
+ # ds_config = sat_to_img_ids_path.split("|")[0]
+ # ds_split = sat_to_img_ids_path.split("|")[1]
+ # self.sat_to_img_ids_json = load_dataset(config.avs_dataset, name=ds_config, split=ds_split)
+ # print("Loaded huggingface dataset: ", ds_config, ds_split)
+
+ # # Expand dict
+ # self.sat_to_img_ids_tuples = []
+ # if self.sat_to_img_ids_json_is_train_dict:
+ # # Convert from a huggingface list of dicts into dict of dicts (no duplicate keys)
+ # self.sat_to_img_ids_json = {sat_sample["sat_key"]: sat_sample for sat_sample in self.sat_to_img_ids_json}
+ # for sat_key, sat_sample in self.sat_to_img_ids_json.items():
+ # id = sat_sample["id"]
+ # sat_path = sat_sample["sat_path"]
+ # img_ids = sat_sample["img_ids"]
+ # for img_id in img_ids:
+ # self.sat_to_img_ids_tuples.append((id, sat_path, img_id))
+ # print("len(self.sat_to_img_ids_json): ", len(self.sat_to_img_ids_json))
+ # print("len(self.sat_to_img_ids_tuples): ", len(self.sat_to_img_ids_tuples))
+ # else:
+ # self.filtered_val_ds_by_tax = [d for d in self.sat_to_img_ids_json if self.tax_to_filter_val in d['taxonomy']]
+
+ if mode == 'train':
+ self.img_transform = transforms.Compose([
+ transforms.Resize((256, 256)),
+ transforms.RandomCrop((224, 224)),
+ transforms.RandomHorizontalFlip(0.5),
+ transforms.GaussianBlur(5, (0.01, 1.0)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ ])
+ self.imo_transform = transforms.Compose([
+ transforms.Resize((336,336)),
+ transforms.GaussianBlur(5, (0.01, 1.0)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ ])
+ else:
+ self.img_transform = transforms.Compose([
+ transforms.Resize((256, 256)),
+ transforms.CenterCrop((224, 224)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ ])
+ self.imo_transform = transforms.Compose([
+ transforms.Resize((336,336)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ ])
+ self.debug_img_viz_transform = transforms.Compose([
+ transforms.Resize((256, 256)),
+ transforms.CenterCrop((224, 224))
+ ])
+ self.debug_imo_viz_transform = transforms.Compose([
+ transforms.Resize((336,336))
+ ])
+
+
+ def __len__(self):
+ return len(self.sat_to_img_ids_tuples)
+
+
+ def __getitem__(self, idx):
+
+ if not self.sat_to_img_ids_json_is_train_dict:
+ print("Json is not dict. Please reformat for training!")
+ exit()
+
+ ## Pixel-level CLIP
+ id, sat_path, img_id = self.sat_to_img_ids_tuples[idx]
+ imo_path = os.path.join(self.imo_dir, sat_path)
+ imo = self.imo_transform(Image.open(imo_path))
+ sat_id = Path(sat_path).stem
+
+ img_path = os.path.join(self.img_dir, self.inat_json_dict["images"][img_id]["file_name"])
+ img = self.img_transform(Image.open(img_path))
+
+ # # Map lat-lon to pixel in sat img
+ sat_min_lon = self.sat_to_img_ids_json[sat_id]["sat_bounds"]["min_lon"]
+ sat_min_lat = self.sat_to_img_ids_json[sat_id]["sat_bounds"]["min_lat"]
+ sat_max_lon = self.sat_to_img_ids_json[sat_id]["sat_bounds"]["max_lon"]
+ sat_max_lat = self.sat_to_img_ids_json[sat_id]["sat_bounds"]["max_lat"]
+
+ img_lon = self.inat_json_dict["images"][img_id]["longitude"]
+ img_lat = self.inat_json_dict["images"][img_id]["latitude"]
+ row, col = self.latlon_to_pixel(img_lat, img_lon, sat_min_lat, sat_max_lat, sat_min_lon, sat_max_lon, imo.shape[2], imo.shape[1])
+
+ patch_idx = self.pixel_to_patch_idx(row, col, self.patch_size, imo.shape[2], imo.shape[1])
+ patch_idx += 1 # account for [CLS] token at the start of ViT input sequence
+
+ species_text = " ".join(self.inat_json_dict["images"][img_id]['file_name'].split("/")[1].split("_")[1:])
+
+ if self.get_img_path:
+ return img_path, imo_path, img, imo, self.inat_json_dict["annotations"][img_id]['category_id'], patch_idx, species_text, self.species_text.index(species_text)
+ else:
+ return img, imo, self.inat_json_dict["annotations"][img_id]['category_id'], patch_idx, species_text, self.species_text.index(species_text)
+
+
+ def latlon_to_pixel(self, lat, lon, lat_min, lat_max, lon_min, lon_max, img_width, img_height):
+ lat_res = (lat_max - lat_min) / img_height
+ lon_res = (lon_max - lon_min) / img_width
+ col = int(math.floor((lon - lon_min) / lon_res))
+ row = int(math.floor((lat_max - lat) / lat_res))
+ return row, col
+
+
+ def pixel_to_patch_idx(self, row, col, patch_size, img_width, img_height):
+ patch_size_width = patch_size
+ patch_size_height = patch_size
+ patch_row = row // patch_size_height
+ patch_col = col // patch_size_width
+ patch_idx = patch_row * (img_width // patch_size) + patch_col
+ return patch_idx
+
+
+ def set_epoch(self, epoch):
+ self.current_epoch = epoch
+
+
+ def get_search_ds_data(self, idx):
+
+ if self.sat_to_img_ids_json_is_train_dict:
+ print("Json is dict. Please reformat for target search!")
+ exit()
+
+ bounded_idx = idx % len(self.filtered_val_ds_by_tax)
+ sat_sample = self.filtered_val_ds_by_tax[bounded_idx]
+ target_positions = sat_sample["target_positions"]
+ imo_path = os.path.join(self.imo_dir, sat_sample["sat_path"])
+ imo = self.imo_transform(Image.open(imo_path))
+
+ img_paths = []
+ imgs = []
+ species_texts = []
+ for img_id in sat_sample["img_ids"]:
+ img_path = os.path.join(self.img_dir, self.inat_json_dict["images"][img_id]["file_name"])
+ img = self.img_transform(Image.open(img_path))
+ img_paths.append(img_path)
+ imgs.append(img)
+ species_text = " ".join(self.inat_json_dict["images"][img_id]['file_name'].split("/")[1].split("_")[1:])
+ species_texts.append(species_text)
+ imgs = torch.stack(imgs)
+
+ if len(set(species_texts)) > 1:
+ print("Species mismatch in search dataset!")
+ exit()
+ else:
+ species_name = species_texts[0]
+ gt_mask_name = str(sat_sample["id"]) + "_" + sat_sample["taxonomy"] + ".png"
+ gt_mask_name = gt_mask_name.replace(" ", "_")
+
+ # Consider sound if valid
+ sounds, sound_ids = [], []
+ if self.sound_dir is not None and "sound_ids" in sat_sample:
+ sound_id = sat_sample["sound_ids"][0]
+ sound_path = os.path.join(self.sound_dir,"sounds_mp3",str(sound_id)+"."+'mp3')
+
+ from soundbind.sound_encoder import get_audio_clap
+ sound = get_audio_clap(sound_path)
+ sounds.append(sound)
+ sound_ids.append(sound_id)
+
+ return img_paths, imo_path, imgs, imo, sounds, sound_ids, species_name, target_positions, gt_mask_name
\ No newline at end of file
diff --git a/taxabind_avs/satbind/kmeans_clustering.py b/taxabind_avs/satbind/kmeans_clustering.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4f4657e6570900555fa3928d864e3c2a5b19493
--- /dev/null
+++ b/taxabind_avs/satbind/kmeans_clustering.py
@@ -0,0 +1,324 @@
+##############################################################################
+# Name: kmeans_clustering.py
+#
+# - Performs k-means clustering on a patch embedding matrix
+###############################################################################
+
+import math
+import os
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib.patches as mpatches
+from sklearn.cluster import KMeans
+from sklearn.metrics import silhouette_score
+from kneed import KneeLocator
+
+
+class CombinedSilhouetteInertiaClusterer:
+ def __init__(
+ self,
+ k_min=1,
+ k_max=8,
+ k_avg_max=4,
+ silhouette_threshold=0.15,
+ relative_threshold=0.15,
+ random_state=0,
+ min_patch_size=5,
+ n_smooth_iter=2,
+ ignore_label=-1,
+ plot=False,
+ gifs_dir = "./"
+ ):
+ """
+ Parameters
+ ----------
+ k_min : int
+ Minimum number of clusters for KMeans.
+ k_max : int
+ Maximum number of clusters for KMeans.
+ k_avg_max : int
+ Upper bound on k after combining elbow & silhouette if they disagree.
+ silhouette_threshold : float
+ Minimum silhouette score at k=2 to justify splitting.
+ relative_threshold : float
+ Minimum % improvement in inertia from k=1βk=2 to justify splitting.
+ random_state : int
+ RNG seed for KMeans.
+ min_patch_size : int
+ Patches smaller than this threshold are smoothed.
+ n_smooth_iter : int
+ Number of smoothing iterations.
+ ignore_label : int
+ Label to ignore in smoothing step.
+ """
+ self.k_min = k_min
+ self.k_max = k_max
+ self.k_avg_max = k_avg_max
+ self.silhouette_threshold = silhouette_threshold
+ self.relative_threshold = relative_threshold
+ self.random_state = random_state
+
+ self.min_patch_size = min_patch_size
+ self.n_smooth_iter = n_smooth_iter
+ self.ignore_label = ignore_label
+ self.plot = False #plot
+ self.gifs_dir = gifs_dir
+
+ self.final_k = None
+ self.final_labels_1d = None
+ self.smoothed_labels_2d = None
+ self.kmeans_frame_files = []
+
+ ##############################
+ # Helper functions
+ ##############################
+
+ def combined_silhouette_inertia_clustering(
+ self,
+ X,
+ k_min=1,
+ k_max=8,
+ k_avg_max=4,
+ silhouette_threshold=0.2,
+ relative_threshold=0.05,
+ random_state=0
+ ):
+ """
+ Runs KMeans for k in [k_min..k_max] exactly once each,
+ collects silhouette scores & inertias, and returns best_k.
+ """
+ n_samples = len(X)
+ if n_samples < 2:
+ return 1, np.zeros(n_samples, dtype=int), [None], [None]
+
+ # --- Fit once for k=1 ---
+ km1 = KMeans(n_clusters=1, random_state=random_state).fit(X)
+ inertia_k1 = km1.inertia_ / n_samples
+ silhouette_k1 = None # undefined for k=1
+
+ # If k_max=1, no reason to check further
+ if k_max < 2:
+ return 1, km1.labels_, [silhouette_k1], [inertia_k1]
+
+ # --- Fit once for k=2 ---
+ km2 = KMeans(n_clusters=2, random_state=random_state).fit(X)
+ inertia_k2 = km2.inertia_ / n_samples
+ sil_k2 = silhouette_score(X, km2.labels_)
+ relative_improvement = (inertia_k1 - inertia_k2) / inertia_k1
+
+ # If improvement is too small or silhouette is too low => remain at k=1
+ if (relative_improvement < relative_threshold) or (sil_k2 < silhouette_threshold):
+ return 1, km1.labels_, [silhouette_k1, sil_k2], [inertia_k1, inertia_k2]
+
+ # --- Otherwise fit k=2..k_max and gather inertias & silhouettes ---
+ all_k = range(2, k_max + 1)
+ kmeans_models = {}
+ inertias = []
+ silhouettes = []
+
+ # We already have k=2
+ kmeans_models[2] = km2
+ inertias.append(inertia_k2)
+ silhouettes.append(sil_k2)
+
+ for k in range(3, k_max + 1):
+ km = KMeans(n_clusters=k, random_state=random_state).fit(X)
+ kmeans_models[k] = km
+
+ norm_inertia = km.inertia_ / n_samples
+ inertias.append(norm_inertia)
+
+ # If k>n_samples, silhouette_score is meaningless, but in normal usage k< default to k=1.")
+ best_k_elbow = 1 # fallback
+
+ print(f"Silhouette-based best_k={best_k_sil}, elbow-based best_k={best_k_elbow}")
+
+ # Combine if there's disagreement
+ if best_k_sil == best_k_elbow:
+ final_k = max(1, min(best_k_sil, k_avg_max)) # best_k_sil
+ else:
+ avg_k = 0.5 * (best_k_sil + best_k_elbow)
+ final_k = int(math.ceil(avg_k))
+ final_k = max(1, min(final_k, k_avg_max))
+
+ assert (final_k <= k_avg_max), f"Final k={final_k} is greater than k_avg_max={k_avg_max}"
+
+ # Get final labels from the chosen KMeans model
+ if final_k == 1:
+ final_labels = km1.labels_
+ else:
+ final_labels = kmeans_models[final_k].labels_
+
+ return final_k, final_labels, [silhouette_k1] + silhouettes, [inertia_k1] + inertias
+
+
+ def compute_region_statistics(self, label_map, heatmap, visited_indices, episode_num=0, step_num=0):
+ """
+ Computes region statistics for the current smoothed label map.
+ """
+ # Flatten the cluster map and the heatmap to handle indexing uniformly
+ label_map_2d = label_map
+ label_map_1d = label_map.ravel()
+ heatmap_1d = heatmap.ravel()
+
+ # Identify unique labels (excluding ignore_label if present)
+ unique_labels = np.unique(label_map_1d)
+ region_dict = {}
+ for lbl in unique_labels:
+ if lbl == self.ignore_label:
+ continue
+ region_dict[lbl] = {
+ 'num_patches': 0,
+ 'patches_visited': 0,
+ 'expectation': 0.0
+ }
+
+ # Accumulate totals for all patches
+ total_patches = len(label_map_1d)
+ for i in range(total_patches):
+ lbl = label_map_1d[i]
+ if lbl == self.ignore_label:
+ continue
+ region_dict[lbl]['num_patches'] += 1
+ region_dict[lbl]['expectation'] += float(heatmap_1d[i])
+
+ # # Exponential distribution (waiting time) = num_patches / expected_num_tgts
+ for lbl in region_dict:
+ region_dict[lbl]['expectation'] = region_dict[lbl]['num_patches'] / region_dict[lbl]['expectation']
+
+ # Count only unique visited patches by converting to a set.
+ unique_visited = set(visited_indices)
+ for vi in unique_visited:
+ if vi < 0 or vi >= total_patches:
+ continue
+ lbl = label_map_1d[vi]
+ if lbl == self.ignore_label:
+ continue
+ region_dict[lbl]['patches_visited'] += 1
+
+ if self.plot:
+ self.plot_cluster_map(label_map_2d, heatmap, visited_indices, region_dict, episode_num, step_num)
+
+ return region_dict
+
+
+ def plot_cluster_map(self, cluster_map, heatmap, path_taken, region_stats_dict, episode_num, step_num, cmap='tab20'):
+
+ # 4) Plot (side-by-side) if requested
+ fig, axes = plt.subplots(1, 3, figsize=(12, 6))
+
+ axes[0].imshow(cluster_map, cmap='tab20')
+ axes[0].set_title(f"Raw KMeans Clusters")
+ axes[0].axis('off')
+
+ axes[1].imshow(heatmap, cmap="viridis")
+ axes[1].set_title("Heatmap")
+ axes[1].axis('off')
+
+ axes[2].imshow(cluster_map, cmap='tab20')
+ axes[2].set_title("Raw KMeans Clusters")
+ axes[2].axis('off')
+
+ path_rows, path_cols = [], []
+ for i, idx in enumerate(path_taken):
+ rr = idx // cluster_map.shape[1]
+ cc = idx % cluster_map.shape[1]
+ path_rows.append(rr)
+ path_cols.append(cc)
+ axes[2].plot(path_cols, path_rows, c="r", linewidth=2)
+ axes[2].plot(path_cols[-1], path_rows[-1], markersize=12, zorder=99, marker="^", ls="-", c="r", mec="black")
+ axes[2].plot(path_cols[0], path_rows[0], 'co', c="r", markersize=8, zorder=5)
+
+ # Create legend patches for each region.
+ unique_labels = sorted(region_stats_dict.keys())
+ max_label = max(unique_labels) if unique_labels else 1
+ cm = plt.get_cmap(cmap)
+ legend_patches = []
+ for lbl in unique_labels:
+ # Normalize the label to [0,1] for colormap lookup.
+ norm_value = lbl / max_label if max_label > 0 else 0.5
+ color = cm(norm_value)
+ patch = mpatches.Patch(color=color, label=f"R{lbl}")
+ legend_patches.append(patch)
+
+ # Add legends to both subplots.
+ axes[0].legend(handles=legend_patches, title="Regions", loc='upper right')
+ axes[2].legend(handles=legend_patches, title="Regions", loc='upper right')
+
+ # Build the legend text for each region using the provided format:
+ legend_lines = []
+ for label, stats in region_stats_dict.items():
+ line = f"R{label}: patches={stats['num_patches']}, E={stats['expectation']:.3f}, visited={stats['patches_visited']}"
+ legend_lines.append(line)
+ legend_text = "\n".join(legend_lines)
+
+ # Add the legend text as a subtitle at the bottom of the figure
+ fig.text(0.5, 0.05, legend_text, ha='center', va='bottom', fontsize=10,
+ bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'))
+
+ # Adjust layout to reserve space for the legend text
+ plt.tight_layout()
+ plt.subplots_adjust(bottom=0.1)
+
+ if not os.path.exists(self.gifs_dir):
+ os.makedirs(self.gifs_dir)
+
+ plt.savefig(f'{self.gifs_dir}/kmeans_{episode_num}_{step_num}.png'.format(dpi=150))
+ self.kmeans_frame_files.append(f'{self.gifs_dir}/kmeans_{episode_num}_{step_num}.png')
+ plt.close()
+
+
+ def get_label_id(self, label_map, patch_idx):
+ return label_map.ravel()[patch_idx]
+
+
+ def get_probs(self, patch_idx, heatmap):
+ return heatmap.ravel()[patch_idx]
+
+
+ ##############################
+ # Main functions
+ ##############################
+ def fit_predict(self, patch_embeds, map_shape):
+ """
+ Main function to obtain smoothed labelmap
+ """
+ # 1) Run combined silhouette & inertia
+ best_k, final_labels, silhouettes, inertias = self.combined_silhouette_inertia_clustering(
+ X=patch_embeds,
+ k_min=self.k_min,
+ k_max=self.k_max,
+ k_avg_max=self.k_avg_max,
+ silhouette_threshold=self.silhouette_threshold,
+ relative_threshold=self.relative_threshold,
+ random_state=self.random_state
+ )
+ self.final_k = best_k
+ self.final_labels_1d = final_labels.copy()
+
+ # 2) Reshape for display
+ H, W = map_shape
+ cluster_map = final_labels.reshape(H, W)
+
+ # 3) Apply smoothing
+ cluster_map_smoothed = cluster_map
+ self.smoothed_labels_2d = cluster_map_smoothed.copy()
+
+ # 5) Return the smoothed 2D labels
+ return self.smoothed_labels_2d
diff --git a/taxabind_avs/satbind/model_sat.py b/taxabind_avs/satbind/model_sat.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b00d7cd0b05cc712b7dc776d5cb54db4acc7ffb
--- /dev/null
+++ b/taxabind_avs/satbind/model_sat.py
@@ -0,0 +1,190 @@
+##############################################################################
+# Name: model_sat.py
+#
+# - Training wrapper for satellite image CLIP model
+###############################################################################
+
+import os
+import sys
+sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
+
+import open_clip
+import pytorch_lightning as pl
+import torch
+import torch.nn as nn
+import numpy as np
+from torch.utils.data import DataLoader
+from transformers import CLIPVisionModelWithProjection
+from dataset import SatNatDataset
+from pytorch_lightning.callbacks import ModelCheckpoint
+from config_sat import config
+
+
+def create_pairwise_mask(labels):
+ labels = labels.reshape(-1)
+ num_samples = len(labels)
+ pairwise_mask = torch.zeros(num_samples, num_samples).to(labels.device)
+
+ for i in range(num_samples):
+ pairwise_mask[i, :] = (labels == labels[i])
+
+ return pairwise_mask
+
+def clip_loss(similarity: torch.Tensor, label) -> torch.Tensor:
+ overhead_img_loss = contrastive_loss(similarity, label)
+ ground_img_loss = contrastive_loss(similarity.t(), label.t())
+ return 0.5*torch.mean(torch.sum(overhead_img_loss, dim=-1)) + 0.5*torch.mean(torch.sum(ground_img_loss, dim=-1))
+
+def contrastive_loss(logits: torch.Tensor, label) -> torch.Tensor:
+ gt = create_pairwise_mask(label)
+ return -gt*torch.log(logits.softmax(-1)+1e-6)
+
+
+class SatBind(pl.LightningModule):
+ def __init__(self, train_dataset, val_dataset, **kwargs):
+ super().__init__()
+ self.train_dataset = train_dataset
+ self.val_dataset = val_dataset
+
+ #initialize bio CLIP with frozen weights
+ self.bio_model, *_ = open_clip.create_model_and_transforms(config.image_encoder_finetuned)
+ if config.locked_tuning:
+ for param in self.bio_model.parameters():
+ param.requires_grad = False
+
+ #initialize CLIP with trainable weights
+ self.imo_encoder = CLIPVisionModelWithProjection.from_pretrained(config.sat_encoder).train()
+ for layer in self.imo_encoder.children():
+ if hasattr(layer, 'reset_parameters'):
+ layer.reset_parameters()
+
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+ self.batch_size = kwargs.get('batch_size', config.batch_size)
+ self.lr = kwargs.get('lr', config.lr)
+
+ # Custom
+ clip_cfg = self.imo_encoder.config
+ self.visual_projection_custom = nn.Linear(clip_cfg.hidden_size, 512, bias=False) # clip_cfg.projection_dim)
+
+
+ def forward(self, batch):
+ img, imo, label, patch_idx, *_ = batch
+ batch_size = img.shape[0]
+
+ #compute bioclip embeddings
+ img_embeds, *_ = self.bio_model(img) # (batch_size, proj_dim)
+
+ # Similarity computation
+ imo_embeds = self.imo_encoder(imo).last_hidden_state # (batch, Patches, hidden_dim)
+ imo_embeds = imo_embeds[torch.arange(batch_size), patch_idx] # (batch, hidden_dim)
+ imo_embeds = self.visual_projection_custom(imo_embeds) # (batch_size, proj_dim)
+
+ return img_embeds, imo_embeds, label
+
+
+ def shared_step(self, batch, return_sim_matrix=False):
+
+ img_embeds, imo_embeds, label, *_ = self(batch)
+ imo_embeds = torch.nn.functional.normalize(imo_embeds, dim=-1)
+
+ #exponentiate the log of temperrature
+ logit_scale = self.logit_scale.exp()
+
+ #compute similarity
+ img_to_imo_sim = img_embeds @ imo_embeds.t() * logit_scale
+
+ if return_sim_matrix:
+ img_to_imo_sim_copy = img_to_imo_sim.clone().detach()
+
+ loss = clip_loss(img_to_imo_sim, label)
+
+ if return_sim_matrix:
+ return loss, img_to_imo_sim_copy
+ else:
+ return loss
+
+
+ def training_step(self, batch, batch_idx):
+ loss = self.shared_step(batch)
+ self.log('train_loss', loss, sync_dist=True, prog_bar=True, on_epoch=True, batch_size=self.batch_size)
+ self.log('temperature', self.logit_scale.data, prog_bar=True, on_epoch=True, batch_size=self.batch_size)
+ return loss
+
+ def validation_step(self, batch, batch_idx):
+ loss = self.shared_step(batch)
+ self.log('val_loss', loss, sync_dist=True, prog_bar=True, on_epoch=True, batch_size=self.batch_size)
+ return loss
+
+ def train_dataloader(self):
+ return DataLoader(self.train_dataset,
+ batch_size=self.batch_size,
+ num_workers=config.num_workers,
+ shuffle=True, # True
+ persistent_workers=False)
+
+ def val_dataloader(self):
+ return DataLoader(self.val_dataset,
+ batch_size=self.batch_size,
+ num_workers=config.num_workers,
+ shuffle=False,
+ persistent_workers=False)
+
+ def configure_optimizers(self):
+ params = self.parameters()
+ self.optim = torch.optim.AdamW(params,
+ lr=self.lr,
+ betas=(0.9,0.98),
+ eps=1e-6
+ )
+ self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
+ optimizer=self.optim,
+ T_0=20,
+ eta_min=1e-6
+ )
+ return [self.optim], [self.scheduler]
+
+
+if __name__ == '__main__':
+ img_dir = config.img_dir
+ imo_dir = config.imo_dir
+ imo_dir_val = config.imo_dir_val
+ train_json_path = config.train_json_path
+ val_json_path = config.val_json_path
+ sat_to_img_ids_train_json_path = config.sat_to_img_ids_train_json_path
+ sat_to_img_ids_val_json_path = config.sat_to_img_ids_val_json_path
+ patch_size = config.patch_size
+
+ #define dataset
+ train_dataset = SatNatDataset(img_dir, imo_dir, train_json_path, sat_to_img_ids_train_json_path, patch_size)
+ val_dataset = SatNatDataset(img_dir, imo_dir_val, val_json_path, sat_to_img_ids_val_json_path, patch_size, mode='val')
+
+ #define model
+ model = SatBind(train_dataset=train_dataset, val_dataset=val_dataset)
+ torch.cuda.empty_cache()
+
+ checkpoint = ModelCheckpoint(
+ monitor='val_loss',
+ dirpath=config.save_dir,
+ filename=config.filename,
+ mode='min',
+ save_top_k=1,
+ save_last=True
+ )
+ checkpoint.CHECKPOINT_NAME_LAST = config.filename + "-LAST"
+
+ trainer = pl.Trainer(
+ accelerator='gpu',
+ strategy='ddp_find_unused_parameters_true', # supress pl issues with 'unused trainable params'
+ devices=config.devices,
+ max_epochs=config.max_epochs,
+ num_nodes=1,
+ callbacks=[checkpoint],
+ accumulate_grad_batches=config.accumulate_grad_batches,
+ log_every_n_steps=1,
+ val_check_interval=config.val_check_interval,
+ )
+
+ if config.resume_from_checkpoint:
+ trainer.fit(model, ckpt_path=f"{config.save_dir}/{config.resume_checkpoint_name}.ckpt")
+ else:
+ trainer.fit(model)
\ No newline at end of file
diff --git a/taxabind_avs/scripts/download_sat_imgs.py b/taxabind_avs/scripts/download_sat_imgs.py
new file mode 100644
index 0000000000000000000000000000000000000000..394f429fec75d5688e16e1926fa7f87e40bb85ac
--- /dev/null
+++ b/taxabind_avs/scripts/download_sat_imgs.py
@@ -0,0 +1,145 @@
+##############################################################################
+# Name: download_sat_imgs.py
+#
+# - Downloads satellite images and relevant data from huggingface
+# - https://huggingface.co/datasets/MVRL/iSatNat
+###############################################################################
+
+import os
+import itertools
+import requests
+import time
+import threading
+from datasets import load_dataset
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from tqdm import tqdm
+from PIL import Image, UnidentifiedImageError
+from io import BytesIO
+
+# Load the dataset from Hugging Face
+mode = "train" # train or test
+ds = load_dataset("MVRL/iSatNat", split=f"{mode}")
+num_rows = len(ds)
+resize_size = 512 # Resize images to this size (i.e. same FOV - increase resolution)
+
+# Directory where images will be saved
+sat_save_dir = f"/mnt/hdd/inat2021_ds/sat_{mode}"
+os.makedirs(sat_save_dir, exist_ok=True)
+
+# Dictionary to store download failures: {key: sat_url}
+download_failures = {}
+
+# Create a global progress bar for images processed.
+pbar = tqdm(total=num_rows, desc="Images Processed")
+
+# Create a lock for thread-safe updates to the progress bar.
+progress_lock = threading.Lock()
+
+def download_image(row):
+ """
+ Download the image from row['sat_url'] and save it as sat_save_dir/{row['key']}.jpeg.
+ If the download fails, it will retry up to 3 times before recording a failure.
+ Each attempt prints a success or retry message.
+ The progress bar is updated once per image processed.
+ """
+ key = row["key"]
+ sat_url = row["sat_url"]
+ file_path = os.path.join(sat_save_dir, f"{key}.jpg")
+
+ if resize_size != 256:
+ sat_url = sat_url.replace("width=256", f"width={resize_size}")
+ sat_url = sat_url.replace("height=256", f"height={resize_size}")
+
+ # Check if file already exists; if so, skip the download.
+ if os.path.exists(file_path):
+ print(f"SKIPPED: Image for key {key} already exists.")
+ with progress_lock:
+ pbar.update(1)
+ return None
+
+ # Optional: use headers to mimic a browser.
+ headers = {
+ "User-Agent": (
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
+ "AppleWebKit/537.36 (KHTML, like Gecko) "
+ "Chrome/115.0.0.0 Safari/537.36"
+ )
+ }
+
+ max_retries = 10
+ success = False
+ for attempt in range(1, max_retries + 1):
+ try:
+ response = requests.get(sat_url, headers=headers, timeout=10)
+ response.raise_for_status() # Raise an error for bad HTTP status codes.
+ # Convert the image to JPEG if necessary
+ image = Image.open(BytesIO(response.content))
+ image = image.convert("RGB") # Ensure compatibility with JPEG format
+ image.save(file_path, "JPEG")
+
+ # Test to see file is corrupted
+ with Image.open(file_path) as img:
+ img.verify() # Does minimal decoding, good for quick validation
+ success = True
+ break # Exit the loop if the download is successful.
+
+ # OSError can catch issues like truncated files, permission errors, etc.
+ except (UnidentifiedImageError, OSError):
+ if attempt < max_retries:
+ print(f"[Corrupted] Retrying: Failed attempt {attempt} for key {key} from URL {sat_url}")
+ time.sleep(2)
+
+ except Exception as e:
+ if attempt < max_retries:
+ print(f"Retrying: Failed attempt {attempt} for key {key} from URL {sat_url}")
+ time.sleep(2)
+ # else:
+
+
+
+ if not success:
+ print(f"FAILURE: Could not download image for key {key} from URL: {sat_url} after {max_retries} attempts")
+
+ # Update the progress bar regardless of success or failure.
+ with progress_lock:
+ pbar.update(1)
+ if not success:
+ return (key, sat_url)
+ return None
+
+def chunked_iterator(iterable, chunk_size):
+ """
+ Yield successive chunks of size `chunk_size` from the iterable.
+ """
+ iterator = iter(iterable)
+ while True:
+ chunk = list(itertools.islice(iterator, chunk_size))
+ if not chunk:
+ break
+ yield chunk
+
+# Define chunk size and number of worker threads.
+chunk_size = 1000 # Process 10,000 rows at a time.
+max_workers = 32 # Number of threads to use in parallel.
+
+# Process the dataset in chunks.
+try:
+ for chunk in chunked_iterator(ds, chunk_size):
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
+ # Submit a download task for each row in the chunk.
+ futures = {executor.submit(download_image, row): row for row in chunk}
+ # As each future completes, record any failures.
+ for future in as_completed(futures):
+ result = future.result()
+ if result:
+ key, sat_url = result
+ download_failures[key] = sat_url
+except:
+ print(f"Download failures: {download_failures}")
+ print("len(download_failures):", len(download_failures))
+
+# Close the progress bar when done.
+pbar.close()
+
+print(f"Download failures: {download_failures}")
+print("len(download_failures):", len(download_failures))
diff --git a/taxabind_avs/scripts/download_sound_and_img_pairs.py b/taxabind_avs/scripts/download_sound_and_img_pairs.py
new file mode 100644
index 0000000000000000000000000000000000000000..effef3afcec085b18ce3087a84770d260f9f44d9
--- /dev/null
+++ b/taxabind_avs/scripts/download_sound_and_img_pairs.py
@@ -0,0 +1,361 @@
+##############################################################################
+# Name: download_sound_and_img_pairs.py
+#
+# - Downloads sound and image pairs from huggingface
+# - https://huggingface.co/datasets/MVRL/iSoundNat
+###############################################################################
+
+import os
+import itertools
+import requests
+import time
+import threading
+import ffmpeg
+import pandas as pd
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from tqdm import tqdm # progress bar
+from PIL import Image
+
+##############################
+# SETUP: DIRECTORIES & DATASET
+##############################
+
+mode = "train" # or "validation" or "test"
+
+# Define which split to use and CSV paths.
+splits = {
+ 'train': 'train_df.csv',
+ 'validation': 'val_df.csv',
+ 'test': 'test_df.csv'
+}
+# Here we load the training CSV; adjust as needed.
+df = pd.read_csv("hf://datasets/MVRL/iSoundNat/" + splits[mode])
+
+# If you want to skip to a specific row (for example, row index 19000),
+# then slice the DataFrame accordingly.
+start_index = 0
+if start_index > 0:
+ df = df.iloc[start_index:].reset_index(drop=True)
+
+# Directories for saving images and audio files
+image_save_dir = f"/mnt/hdd/inat2021_ds/sound_{mode}/images"
+audio_save_dir = f"/mnt/hdd/inat2021_ds/sound_{mode}/sounds_mp3"
+os.makedirs(image_save_dir, exist_ok=True)
+os.makedirs(audio_save_dir, exist_ok=True)
+
+# Convert dataframe rows to a list of dictionaries for iteration.
+rows = df.to_dict("records")
+num_rows = len(rows)
+
+# Dictionaries to record failures for pairs (keyed by id)
+image_failures = {}
+audio_failures = {}
+
+# Global progress bar and lock (one update per pair processed)
+pbar = tqdm(total=num_rows, desc="Pairs Processed")
+progress_lock = threading.Lock()
+
+##############################
+# HELPER FUNCTIONS
+##############################
+
+def convert_image_to_jpeg(temp_path, final_path):
+ """
+ Opens the image at temp_path using Pillow.
+ If its format is not JPEG, converts it.
+ Saves the image as JPEG to final_path.
+ """
+ try:
+ with Image.open(temp_path) as im:
+ if im.format != "JPEG":
+ rgb_im = im.convert("RGB")
+ rgb_im.save(final_path, "JPEG")
+ else:
+ # If already JPEG, simply rename the file.
+ os.rename(temp_path, final_path)
+ except Exception as e:
+ print(f"Error converting image {temp_path}: {e}")
+ raise e
+
+def is_audio_corrupted(file_path):
+ """
+ Uses ffmpeg.probe() to check if an audio file is readable.
+ Returns True if the file is corrupted or unreadable.
+ """
+ try:
+ ffmpeg.probe(file_path)
+ return False
+ except ffmpeg.Error as e:
+ print(f"Error probing audio '{file_path}': {e}")
+ return True
+
+def is_mp3_format(file_path):
+ """
+ Probes the file and checks whether 'mp3' is part of the format name.
+ """
+ try:
+ info = ffmpeg.probe(file_path)
+ format_name = info.get("format", {}).get("format_name", "")
+ return "mp3" in format_name
+ except Exception as e:
+ print(f"Error checking mp3 format for '{file_path}': {e}")
+ return False
+
+def convert_to_mp3(input_file, output_file):
+ """
+ Converts the input audio file to MP3 using the libmp3lame codec.
+ """
+ try:
+ stream = ffmpeg.input(input_file)
+ stream = ffmpeg.output(stream, output_file, acodec="libmp3lame")
+ ffmpeg.run(stream, quiet=True)
+ except ffmpeg.Error as e:
+ print(f"Error converting audio '{input_file}' to MP3: {e}")
+ raise e
+
+##############################
+# DOWNLOAD FUNCTIONS WITH RETRIES
+##############################
+
+def download_image(row, image_url, image_id, image_save_path):
+ """
+ Downloads the image from row["image_url"].
+ Saves a temporary file then converts (if needed) to JPEG as {id}.
+ """
+
+ temp_path = image_save_path + ".temp"
+ headers = {
+ "User-Agent": (
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
+ "AppleWebKit/537.36 (KHTML, like Gecko) "
+ "Chrome/115.0.0.0 Safari/537.36"
+ )
+ }
+ max_retries = 3
+ success = False
+ for attempt in range(1, max_retries + 1):
+ try:
+ response = requests.get(image_url, headers=headers, timeout=10)
+ response.raise_for_status()
+ with open(temp_path, "wb") as f:
+ f.write(response.content)
+ success = True
+ break # Exit loop on success.
+ except Exception as e:
+ if attempt < max_retries:
+ time.sleep(2)
+ else:
+ print(f"FAILURE: Could not download image {image_id} from {image_url} after {max_retries} attempts")
+ success = False
+ if not success:
+ if os.path.exists(temp_path):
+ try:
+ os.remove(temp_path)
+ except Exception:
+ pass
+ return False
+
+ try:
+ convert_image_to_jpeg(temp_path, image_save_path)
+ except Exception as e:
+ print(f"Error processing image {image_id}: {e}")
+ if os.path.exists(temp_path):
+ try:
+ os.remove(temp_path)
+ except Exception:
+ pass
+ if os.path.exists(image_save_path):
+ try:
+ os.remove(image_save_path)
+ except Exception:
+ pass
+ return False
+ finally:
+ if os.path.exists(temp_path):
+ try:
+ os.remove(temp_path)
+ except Exception:
+ pass
+
+ return os.path.exists(image_save_path)
+
+def download_audio(row, audio_url, audio_id, audio_save_path):
+ """
+ Downloads the audio file from row["sound_url"].
+ Saves it to a temporary file, checks for corruption, and if needed converts it to MP3
+ as {id}.mp3 using ffmpeg-python.
+ """
+
+ # temp_path = os.path.join(audio_save_dir, f"{audio_id}_temp")
+ temp_path = audio_save_path + ".temp"
+ headers = {
+ "User-Agent": (
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
+ "AppleWebKit/537.36 (KHTML, like Gecko) "
+ "Chrome/115.0.0.0 Safari/537.36"
+ )
+ }
+ max_retries = 5
+ success = False
+ for attempt in range(1, max_retries + 1):
+ try:
+ response = requests.get(audio_url, headers=headers, timeout=10)
+ response.raise_for_status()
+ with open(temp_path, "wb") as f:
+ f.write(response.content)
+ success = True
+ break
+ except Exception as e:
+ if attempt < max_retries:
+ time.sleep(2)
+ else:
+ print(f"FAILURE: Could not download audio {audio_id} from {audio_url} after {max_retries} attempts")
+ success = False
+ if not success:
+ if os.path.exists(temp_path):
+ try:
+ os.remove(temp_path)
+ except Exception:
+ pass
+ return False
+
+ # Check if the downloaded audio is corrupted.
+ if is_audio_corrupted(temp_path):
+ print(f"Audio file {audio_id} is corrupted.")
+ try:
+ os.remove(temp_path)
+ except Exception:
+ pass
+ return False
+
+ # Check if the audio is already in MP3 format.
+ if is_mp3_format(temp_path):
+ try:
+ os.rename(temp_path, audio_save_path)
+ except Exception as e:
+ print(f"Error renaming audio {audio_id}: {e}")
+ try:
+ os.remove(temp_path)
+ except Exception:
+ pass
+ return False
+ return True
+ else:
+ try:
+ convert_to_mp3(temp_path, audio_save_path)
+ except Exception as e:
+ print(f"Error converting audio {audio_id}: {e}")
+ try:
+ os.remove(temp_path)
+ except Exception:
+ pass
+ return False
+ finally:
+ if os.path.exists(temp_path):
+ try:
+ os.remove(temp_path)
+ except Exception:
+ pass
+
+ return os.path.exists(audio_save_path)
+
+def download_pair(row):
+ """
+ Downloads both the image and audio for a given row.
+ If either download/conversion fails, deletes any successfully downloaded file
+ and marks the pair as a failure.
+ """
+
+ # If the final image already exists, assume it's already downloaded.
+ image_url = row["image_url"]
+ image_id = row["id"]
+ img_save_path = os.path.join(image_save_dir, f"{image_id}.jpg")
+
+ img_exists = False
+ if os.path.exists(img_save_path):
+ img_exists = True
+
+
+ # If the final audio already exists, assume it's already downloaded.
+ audio_url = row["sound_url"]
+ audio_id = row["id"]
+ audio_save_path = os.path.join(audio_save_dir, f"{audio_id}.mp3")
+
+ audio_exists = False
+ if os.path.exists(audio_save_path):
+ audio_exists = True
+
+ # Skip the download if both files already exist.
+ if not (img_exists and audio_exists):
+
+ image_success = download_image(row, image_url, image_id, img_save_path)
+ audio_success = download_audio(row, audio_url, audio_id, audio_save_path)
+
+ # If either download failed, delete any successfully downloaded file.
+ if not (image_success and audio_success):
+ image_failures[row["id"]] = row["image_url"]
+ audio_failures[row["id"]] = row["sound_url"]
+ success = False
+ else:
+ success = True
+ else:
+ success = True
+ print(f"SKIPPED: Image {image_id} and Audio {audio_id} already exists.")
+
+ with progress_lock:
+ pbar.update(1)
+ return success
+
+def chunked_iterator(iterable, chunk_size):
+ """
+ Yields successive chunks of size chunk_size from the iterable.
+ """
+ iterator = iter(iterable)
+ while True:
+ chunk = list(itertools.islice(iterator, chunk_size))
+ if not chunk:
+ break
+ yield chunk
+
+##############################
+# PROCESS THE DATASET IN CHUNKS
+##############################
+
+chunk_size = 999999 # Adjust based on memory and dataset size.
+max_workers = 8 # Number of threads for parallel downloads.
+
+# Process rows in chunks using multi-threading.
+try:
+ for chunk in chunked_iterator(rows, chunk_size):
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
+ futures = {executor.submit(download_pair, row): row for row in chunk}
+ for future in as_completed(futures):
+ try:
+ future.result() # True if both downloads succeeded.
+ except Exception as e:
+ row = futures[future]
+ print(f"Error processing row {row['id']}: {e}")
+except Exception as e:
+ print(f"An error occurred during processing: {e}")
+
+pbar.close()
+
+print("Image download failures:", image_failures)
+print("Audio download failures:", audio_failures)
+print("len(image_failures):", len(image_failures))
+print("len(audio_failures):", len(audio_failures))
+
+##############################
+# REMOVE FAILURE ROWS FROM ORIGINAL DATAFRAME AND EXPORT
+##############################
+
+# Combine IDs from both failure dictionaries.
+failure_ids = set(image_failures.keys()).union(set(audio_failures.keys()))
+print(f"Total failed pairs: {len(failure_ids)}")
+
+# Remove failed rows from the original dataframe (preserving original order).
+successful_df = df[~df["id"].isin(failure_ids)]
+
+output_csv = f"/mnt/hdd/inat2021_ds/sound_{mode}/sound_image_pairs_filtered.csv"
+successful_df.to_csv(output_csv, index=False)
+print(f"Exported {len(successful_df)} successful rows to {output_csv}")
diff --git a/taxabind_avs/soundbind/config_sound.py b/taxabind_avs/soundbind/config_sound.py
new file mode 100644
index 0000000000000000000000000000000000000000..109378567ff076a5cf948c16f704efa2b9ba11a3
--- /dev/null
+++ b/taxabind_avs/soundbind/config_sound.py
@@ -0,0 +1,31 @@
+##############################################################################
+# Name: config_sound.py
+#
+# - Parameters for train/eval sound encoder
+###############################################################################
+
+from easydict import EasyDict as edict
+
+config = edict()
+config.train_df = 'clip_quad_modal|train'
+config.val_df = 'clip_quad_modal|val'
+config.data_path_train = '/mnt/hdd/avs_bench_ds/sound_mp3/train'
+config.data_path_val = '/mnt/hdd/avs_bench_ds/sound_mp3/test'
+
+config.batch_size = 256
+config.lr = 1e-4
+config.accumulate_grad_batches = 8
+config.max_epochs = 20
+config.num_workers = 16
+config.devices = 2
+config.val_check_interval = 0.5
+config.sound_encoder = 'laion/clap-htsat-fused'
+config.avs_dataset = 'derektan95/avs-bench'
+config.save_dir = 'checkpoints'
+config.filename = 'soundbind-{epoch:02d}-{val_loss:.2f}'
+config.locked_tuning = True
+
+# huggingface finetuned
+config.image_encoder_finetuned = 'hf-hub:imageomics/bioclip'
+
+print("config: \n", config)
\ No newline at end of file
diff --git a/taxabind_avs/soundbind/dataloader.py b/taxabind_avs/soundbind/dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1829489e001732e6fd5c26852ab29f02ae4cd69
--- /dev/null
+++ b/taxabind_avs/soundbind/dataloader.py
@@ -0,0 +1,72 @@
+##############################################################################
+# Name: dataloader.py
+#
+# - Handles loading of quadmodal dataset
+# - https://huggingface.co/datasets/derektan95/avs-bench
+###############################################################################
+
+import os
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from config_sound import config
+from sound_encoder import get_audio_clap
+from datasets import load_dataset
+
+
+class INatDataset(Dataset):
+ def __init__(self,
+ data_file,
+ mode='train'):
+
+ # Load from huggingface dataset
+ ds_config = data_file.split("|")[0]
+ ds_split = data_file.split("|")[1]
+ self.data_file = load_dataset(config.avs_dataset, name=ds_config, split=ds_split).to_pandas()
+ print("Loaded huggingface dataset: ", ds_config, ds_split)
+
+ self.mode = mode
+ if mode=='train':
+ self.transform = transforms.Compose([
+ transforms.Resize((256, 256)),
+ transforms.RandomCrop((224, 224)),
+ transforms.RandomHorizontalFlip(0.5),
+ transforms.GaussianBlur(5, (0.01, 1.0)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ ])
+ else:
+ self.transform = transforms.Compose([
+ transforms.Resize((256, 256)),
+ transforms.CenterCrop((224, 224)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ ])
+ self.species_text = self.data_file['scientific_name'].tolist()
+ self.species_classes = list(set(self.species_text))
+ print("mode: ", self.mode)
+ print("len(self.data_file): ", len(self.data_file))
+
+ def __len__(self):
+ return len(self.data_file)
+
+ def get_sample(self,idx):
+ sample = self.data_file.iloc[idx]
+ id = sample.id
+ sound_format = sample.sound_format
+ data_path = config['data_path_train'] if self.mode == 'train' else config['data_path_val']
+ image_path = os.path.join(data_path,"images",str(id)+".jpg")
+ sound_path = os.path.join(data_path,"sounds_mp3",str(id)+"."+'mp3')
+ sound = get_audio_clap(sound_path)
+
+ for k in sound.keys():
+ sound[k] = sound[k].squeeze(0)
+ image = self.transform(Image.open(image_path).convert("RGB"))
+
+ return image, sound
+
+ def __getitem__(self, idx):
+ image, sound = self.get_sample(idx)
+ return image, sound, self.species_classes.index(self.data_file.iloc[idx]['scientific_name'])
diff --git a/taxabind_avs/soundbind/model_sound.py b/taxabind_avs/soundbind/model_sound.py
new file mode 100644
index 0000000000000000000000000000000000000000..10efb1d6f90cf3e4548d0da8721a6934291359f3
--- /dev/null
+++ b/taxabind_avs/soundbind/model_sound.py
@@ -0,0 +1,148 @@
+##############################################################################
+# Name: model_sound.py
+#
+# - Training wrapper for sound CLAP model
+###############################################################################
+
+import os
+import sys
+sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
+
+import open_clip
+import pytorch_lightning as pl
+import torch
+import torch.nn as nn
+import numpy as np
+import os
+import random
+from sound_encoder import CLAP_audiomodel_withProjection as AudioEncoder
+from torch.utils.data import DataLoader
+from config_sound import config
+from dataloader import INatDataset
+from pytorch_lightning.callbacks import ModelCheckpoint
+
+
+def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
+ audio_loss = contrastive_loss(similarity)
+ ground_img_loss = contrastive_loss(similarity.t())
+ return 0.5*audio_loss + 0.5*ground_img_loss
+
+def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
+
+ return nn.functional.cross_entropy(logits[:logits.shape[1]], torch.arange(logits.shape[1], device=logits.device))
+
+class AudioBind(pl.LightningModule):
+ def __init__(self, train_dataset, val_dataset, **kwargs):
+ super().__init__()
+ self.train_dataset = train_dataset
+ self.val_dataset = val_dataset
+ self.model, *_ = open_clip.create_model_and_transforms(config.image_encoder_finetuned)
+ if config.locked_tuning:
+ for param in self.model.parameters():
+ param.requires_grad = False
+ self.audio_encoder = AudioEncoder(freeze=False)
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+ self.batch_size = kwargs.get('batch_size')
+ self.num_workers = kwargs.get('num_workers')
+ self.lr = kwargs.get('lr', 1e-4)
+
+ def forward(self, image, audio):
+ with torch.no_grad():
+ image_embeds, *_ = self.model(image)
+ unnormalized_audio_embeds = self.audio_encoder(audio)
+ audio_embeds = torch.nn.functional.normalize(unnormalized_audio_embeds, dim=-1)
+ return image_embeds, audio_embeds
+
+ def shared_step(self, batch):
+ image, audio, *_ = batch
+ image_embeds, audio_embeds = self(image, audio)
+ logit_scale = self.logit_scale.exp()
+ logits_per_img = torch.matmul(image_embeds,audio_embeds.t())*logit_scale
+ cross_contrastive_loss = clip_loss(logits_per_img)
+ return cross_contrastive_loss
+
+ def training_step(self, batch, batch_idx):
+ loss = self.shared_step(batch)
+ self.log('train_loss', loss, sync_dist=True, prog_bar=True, on_epoch=True, batch_size=self.batch_size)
+ self.log('temperature', self.logit_scale.data, prog_bar=True, on_epoch=True, batch_size=self.batch_size)
+ return loss
+
+ def on_train_batch_end(self,outputs,batch, batch_idx):
+ if self.logit_scale.data > np.log(100):
+ self.logit_scale.data = torch.clamp(self.logit_scale.data, 0, np.log(100))
+
+ def validation_step(self, batch, batch_idx):
+ loss = self.shared_step(batch)
+ self.log('val_loss', loss, sync_dist=True, prog_bar=True, on_epoch=True, batch_size=self.batch_size)
+ return loss
+
+ def train_dataloader(self):
+ return DataLoader(self.train_dataset,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ shuffle=True,
+ persistent_workers=False)
+
+ def val_dataloader(self):
+ return DataLoader(self.val_dataset,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ shuffle=False,
+ persistent_workers=False)
+
+ def configure_optimizers(self):
+ params = self.parameters()
+ self.optim = torch.optim.AdamW(params,
+ lr=self.lr,
+ betas=(0.9,0.98),
+ eps=1e-6
+ )
+ self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
+ optimizer=self.optim,
+ T_0=20
+ )
+ return [self.optim], [self.scheduler]
+
+def seed_everything(seed=42):
+ """
+ seed: int
+ """
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ os.environ["PYTHONHASHSEED"] = str(seed)
+
+if __name__=='__main__':
+ import warnings
+ warnings.filterwarnings("ignore")
+ torch.set_warn_always(False)
+
+ seed_everything()
+ train_dataset = INatDataset(data_file=config.train_df, mode='train')
+ val_dataset = INatDataset(data_file=config.val_df, mode='val')
+ kwargs = {'batch_size':config.batch_size, 'num_workers': config.num_workers}
+
+ model = AudioBind(train_dataset, val_dataset, **kwargs)
+ torch.cuda.empty_cache()
+
+ checkpoint = ModelCheckpoint(
+ monitor='val_loss',
+ dirpath=config.save_dir,
+ filename=config.filename,
+ mode='min',
+ save_top_k=3
+ )
+ trainer = pl.Trainer(
+ accelerator='gpu',
+ devices=config.devices,
+ strategy='ddp',
+ max_epochs=config.max_epochs,
+ num_nodes=1,
+ callbacks=[checkpoint],
+ accumulate_grad_batches=config.accumulate_grad_batches,
+ log_every_n_steps=1
+ )
+ trainer.fit(model)
\ No newline at end of file
diff --git a/taxabind_avs/soundbind/sound_encoder.py b/taxabind_avs/soundbind/sound_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ed354c2485cfb77bf9863b500097de172e56c91
--- /dev/null
+++ b/taxabind_avs/soundbind/sound_encoder.py
@@ -0,0 +1,53 @@
+##############################################################################
+# Name: sound_encoder.py
+#
+# - Wrapper for sound CLAP model
+###############################################################################
+
+import os
+import sys
+sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
+
+import pytorch_lightning as pl
+import torch
+import torchaudio
+from transformers import ClapProcessor
+from transformers import ClapAudioModelWithProjection
+from config_sound import config
+
+
+processor = ClapProcessor.from_pretrained(config.sound_encoder)
+SAMPLE_RATE = 48000
+
+def get_audio_clap(path_to_audio,format="mp3",padding="repeatpad",truncation="fusion"):
+ track, sr = torchaudio.load(path_to_audio, format=format)
+ track = track.mean(axis=0)
+ track = torchaudio.functional.resample(track, orig_freq=sr, new_freq=SAMPLE_RATE)
+ output = processor(audios=track, sampling_rate=SAMPLE_RATE, max_length_s=10, return_tensors="pt",padding=padding,truncation=truncation)
+ return output
+
+
+class CLAP_audiomodel_withProjection(pl.LightningModule):
+ def __init__(self,freeze=False):
+ super().__init__()
+ if freeze:
+ self.model = ClapAudioModelWithProjection.from_pretrained(config.sound_encoder).eval()
+ for params in self.model.parameters():
+ params.requires_grad=False
+ else:
+ self.model = ClapAudioModelWithProjection.from_pretrained(config.sound_encoder).train()
+ def forward(self,audio):
+ batch_embeddings_audio = self.model(**audio)['audio_embeds']
+ return batch_embeddings_audio
+
+if __name__ == '__main__':
+ path_to_audio ="/mnt/hdd/inat2021_ds/sound_train/sounds_mp3/165878447.mp3"
+ sample = get_audio_clap(path_to_audio)
+ print(sample.keys())
+
+ sample['input_features'] = torch.concat([sample['input_features'],sample['input_features']],axis=0)
+ sample['is_longer'] = torch.concat([sample['is_longer'],sample['is_longer']],axis=0)
+ print(sample['input_features'].shape,sample['is_longer'].shape)
+ model = CLAP_audiomodel_withProjection(freeze=False)
+ audio_feat = model(sample)
+ print(audio_feat.shape)
\ No newline at end of file