derektan commited on
Commit
e330ebf
·
0 Parent(s):

Initial Commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +40 -0
  2. .gitignore +38 -0
  3. .vscode/launch.json +15 -0
  4. README.md +13 -0
  5. app.py +528 -0
  6. app_multimodal_inference.py +350 -0
  7. examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/0b9cc264-a2ba-44bd-8e41-0d01a6edd1e8.jpg +3 -0
  8. examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/12372063.mp3 +3 -0
  9. examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/277303_38.72364_-75.07749.jpg +3 -0
  10. examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg +3 -0
  11. examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/89758229.mp3 +3 -0
  12. examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/cc1ebaf9-899d-49f2-81c8-d452249a8087.jpg +3 -0
  13. examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/5041_-0.28573_-90.54837.jpg +3 -0
  14. examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/c834edf7-b073-4cd5-8726-9c88ebe943c0.jpg +3 -0
  15. examples/Animalia_Chordata_Mammalia_Artiodactyla_Cervidae_Cervus_nippon/1a51f978-c76d-4816-a716-473c21b5495c.jpg +3 -0
  16. examples/Animalia_Chordata_Mammalia_Artiodactyla_Cervidae_Cervus_nippon/245767_53.0076_-6.35201.jpg +3 -0
  17. examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg +3 -0
  18. examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/189318716.mp3 +3 -0
  19. examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg +3 -0
  20. examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg +3 -0
  21. examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg +3 -0
  22. examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/27423_35.64005_-121.17595.jpg +3 -0
  23. examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3123948.mp3 +3 -0
  24. examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3aac526d-c921-452a-af6a-cb4f2f52e2c4.jpg +3 -0
  25. examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/166631961.mp3 +3 -0
  26. examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/1b8064f8-7deb-4b30-98cd-69da98ba6a3d.jpg +3 -0
  27. examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/28871_-12.80255_-69.29999.jpg +3 -0
  28. examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/327e1f07-692b-4140-8a3e-bd098bc064ff.jpg +3 -0
  29. examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/388246_45.49036_7.14796.jpg +3 -0
  30. examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/59677071.mp3 +3 -0
  31. examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/340271_10.52832_-83.49678.jpg +3 -0
  32. examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/45193295.mp3 +3 -0
  33. examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/938aab7b-4509-4de7-afad-2c8ea51f4799.jpg +3 -0
  34. examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg +3 -0
  35. examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg +3 -0
  36. examples/metadata.json +173 -0
  37. inference/model/avs_rl_policy.pth +3 -0
  38. maps/example/masks_val/MSK_0001.png +3 -0
  39. maps/gpt4o/envs_val/MSK_0001.png +3 -0
  40. planner/env.py +610 -0
  41. planner/graph.py +167 -0
  42. planner/graph_generator.py +300 -0
  43. planner/model.py +312 -0
  44. planner/node.py +96 -0
  45. planner/robot.py +58 -0
  46. planner/sensor.py +128 -0
  47. planner/test_info_surfing.py +1071 -0
  48. planner/test_parameter.py +118 -0
  49. planner/test_worker.py +590 -0
  50. planner/worker.py +272 -0
.gitattributes ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
38
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
39
+ *.mp3 filter=lfs diff=lfs merge=lfs -text
40
+ *.pth filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.idea/
2
+ *.vscode/
3
+
4
+ *__pycache__/
5
+ *.pyc
6
+ *.pyo
7
+ *.pyd
8
+
9
+ *.pt
10
+ *.pth
11
+ *.tar.gz
12
+ *.zip
13
+
14
+ !**/train/
15
+ **/train/*
16
+ !**/train/saved
17
+
18
+ !**/inference/
19
+ **/inference/*
20
+ !**/inference/saved
21
+
22
+ !**/maps/
23
+ **/maps/*
24
+ !**/maps/example
25
+ !**/maps/gpt4o
26
+ !**/maps/lisa
27
+
28
+ # For taxabind_avs
29
+ **/dataset/
30
+ **/checkpoints/
31
+
32
+ !**/lightning_logs/
33
+ **/lightning_logs/*
34
+ !**/lightning_logs/saved
35
+
36
+ # Saved weights & logs
37
+ **avs_rl_policy.pth
38
+ **/avs_rl_policy_21.5k/*
.vscode/launch.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "0.2.0",
3
+ "configurations": [
4
+ {
5
+ "name": "Debug app.py",
6
+ "type": "debugpy",
7
+ "request": "launch",
8
+ "program": "${workspaceFolder}/app.py",
9
+ "cwd": "${workspaceFolder}",
10
+ "console": "integratedTerminal",
11
+ "justMyCode": false,
12
+ "python": "/home/user/anaconda3/envs/vlm-search/bin/python3"
13
+ }
14
+ ]
15
+ }
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Search-TTA
3
+ emoji: 🦁
4
+ colorFrom: green
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 5.31.0
8
+ app_file: app.py
9
+ pinned: false
10
+ short_description: Multimodal Test-time Adaptation Framework for Visual Search
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simplified Gradio demo for Search-TTA evaluation.
3
+ """
4
+
5
+ # ────────────────────────── imports ───────────────────────────────────
6
+ from pathlib import Path
7
+ import matplotlib
8
+ matplotlib.use("Agg", force=True)
9
+
10
+ import gradio as gr
11
+ import ctypes # for safely stopping background threads
12
+ import os, glob, threading, time
13
+ import torch
14
+ from PIL import Image
15
+ import json
16
+ import shutil
17
+ import spaces # integration with ZeroGPU on hf
18
+ from planner.test_parameter import *
19
+ from planner.model import PolicyNet
20
+ from planner.test_worker import TestWorker
21
+ from taxabind_avs.satbind.clip_seg_tta import ClipSegTTA
22
+
23
+
24
+ # Helper to kill a Python thread by injecting SystemExit
25
+ def _stop_thread(thread: threading.Thread):
26
+ """Forcefully raise SystemExit in the given thread (best-effort)."""
27
+ if thread is None or not thread.is_alive():
28
+ return
29
+ tid = thread.ident
30
+ if tid is None:
31
+ return
32
+ # Ask CPython to raise SystemExit in the thread context
33
+ res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), ctypes.py_object(SystemExit))
34
+ if res > 1:
35
+ # If it returned >1, cleanup and fail safe
36
+ ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), None)
37
+
38
+ # ──────────── Thread Registry for Cleanup on Tab Switch ─────────────
39
+ _running_threads: list[threading.Thread] = []
40
+ _running_threads_lock = threading.Lock()
41
+
42
+ # Map worker threads to their ClipSegTTA instance so UI can read executing_tta flag
43
+ _thread_clip_map: dict[threading.Thread, ClipSegTTA] = {}
44
+
45
+ # ──────────── Run directory rotation ─────────────
46
+ RUN_HISTORY_LIMIT = 30 # keep at most this many timestamped run directories per instance
47
+
48
+ def _prune_old_run_dirs(base_dir: str, limit: int = RUN_HISTORY_LIMIT):
49
+ """Delete oldest timestamp-named run directories leaving only *limit* of the newest ones."""
50
+ try:
51
+ dirs = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]
52
+ dirs.sort()
53
+ if len(dirs) > limit:
54
+ for obsolete in dirs[:-limit]:
55
+ shutil.rmtree(os.path.join(base_dir, obsolete), ignore_errors=True)
56
+ except Exception:
57
+ pass
58
+
59
+
60
+ # CHANGE ME!
61
+ POLL_INTERVAL = 1.0 # For visualization
62
+
63
+ # Prepare the model
64
+ device = torch.device('cuda') if USE_GPU and torch.cuda.is_available() else torch.device('cpu')
65
+ policy_net = PolicyNet(INPUT_DIM, EMBEDDING_DIM).to(device)
66
+ script_dir = Path(__file__).resolve().parent
67
+ print("real_script_dir: ", script_dir)
68
+ checkpoint = torch.load(f'{MODEL_PATH}/{MODEL_NAME}')
69
+ policy_net.load_state_dict(checkpoint['policy_model'])
70
+ print('Model loaded!')
71
+
72
+ # Load metadata json
73
+ tgts_metadata_json_path = os.path.join(script_dir, "examples/metadata.json")
74
+ tgts_metadata = json.load(open(tgts_metadata_json_path))
75
+
76
+
77
+ # ────────────────────────── Gradio process fn ─────────────────────────
78
+
79
+ ### integration with ZeroGPU on hf
80
+ # @spaces.GPU
81
+ def process_search_tta(
82
+ sat_path: str | None,
83
+ ground_path: str | None,
84
+ taxonomy: str | None = None,
85
+ session_threads: list[threading.Thread] | None = None,
86
+ ):
87
+ """Run both TTA and non-TTA search episodes concurrently and stream both heat-maps."""
88
+
89
+ if session_threads is None:
90
+ session_threads = []
91
+
92
+ # Disable Run button and clear image/status outputs, hide sliders, clear frame states
93
+ yield (
94
+ gr.update(interactive=False),
95
+ gr.update(value=None),
96
+ gr.update(value=None),
97
+ gr.update(value="Initializing model…", visible=True),
98
+ gr.update(value="Initializing model…", visible=True),
99
+ gr.update(visible=False),
100
+ gr.update(visible=False),
101
+ [],
102
+ [],
103
+ session_threads,
104
+ )
105
+
106
+ # Bail early if satellite image missing
107
+ if sat_path is None:
108
+ yield (
109
+ gr.update(interactive=True),
110
+ gr.update(value=None),
111
+ gr.update(value=None),
112
+ gr.update(value="No satellite image provided.", visible=True),
113
+ gr.update(value="", visible=True),
114
+ gr.update(visible=False),
115
+ gr.update(visible=False),
116
+ [],
117
+ [],
118
+ session_threads,
119
+ )
120
+ return
121
+
122
+ # Prepare PIL images
123
+ sat_img = Image.open(sat_path).convert("RGB")
124
+ ground_img_pil = Image.open(ground_path).convert("RGB") if ground_path else None
125
+
126
+ # Lookup target positions metadata (may be empty)
127
+ tgt_positions = []
128
+ if taxonomy and taxonomy in tgts_metadata:
129
+ tgt_positions = [tuple(t) for t in tgts_metadata[taxonomy]["target_positions"]]
130
+
131
+ # Helper to build a TestWorker with/without TTA
132
+ def build_planner(enable_tta: bool, save_dir: str, clip_obj):
133
+ # Lazily (re)create a ClipSegTTA instance per thread if not provided
134
+ local_clip = clip_obj
135
+ if LOAD_AVS_BENCH and local_clip is None:
136
+ local_clip = ClipSegTTA(
137
+ img_dir=AVS_IMG_DIR,
138
+ imo_dir=AVS_IMO_DIR,
139
+ json_path=AVS_INAT_JSON_PATH,
140
+ sat_to_img_ids_path=AVS_SAT_TO_IMG_IDS_PATH,
141
+ sat_checkpoint_path=AVS_SAT_CHECKPOINT_PATH,
142
+ load_pretrained_hf_ckpt=AVS_LOAD_PRETRAINED_HF_CHECKPOINT,
143
+ blur_kernel = AVS_GAUSSIAN_BLUR_KERNEL,
144
+ sample_index=-1,
145
+ device=device,
146
+ sat_to_img_ids_json_is_train_dict=False,
147
+ tax_to_filter_val=QUERY_TAX,
148
+ load_model=USE_CLIP_PREDS,
149
+ query_modality=QUERY_MODALITY,
150
+ sound_dir = AVS_SOUND_DIR,
151
+ sound_checkpoint_path=AVS_SOUND_CHECKPOINT_PATH,
152
+ )
153
+
154
+ if local_clip is not None:
155
+ # Feed inputs to ClipSegTTA copy
156
+ local_clip.img_paths = [ground_path] if ground_path else []
157
+ local_clip.imo_path = sat_path
158
+ local_clip.imgs = ([local_clip.dataset.img_transform(ground_img_pil).to(device)] if ground_img_pil else [])
159
+ local_clip.imo = local_clip.dataset.imo_transform(sat_img).to(device)
160
+ local_clip.sounds = []
161
+ local_clip.sound_ids = []
162
+ local_clip.species_name = taxonomy or ""
163
+ local_clip.gt_mask_name = taxonomy.replace(" ", "_") if taxonomy else ""
164
+ local_clip.target_positions = tgt_positions if tgt_positions else [(0, 0)]
165
+
166
+ planner = TestWorker(
167
+ meta_agent_id=0,
168
+ n_agent=1,
169
+ policy_net=policy_net,
170
+ global_step=-1,
171
+ device=device,
172
+ greedy=True,
173
+ save_image=SAVE_GIFS,
174
+ clip_seg_tta=local_clip,
175
+ )
176
+ planner.execute_tta = enable_tta
177
+ planner.gifs_path = save_dir
178
+ return planner
179
+
180
+ # ────────────── Per-run output directories ──────────────
181
+ # Ensure base directory exists
182
+ os.makedirs(GIFS_PATH, exist_ok=True)
183
+
184
+ run_id = time.strftime("%Y%m%d_%H%M%S") # unique timestamp
185
+ run_root = os.path.join(GIFS_PATH, run_id)
186
+ gifs_dir_tta = os.path.join(run_root, "with_tta")
187
+ gifs_dir_no = os.path.join(run_root, "no_tta")
188
+
189
+ os.makedirs(gifs_dir_tta, exist_ok=True)
190
+ os.makedirs(gifs_dir_no, exist_ok=True)
191
+
192
+ # House-keep old runs so we never keep more than RUN_HISTORY_LIMIT
193
+ _prune_old_run_dirs(GIFS_PATH, RUN_HISTORY_LIMIT)
194
+
195
+ # Shared dict to record if a thread hit an exception
196
+ error_flags = {"tta": False, "no": False}
197
+
198
+ def _planner_thread(enable_tta: bool, save_dir: str, clip_obj, key: str):
199
+ """Prepare directory, build planner, run an episode, record errors."""
200
+ try:
201
+ planner = build_planner(enable_tta, save_dir, clip_obj)
202
+ _thread_clip_map[threading.current_thread()] = planner.clip_seg_tta
203
+ planner.run_episode(0)
204
+ except Exception as exc:
205
+ # Mark that this planner crashed so UI can show an error status
206
+ error_flags[key] = True
207
+ # Log full traceback so developers can debug via console logs
208
+ import traceback, sys
209
+ traceback.print_exc()
210
+ # Still exit the thread
211
+ return
212
+
213
+ # Launch both planners in background threads – preparation included
214
+ thread_tta = threading.Thread(
215
+ target=_planner_thread,
216
+ args=(True, gifs_dir_tta, None, "tta"),
217
+ daemon=True,
218
+ )
219
+ thread_no = threading.Thread(
220
+ target=_planner_thread,
221
+ args=(False, gifs_dir_no, None, "no"),
222
+ daemon=True,
223
+ )
224
+ # Track threads for this user session
225
+ session_threads.extend([thread_tta, thread_no])
226
+ thread_tta.start()
227
+ thread_no.start()
228
+
229
+
230
+ sent_tta: set[str] = set()
231
+ sent_no: set[str] = set()
232
+ last_tta = None
233
+ last_no = None
234
+ # Track previous status strings so we can emit updates when only the
235
+ # status (Running…/Done.) changes even if no new frame was produced.
236
+ # Previous status values so we can detect changes and yield updates
237
+ prev_status_tta = "Initializing model…"
238
+ prev_status_no = "Initializing model…"
239
+
240
+ try:
241
+ while thread_tta.is_alive() or thread_no.is_alive():
242
+ updated = False
243
+ # Collect new frames from TTA dir
244
+ pngs = glob.glob(os.path.join(gifs_dir_tta, "*.png"))
245
+ pngs.sort(key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
246
+ for fp in pngs:
247
+ if fp not in sent_tta:
248
+ # Ensure file is fully written (non-empty & readable)
249
+ try:
250
+ if os.path.getsize(fp) == 0:
251
+ continue
252
+ with open(fp, "rb") as fh:
253
+ fh.read(1)
254
+ except Exception:
255
+ # Skip this round; we'll retry next poll
256
+ continue
257
+ sent_tta.add(fp)
258
+ last_tta = fp
259
+ updated = True
260
+ # Collect new frames from no-TTA dir
261
+ pngs = glob.glob(os.path.join(gifs_dir_no, "*.png"))
262
+ pngs.sort(key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
263
+ for fp in pngs:
264
+ if fp not in sent_no:
265
+ try:
266
+ if os.path.getsize(fp) == 0:
267
+ continue
268
+ with open(fp, "rb") as fh:
269
+ fh.read(1)
270
+ except Exception:
271
+ continue
272
+ sent_no.add(fp)
273
+ last_no = fp
274
+ updated = True
275
+
276
+ # Determine status based on whether we already have a frame and whether
277
+ # the corresponding thread is still alive.
278
+ def _mk_status(last_frame, thread_alive, errored: bool, running_tta: bool=False):
279
+ if errored:
280
+ return "Error!"
281
+ if last_frame is None:
282
+ return "Initializing model…"
283
+ if not thread_alive:
284
+ return "Done."
285
+ return "Executing TTA (Scheduling GPUs)…" if running_tta else "Executing Planner…"
286
+
287
+ exec_tta_flag = False
288
+ if thread_tta.is_alive():
289
+ clip_obj = _thread_clip_map.get(thread_tta)
290
+ if clip_obj is not None and getattr(clip_obj, "executing_tta", False):
291
+ exec_tta_flag = True
292
+
293
+ status_tta = _mk_status(last_tta, thread_tta.is_alive(), error_flags["tta"], exec_tta_flag)
294
+ status_no = _mk_status(last_no, thread_no.is_alive(), error_flags["no"], False)
295
+
296
+ # Determine if we should reveal sliders (once corresponding thread has finished)
297
+ show_slider_tta = (not thread_tta.is_alive()) and (last_tta is not None)
298
+ show_slider_no = (not thread_no.is_alive()) and (last_no is not None)
299
+
300
+ # Build slider updates
301
+ slider_tta_upd = gr.update()
302
+ slider_no_upd = gr.update()
303
+ frames_tta_upd = gr.update()
304
+ frames_no_upd = gr.update()
305
+
306
+ if show_slider_tta:
307
+ n_tta_frames = max(len(sent_tta), 1)
308
+ slider_tta_upd = gr.update(visible=True, minimum=1, maximum=n_tta_frames, value=n_tta_frames)
309
+ frames_tta_upd = sorted(sent_tta, key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
310
+ if show_slider_no:
311
+ n_no_frames = max(len(sent_no), 1)
312
+ slider_no_upd = gr.update(visible=True, minimum=1, maximum=n_no_frames, value=n_no_frames)
313
+ frames_no_upd = sorted(sent_no, key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
314
+
315
+ # Emit update if we have a new frame OR status changed OR slider visibility changed
316
+ if (
317
+ updated
318
+ or status_tta != prev_status_tta
319
+ or status_no != prev_status_no
320
+ or show_slider_tta
321
+ or show_slider_no
322
+ ):
323
+ yield (
324
+ gr.update(interactive=False),
325
+ last_tta,
326
+ last_no,
327
+ gr.update(value=status_tta, visible=True),
328
+ gr.update(value=status_no, visible=True),
329
+ slider_tta_upd,
330
+ slider_no_upd,
331
+ frames_tta_upd,
332
+ frames_no_upd,
333
+ session_threads,
334
+ )
335
+
336
+ prev_status_tta = status_tta
337
+ prev_status_no = status_no
338
+
339
+ time.sleep(POLL_INTERVAL)
340
+ finally:
341
+ # Ensure background threads are stopped on cancel
342
+ for th in (thread_tta, thread_no):
343
+ if th.is_alive():
344
+ _stop_thread(th)
345
+ th.join(timeout=1)
346
+
347
+ # Remove finished threads from global registry
348
+ with _running_threads_lock:
349
+ # Clear session thread list
350
+ session_threads.clear()
351
+
352
+ # Small delay to ensure last frame files are fully flushed
353
+ time.sleep(0.2)
354
+ # One last scan after both threads have finished to catch any frame
355
+ # that may have been written just before termination but after the last
356
+ # polling iteration.
357
+ for fp in sorted(glob.glob(os.path.join(gifs_dir_tta, "*.png")), key=lambda p: int(os.path.splitext(os.path.basename(p))[0])):
358
+ if fp not in sent_tta:
359
+ sent_tta.add(fp)
360
+ last_tta = fp
361
+ for fp in sorted(glob.glob(os.path.join(gifs_dir_no, "*.png")), key=lambda p: int(os.path.splitext(os.path.basename(p))[0])):
362
+ if fp not in sent_no:
363
+ sent_no.add(fp)
364
+ last_no = fp
365
+
366
+ # Prepare frames list and slider configs
367
+ frames_tta = sorted(glob.glob(os.path.join(gifs_dir_tta, "*.png")), key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
368
+ frames_no = sorted(glob.glob(os.path.join(gifs_dir_no, "*.png")), key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
369
+ if last_tta is None and frames_tta:
370
+ last_tta = frames_tta[-1]
371
+ if last_no is None and frames_no:
372
+ last_no = frames_no[-1]
373
+ n_tta = len(frames_tta) or 1 # prevent zero-range slider
374
+ n_no = len(frames_no) or 1
375
+
376
+ # Final emit: re-enable button, hide statuses, show sliders set to last frame
377
+ yield (
378
+ gr.update(interactive=True),
379
+ last_tta,
380
+ last_no,
381
+ gr.update(visible=False),
382
+ gr.update(visible=False),
383
+ gr.update(visible=True, minimum=1, maximum=n_tta, value=n_tta),
384
+ gr.update(visible=True, minimum=1, maximum=n_no, value=n_no),
385
+ frames_tta,
386
+ frames_no,
387
+ session_threads,
388
+ )
389
+
390
+
391
+ # ────────────────────────── Gradio UI ─────────────────────────────────
392
+ with gr.Blocks(title="Search-TTA (Simplified)", theme=gr.themes.Base()) as demo:
393
+
394
+ gr.Markdown(
395
+ """
396
+ # Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild Demo
397
+ Click on any of the <b>examples below</b> and run the <b>TTA demo</b>. Check out the <b>multimodal heatmap generation feature</b> by switching to the other tab above. <br>
398
+ Note that the model initialization, RL planner, and TTA updates are not fully optimized on GPU for this huggingface demo, and hence may experience some lag during execution. <br>
399
+ If you encounter an 'Error' status, refresh the browser and rerun the demo, or try again the next day. We will improve this in the future. <br>
400
+ <a href="https://search-tta.github.io">Project Website</a>
401
+ """
402
+ )
403
+
404
+ with gr.Row(variant="panel"):
405
+ with gr.Column():
406
+ gr.Markdown("### Model Inputs")
407
+ sat_input = gr.Image(
408
+ label="Satellite Image",
409
+ sources=["upload"],
410
+ type="filepath",
411
+ height=320,
412
+ )
413
+ taxonomy_input = gr.Textbox(
414
+ label="Full Taxonomy Name (optional)",
415
+ placeholder="e.g. Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
416
+ )
417
+ ground_input = gr.Image(
418
+ label="Ground-level Image (optional)",
419
+ sources=["upload"],
420
+ type="filepath",
421
+ height=320,
422
+ )
423
+ run_btn = gr.Button("Run Search-TTA", variant="primary")
424
+
425
+ with gr.Column():
426
+ gr.Markdown("### Live Heatmap Output")
427
+ display_img_tta = gr.Image(label="Heatmap (TTA per 20 steps)", type="filepath", height=400) # 512
428
+ status_tta = gr.Markdown("")
429
+ slider_tta = gr.Slider(label="TTA Frame", minimum=1, maximum=1, step=1, value=1, visible=False)
430
+
431
+ display_img_no_tta = gr.Image(label="Heatmap (no TTA)", type="filepath", height=400) # 512
432
+ status_no_tta = gr.Markdown("")
433
+ slider_no = gr.Slider(label="No-TTA Frame", minimum=1, maximum=1, step=1, value=1, visible=False)
434
+
435
+ frames_state_tta = gr.State([])
436
+ frames_state_no = gr.State([])
437
+ session_threads_state = gr.State([])
438
+
439
+ # Slider callbacks (updates image when user drags slider)
440
+ def _show_frame(idx: int, frames: list[str]):
441
+ # Slider is 1-indexed; convert to 0-indexed list access
442
+ if 1 <= idx <= len(frames):
443
+ return frames[idx - 1]
444
+ return gr.update()
445
+
446
+ slider_tta.change(_show_frame, inputs=[slider_tta, frames_state_tta], outputs=display_img_tta)
447
+ slider_no.change(_show_frame, inputs=[slider_no, frames_state_no], outputs=display_img_no_tta)
448
+
449
+ # EXAMPLES
450
+ with gr.Row():
451
+ gr.Markdown("### Taxonomy")
452
+ with gr.Row():
453
+ gr.Examples(
454
+ examples=[
455
+ [
456
+ "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg",
457
+ "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg",
458
+ "Animalia Chordata Reptilia Squamata Varanidae Varanus salvator",
459
+ ],
460
+ [
461
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg",
462
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg",
463
+ "Animalia Chordata Mammalia Carnivora Canidae Canis aureus",
464
+ ],
465
+ [
466
+ "examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/340271_10.52832_-83.49678.jpg",
467
+ "examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/938aab7b-4509-4de7-afad-2c8ea51f4799.jpg",
468
+ "Animalia Chordata Reptilia Crocodylia Alligatoridae Caiman crocodilus",
469
+ ],
470
+ [
471
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg",
472
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg",
473
+ "Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis",
474
+ ],
475
+ ],
476
+ inputs=[sat_input, ground_input, taxonomy_input],
477
+ outputs=[run_btn, display_img_tta, display_img_no_tta, status_tta, status_no_tta, slider_tta, slider_no, frames_state_tta, frames_state_no],
478
+ fn=process_search_tta,
479
+ cache_examples=False,
480
+ )
481
+
482
+ run_btn.click(
483
+ fn=process_search_tta,
484
+ inputs=[sat_input, ground_input, taxonomy_input, session_threads_state],
485
+ outputs=[run_btn, display_img_tta, display_img_no_tta, status_tta, status_no_tta, slider_tta, slider_no, frames_state_tta, frames_state_no, session_threads_state],
486
+ )
487
+
488
+ # Footer to point out to model and data from app page.
489
+ gr.Markdown(
490
+ """
491
+ The satellite image CLIP encoder is fine-tuned using [Sentinel-2 Level 2A](https://docs.sentinel-hub.com/api/latest/data/sentinel-2-l2a/) satellite image and taxonomy images (with GPS locations) from [iNaturalist](https://inaturalist.org/). The sound CLIP encoder is fine-tuned with a subset of the same taxonomy images and their corresponding sounds from [iNaturalist](https://inaturalist.org/). Some of these iNaturalist data are also used in [Taxabind](https://arxiv.org/abs/2411.00683). Note that while some of the examples above result in poor probability distributions, they will be improved using our test-time adaptation framework during the search process.
492
+ """
493
+ )
494
+
495
+
496
+ if __name__ == "__main__":
497
+
498
+ # Build UI with explicit Tabs so we can detect tab selection and clean up
499
+ from app_multimodal_inference import demo as multimodal_demo
500
+
501
+ with gr.Blocks() as root:
502
+ with gr.Tabs() as tabs:
503
+ with gr.TabItem("Multimodal Inference"):
504
+ multimodal_demo.render()
505
+ with gr.TabItem("Search-TTA"):
506
+ demo.render()
507
+
508
+ # Hidden textbox purely to satisfy Gradio's need for an output component.
509
+ _cleanup_status = gr.Textbox(visible=False)
510
+
511
+ outputs_on_tab = [_cleanup_status]
512
+
513
+ def _on_tab_change(evt: gr.SelectData, session_threads: list[threading.Thread]):
514
+ # evt.value contains the name of the newly-selected tab.
515
+ if evt.value == "Multimodal Inference":
516
+ # Stop only threads started in this session
517
+ for th in list(session_threads):
518
+ if th is not None and th.is_alive():
519
+ _stop_thread(th)
520
+ th.join(timeout=1)
521
+ session_threads.clear()
522
+ return "Stopped running Search-TTA threads."
523
+ return ""
524
+
525
+ tabs.select(_on_tab_change, inputs=[session_threads_state], outputs=outputs_on_tab)
526
+
527
+ root.queue(max_size=15)
528
+ root.launch(share=True)
app_multimodal_inference.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Search-TTA multimodal heatmap generation demo
3
+ """
4
+
5
+ # ────────────────────────── imports ───────────────────────────────────
6
+ import cv2
7
+ import gradio as gr
8
+ import torch
9
+ import numpy as np
10
+ from PIL import Image
11
+ import matplotlib.pyplot as plt
12
+ import io
13
+ import torchaudio
14
+ import spaces # integration with ZeroGPU on hf
15
+
16
+ from torchvision import transforms
17
+ import open_clip
18
+ from taxabind_avs.satbind.clip_vision_per_patch_model import CLIPVisionPerPatchModel
19
+ from transformers import ClapAudioModelWithProjection
20
+ from transformers import ClapProcessor
21
+
22
+ # ────────────────────────── global config & models ────────────────────
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+
25
+ # BioCLIP (ground-image & text encoder)
26
+ bio_model, _, _ = open_clip.create_model_and_transforms("hf-hub:imageomics/bioclip")
27
+ bio_model = bio_model.to(device).eval()
28
+ bio_tokenizer = open_clip.get_tokenizer("hf-hub:imageomics/bioclip")
29
+
30
+ # Satellite patch encoder CLIP-L-336 per-patch)
31
+ sat_model: CLIPVisionPerPatchModel = (
32
+ CLIPVisionPerPatchModel.from_pretrained("derektan95/search-tta-sat")
33
+ .to(device)
34
+ .eval()
35
+ )
36
+
37
+ # Sound CLAP model
38
+ sound_model: ClapAudioModelWithProjection = (
39
+ ClapAudioModelWithProjection.from_pretrained("derektan95/search-tta-sound")
40
+ .to(device)
41
+ .eval()
42
+ )
43
+ sound_processor: ClapProcessor = ClapProcessor.from_pretrained("derektan95/search-tta-sound")
44
+ SAMPLE_RATE = 48000
45
+
46
+ logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
47
+ logit_scale = logit_scale.exp()
48
+ blur_kernel = (5,5)
49
+
50
+ # ────────────────────────── transforms (exact spec) ───────────────────
51
+ img_transform = transforms.Compose(
52
+ [
53
+ transforms.Resize((256, 256)),
54
+ transforms.CenterCrop((224, 224)),
55
+ transforms.ToTensor(),
56
+ transforms.Normalize(
57
+ mean=[0.485, 0.456, 0.406],
58
+ std=[0.229, 0.224, 0.225],
59
+ ),
60
+ ]
61
+ )
62
+
63
+ imo_transform = transforms.Compose(
64
+ [
65
+ transforms.Resize((336, 336)),
66
+ transforms.ToTensor(),
67
+ transforms.Normalize(
68
+ mean=[0.485, 0.456, 0.406],
69
+ std=[0.229, 0.224, 0.225],
70
+ ),
71
+ ]
72
+ )
73
+
74
+ def get_audio_clap(path_to_audio,format="mp3",padding="repeatpad",truncation="fusion"):
75
+ track, sr = torchaudio.load(path_to_audio, format=format) # torchaudio.load(path_to_audio)
76
+ track = track.mean(axis=0)
77
+ track = torchaudio.functional.resample(track, orig_freq=sr, new_freq=SAMPLE_RATE)
78
+ output = sound_processor(audios=track, sampling_rate=SAMPLE_RATE, max_length_s=10, return_tensors="pt",padding=padding,truncation=truncation)
79
+ return output
80
+
81
+ # ────────────────────────── helpers ───────────────────────────────────
82
+
83
+ @torch.no_grad()
84
+ def _encode_ground(img_pil: Image.Image) -> torch.Tensor:
85
+ img = img_transform(img_pil).unsqueeze(0).to(device)
86
+ img_embeds, *_ = bio_model(img)
87
+ return img_embeds
88
+
89
+
90
+ @torch.no_grad()
91
+ def _encode_text(text: str) -> torch.Tensor:
92
+ toks = bio_tokenizer(text).to(device)
93
+ _, txt_embeds, _ = bio_model(text=toks)
94
+ return txt_embeds
95
+
96
+
97
+ @torch.no_grad()
98
+ def _encode_sat(img_pil: Image.Image) -> torch.Tensor:
99
+ imo = imo_transform(img_pil).unsqueeze(0).to(device)
100
+ imo_embeds = sat_model(imo)
101
+ return imo_embeds
102
+
103
+
104
+ @torch.no_grad()
105
+ def _encode_sound(sound) -> torch.Tensor:
106
+ processed_sound = get_audio_clap(sound)
107
+ for k in processed_sound.keys():
108
+ processed_sound[k] = processed_sound[k].to(device)
109
+ unnormalized_audio_embeds = sound_model(**processed_sound).audio_embeds
110
+ sound_embeds = torch.nn.functional.normalize(unnormalized_audio_embeds, dim=-1)
111
+ return sound_embeds
112
+
113
+
114
+ def _similarity_heatmap(query: torch.Tensor, patches: torch.Tensor) -> np.ndarray:
115
+ sims = torch.matmul(query, patches.t()) * logit_scale
116
+ sims = sims.t().sigmoid()
117
+ sims = sims[1:].squeeze() # drop CLS token
118
+ side = int(np.sqrt(len(sims)))
119
+ sims = sims.reshape(side, side)
120
+ return sims.cpu().detach().numpy()
121
+
122
+
123
+ def _array_to_pil(arr: np.ndarray) -> Image.Image:
124
+ """
125
+ Render arr with viridis, automatically stretching its own min→max to 0→1
126
+ so that the most-similar patches appear yellow.
127
+ """
128
+
129
+ # Gausian Smoothing
130
+ if blur_kernel != (0,0):
131
+ arr = cv2.GaussianBlur(arr, blur_kernel, 0)
132
+
133
+ # --- contrast-stretch to local 0-1 range --------------------------
134
+ arr_min, arr_max = float(arr.min()), float(arr.max())
135
+ if arr_max - arr_min < 1e-6: # avoid /0 when the heat-map is flat
136
+ arr_scaled = np.zeros_like(arr)
137
+ else:
138
+ arr_scaled = (arr - arr_min) / (arr_max - arr_min)
139
+ # ------------------------------------------------------------------
140
+ fig, ax = plt.subplots(figsize=(2.6, 2.6), dpi=96)
141
+ ax.imshow(arr_scaled, cmap="viridis", vmin=0.0, vmax=1.0)
142
+ ax.axis("off")
143
+ buf = io.BytesIO()
144
+ plt.tight_layout(pad=0)
145
+ fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
146
+ plt.close(fig)
147
+ buf.seek(0)
148
+ return Image.open(buf)
149
+
150
+ # ────────────────────────── main inference ────────────────────────────
151
+ # integration with ZeroGPU on hf
152
+ @spaces.GPU(duration=5)
153
+ def process(
154
+ sat_img: Image.Image,
155
+ taxonomy: str,
156
+ ground_img: Image.Image | None,
157
+ sound: torch.Tensor | None,
158
+ ):
159
+ if sat_img is None:
160
+ return None, None
161
+
162
+ patches = _encode_sat(sat_img)
163
+
164
+ heat_ground, heat_text, heat_sound = None, None, None
165
+
166
+ if ground_img is not None:
167
+ q_img = _encode_ground(ground_img)
168
+ heat_ground = _array_to_pil(_similarity_heatmap(q_img, patches))
169
+
170
+ if taxonomy.strip():
171
+ q_txt = _encode_text(taxonomy.strip())
172
+ heat_text = _array_to_pil(_similarity_heatmap(q_txt, patches))
173
+
174
+ if sound is not None:
175
+ q_sound = _encode_sound(sound)
176
+ heat_sound = _array_to_pil(_similarity_heatmap(q_sound, patches))
177
+
178
+ return heat_ground, heat_text, heat_sound
179
+
180
+
181
+ # ────────────────────────── Gradio UI ─────────────────────────────────
182
+ with gr.Blocks(title="Search-TTA", theme=gr.themes.Base()) as demo:
183
+
184
+ gr.Markdown(
185
+ """
186
+ # Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild Demo
187
+ Click on any of the <b>examples below</b> and run the <b>multimodal inference demo</b>. Check out the <b>test-time adaptation feature</b> by switching to the other tab above. <br>
188
+ If you encounter any errors, refresh the browser and rerun the demo, or try again the next day. We will improve this in the future. <br>
189
+ <a href="https://search-tta.github.io">Project Website</a>
190
+ """
191
+ )
192
+
193
+ with gr.Row(variant="panel"):
194
+
195
+ # LEFT COLUMN (satellite, taxonomy, run)
196
+ with gr.Column():
197
+ sat_input = gr.Image(
198
+ label="Satellite Image",
199
+ sources=["upload"],
200
+ type="pil",
201
+ height=320,
202
+ )
203
+ taxonomy_input = gr.Textbox(
204
+ label="Full Taxonomy Name (optional)",
205
+ placeholder="e.g. Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
206
+ )
207
+
208
+ # ─── NEW: sound input ───────────────────────────
209
+ sound_input = gr.Audio(
210
+ label="Sound Input (optional)",
211
+ sources=["upload"],
212
+ type="filepath",
213
+ )
214
+ run_btn = gr.Button("Run", variant="primary")
215
+
216
+ # RIGHT COLUMN (ground image + two heat-maps)
217
+ with gr.Column():
218
+ ground_input = gr.Image(
219
+ label="Ground-level Image (optional)",
220
+ sources=["upload"],
221
+ type="pil",
222
+ height=320,
223
+ )
224
+ gr.Markdown("### Heat-map Results")
225
+ with gr.Row():
226
+ # Separate label and image to avoid overlap
227
+ with gr.Column(scale=1, min_width=100):
228
+ gr.Markdown("**Ground Image Query**", elem_id="label-ground")
229
+ heat_ground_out = gr.Image(
230
+ show_label=False,
231
+ height=160,
232
+ )
233
+ with gr.Column(scale=1, min_width=100):
234
+ gr.Markdown("**Text Query**", elem_id="label-text")
235
+ heat_text_out = gr.Image(
236
+ show_label=False,
237
+ height=160,
238
+ )
239
+ with gr.Column(scale=1, min_width=100):
240
+ gr.Markdown("**Sound Query**", elem_id="label-sound")
241
+ heat_sound_out = gr.Image(
242
+ show_label=False,
243
+ height=160,
244
+ )
245
+
246
+
247
+ # EXAMPLES
248
+ with gr.Row():
249
+ gr.Markdown("### In-Domain Taxonomy")
250
+ with gr.Row():
251
+ gr.Examples(
252
+ examples=[
253
+ [
254
+ "examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg",
255
+ "examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/cc1ebaf9-899d-49f2-81c8-d452249a8087.jpg",
256
+ "Animalia Chordata Aves Charadriiformes Laridae Larus marinus",
257
+ "examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/89758229.mp3"
258
+ ],
259
+ [
260
+ "examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/28871_-12.80255_-69.29999.jpg",
261
+ "examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/1b8064f8-7deb-4b30-98cd-69da98ba6a3d.jpg",
262
+ "Animalia Chordata Mammalia Rodentia Caviidae Hydrochoerus hydrochaeris",
263
+ "examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/166631961.mp3"
264
+ ],
265
+ [
266
+ "examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/277303_38.72364_-75.07749.jpg",
267
+ "examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/0b9cc264-a2ba-44bd-8e41-0d01a6edd1e8.jpg",
268
+ "Animalia Arthropoda Malacostraca Decapoda Ocypodidae Ocypode quadrata",
269
+ "examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/12372063.mp3"
270
+ ],
271
+ [
272
+ "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/388246_45.49036_7.14796.jpg",
273
+ "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/327e1f07-692b-4140-8a3e-bd098bc064ff.jpg",
274
+ "Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
275
+ "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/59677071.mp3"
276
+ ],
277
+ [
278
+ "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg",
279
+ "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg",
280
+ "Animalia Chordata Reptilia Squamata Varanidae Varanus salvator",
281
+ None
282
+ ],
283
+ ],
284
+ inputs=[sat_input, ground_input, taxonomy_input, sound_input],
285
+ outputs=[heat_ground_out, heat_text_out, heat_sound_out],
286
+ fn=process,
287
+ cache_examples=False,
288
+ )
289
+
290
+ # EXAMPLES
291
+ with gr.Row():
292
+ gr.Markdown("### Out-Domain Taxonomy")
293
+ with gr.Row():
294
+ gr.Examples(
295
+ examples=[
296
+ [
297
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/27423_35.64005_-121.17595.jpg",
298
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3aac526d-c921-452a-af6a-cb4f2f52e2c4.jpg",
299
+ "Animalia Chordata Mammalia Carnivora Phocidae Mirounga angustirostris",
300
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3123948.mp3"
301
+ ],
302
+ [
303
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg",
304
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg",
305
+ "Animalia Chordata Mammalia Carnivora Canidae Canis aureus",
306
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/189318716.mp3"
307
+ ],
308
+ [
309
+ "examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/340271_10.52832_-83.49678.jpg",
310
+ "examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/938aab7b-4509-4de7-afad-2c8ea51f4799.jpg",
311
+ "Animalia Chordata Reptilia Crocodylia Alligatoridae Caiman crocodilus",
312
+ "examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/45193295.mp3"
313
+ ],
314
+ [
315
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg",
316
+ "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg",
317
+ "Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis",
318
+ None
319
+ ],
320
+ [
321
+ "examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/5041_-0.28573_-90.54837.jpg",
322
+ "examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/c834edf7-b073-4cd5-8726-9c88ebe943c0.jpg",
323
+ "Animalia Chordata Elasmobranchii Carcharhiniformes Carcharhinidae Triaenodon obesus",
324
+ None
325
+ ],
326
+ ],
327
+ inputs=[sat_input, ground_input, taxonomy_input, sound_input],
328
+ outputs=[heat_ground_out, heat_text_out, heat_sound_out],
329
+ fn=process,
330
+ cache_examples=False,
331
+ )
332
+
333
+ # CALLBACK
334
+ run_btn.click(
335
+ fn=process,
336
+ inputs=[sat_input, taxonomy_input, ground_input, sound_input],
337
+ outputs=[heat_ground_out, heat_text_out, heat_sound_out],
338
+ )
339
+
340
+ # Footer to point out to model and data from app page.
341
+ gr.Markdown(
342
+ """
343
+ The satellite image CLIP encoder is fine-tuned using [Sentinel-2 Level 2A](https://docs.sentinel-hub.com/api/latest/data/sentinel-2-l2a/) satellite image and taxonomy images (with GPS locations) from [iNaturalist](https://inaturalist.org/). The sound CLIP encoder is fine-tuned with a subset of the same taxonomy images and their corresponding sounds from [iNaturalist](https://inaturalist.org/). Some of these iNaturalist data are also used in [Taxabind](https://arxiv.org/abs/2411.00683). Note that while some of the examples above result in poor probability distributions, they will be improved using our test-time adaptation framework during the search process.
344
+ """
345
+ )
346
+
347
+ # LAUNCH
348
+ if __name__ == "__main__":
349
+ demo.queue(max_size=15)
350
+ demo.launch(share=True)
examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/0b9cc264-a2ba-44bd-8e41-0d01a6edd1e8.jpg ADDED

Git LFS Details

  • SHA256: 08aee38091dbb62f0862a184acbc9432f50c03c63fdf357592df8efcacaab485
  • Pointer size: 131 Bytes
  • Size of remote file: 135 kB
examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/12372063.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:575959883981159f2e40593bf5be87be006026c41da36a34d1e40783de648116
3
+ size 54027
examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/277303_38.72364_-75.07749.jpg ADDED

Git LFS Details

  • SHA256: 1caa5d8bab960f559065f79ca554bed63e6b02764096874be6a58b34389855f6
  • Pointer size: 130 Bytes
  • Size of remote file: 25.6 kB
examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg ADDED

Git LFS Details

  • SHA256: 8350770efa7d8e38b91670e757bb82df26167f8989f946132ad978d238baa916
  • Pointer size: 130 Bytes
  • Size of remote file: 26.1 kB
examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/89758229.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2c7ad6df49668d29f9b7f9f9f0739b97ef4edc5219413a41d01983a9863cccc
3
+ size 2601487
examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/cc1ebaf9-899d-49f2-81c8-d452249a8087.jpg ADDED

Git LFS Details

  • SHA256: e346a1c1424e62a040c7e97f17c2e5ccb4c923422682105b2ccedd0ead736170
  • Pointer size: 130 Bytes
  • Size of remote file: 28.4 kB
examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/5041_-0.28573_-90.54837.jpg ADDED

Git LFS Details

  • SHA256: 624443bdb62b8d35e6e63be33e04404a85ad8902b70af67d878a013893656dc2
  • Pointer size: 130 Bytes
  • Size of remote file: 15.3 kB
examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/c834edf7-b073-4cd5-8726-9c88ebe943c0.jpg ADDED

Git LFS Details

  • SHA256: e7f8be8790e7c5837d8d8e0d9285adad45138598caef21f528a591a0ab13ee9b
  • Pointer size: 130 Bytes
  • Size of remote file: 58 kB
examples/Animalia_Chordata_Mammalia_Artiodactyla_Cervidae_Cervus_nippon/1a51f978-c76d-4816-a716-473c21b5495c.jpg ADDED

Git LFS Details

  • SHA256: a47758183723ba17f72dee9acdaf4bcfba2b4d07d0af2e50c125b3fac665ca04
  • Pointer size: 131 Bytes
  • Size of remote file: 117 kB
examples/Animalia_Chordata_Mammalia_Artiodactyla_Cervidae_Cervus_nippon/245767_53.0076_-6.35201.jpg ADDED

Git LFS Details

  • SHA256: 7e9e1c3907555774d831b34b9d7ef94b79b7fbe82c3e226baef75e0cf71194e4
  • Pointer size: 130 Bytes
  • Size of remote file: 23 kB
examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg ADDED

Git LFS Details

  • SHA256: e9f1934026db176cdcea261d37eda0a02309e5f2647ecab336e53d571b40f8f4
  • Pointer size: 130 Bytes
  • Size of remote file: 37.2 kB
examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/189318716.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4639c226ea5a0464b98e89b33a1f821b6625c6637d206d3d355e05bc7c89c641
3
+ size 148019
examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg ADDED

Git LFS Details

  • SHA256: a84eca02154ed12885c075378b0349d6950586a7887883bce414df48adb59746
  • Pointer size: 130 Bytes
  • Size of remote file: 83.1 kB
examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg ADDED

Git LFS Details

  • SHA256: ea5f2dffebd69cdded00548f8773c5a8a8849bbdfba04ae6385fbc2d0983d55f
  • Pointer size: 130 Bytes
  • Size of remote file: 75.6 kB
examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg ADDED

Git LFS Details

  • SHA256: 3b803c3c2e6fa921d9f83ba3aecccac0796a4cd4774c3263aae54fdfc49d13d6
  • Pointer size: 130 Bytes
  • Size of remote file: 23.3 kB
examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/27423_35.64005_-121.17595.jpg ADDED

Git LFS Details

  • SHA256: 16dd378607b7303515593491a1247785ae49733da24bbc3ce21e85d6c6341ab2
  • Pointer size: 130 Bytes
  • Size of remote file: 22.6 kB
examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3123948.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:96ca3a92e6f614cce82972dacb04f5c0c170c1aea3d70d15778af56820ed02c9
3
+ size 276768
examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3aac526d-c921-452a-af6a-cb4f2f52e2c4.jpg ADDED

Git LFS Details

  • SHA256: bdda6139885cf54acfb1d6c9a56373fbe39e205dac7eb99cd04dbe5eb206b9d6
  • Pointer size: 130 Bytes
  • Size of remote file: 95.4 kB
examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/166631961.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc02eca19d0c408d038e205d82f6624c0515858ac374cf7298161a14e169e6a9
3
+ size 266258
examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/1b8064f8-7deb-4b30-98cd-69da98ba6a3d.jpg ADDED

Git LFS Details

  • SHA256: fe636d18a1e068e85b0bb8cd05ff674eb4b19958cc34d75ef00d385f74254ecb
  • Pointer size: 130 Bytes
  • Size of remote file: 85.3 kB
examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/28871_-12.80255_-69.29999.jpg ADDED

Git LFS Details

  • SHA256: d28dec20d4f9cba13386ab00f13ddd7cb36fee24ee466e8a5437dbfd778bc2d5
  • Pointer size: 130 Bytes
  • Size of remote file: 23.3 kB
examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/327e1f07-692b-4140-8a3e-bd098bc064ff.jpg ADDED

Git LFS Details

  • SHA256: 5bffc4c332ae6406bcb1b78cd23170bd7c71d58e2e7dac12fb812fc9aa39b8f0
  • Pointer size: 130 Bytes
  • Size of remote file: 70.3 kB
examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/388246_45.49036_7.14796.jpg ADDED

Git LFS Details

  • SHA256: a5aa9ae1a1dc4c59191bc72005fc9904d4c390f07ce5cc5ed435eb5687ae1d64
  • Pointer size: 130 Bytes
  • Size of remote file: 33.3 kB
examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/59677071.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb043991fe851d6a1e12f32c5a9277dad5a77a939cf15ccb4afcb215b4bc08e3
3
+ size 92876
examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/340271_10.52832_-83.49678.jpg ADDED

Git LFS Details

  • SHA256: b31fbe934b245e7274289836b9eee781b2e33c4121dfbafebc473cd45d638825
  • Pointer size: 130 Bytes
  • Size of remote file: 19.8 kB
examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/45193295.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4cd2e4fd7094a07d79da7fd54788705e8ce7567e65911d87edfd23ff1c0e484
3
+ size 247762
examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/938aab7b-4509-4de7-afad-2c8ea51f4799.jpg ADDED

Git LFS Details

  • SHA256: a26f17668646cd25c77483565f6509ca7b21bba09ce92dac0f38d0ecbfdae3b1
  • Pointer size: 130 Bytes
  • Size of remote file: 86.3 kB
examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg ADDED

Git LFS Details

  • SHA256: 3019573a982d10c4791e357a5bebadfbb245f145c57c60c8a53f2241ac8789fe
  • Pointer size: 130 Bytes
  • Size of remote file: 37 kB
examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg ADDED

Git LFS Details

  • SHA256: 41b11b1ea9709a9fefabc2c7ddf8aa58a7881749474bf0ccadca3a02e3a97c76
  • Pointer size: 131 Bytes
  • Size of remote file: 176 kB
examples/metadata.json ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Animalia Chordata Reptilia Squamata Varanidae Varanus salvator": {
3
+ "id": 410613,
4
+ "sat_key": "410613_5.35573_100.28948",
5
+ "sat_path": "410613_5.35573_100.28948.jpg",
6
+ "taxonomy": "Animalia Chordata Reptilia Squamata Varanidae Varanus salvator",
7
+ "count": 6,
8
+ "spread": 58.00460580210422,
9
+ "sat_bounds": {
10
+ "min_lat": 5.344155081363914,
11
+ "max_lat": 5.367304914271601,
12
+ "min_lon": 100.27793148340874,
13
+ "max_lon": 100.30102851659126
14
+ },
15
+ "img_ids": [
16
+ 707815,
17
+ 411949,
18
+ 701168,
19
+ 1619682,
20
+ 2100008,
21
+ 1548498
22
+ ],
23
+ "target_positions": [
24
+ [
25
+ 225,
26
+ 240
27
+ ],
28
+ [
29
+ 232,
30
+ 275
31
+ ],
32
+ [
33
+ 277,
34
+ 449
35
+ ],
36
+ [
37
+ 220,
38
+ 369
39
+ ],
40
+ [
41
+ 180,
42
+ 393
43
+ ],
44
+ [
45
+ 294,
46
+ 478
47
+ ]
48
+ ],
49
+ "num_landmarks": 2
50
+ },
51
+ "Animalia Chordata Mammalia Carnivora Canidae Canis aureus": {
52
+ "id": 1528408,
53
+ "sat_key": "1528408_13.00422_80.23033",
54
+ "sat_path": "1528408_13.00422_80.23033.jpg",
55
+ "taxonomy": "Animalia Chordata Mammalia Carnivora Canidae Canis aureus",
56
+ "count": 3,
57
+ "spread": 58.14007011752667,
58
+ "sat_bounds": {
59
+ "min_lat": 12.992649951077192,
60
+ "max_lat": 13.015790038631529,
61
+ "min_lon": 80.21853090802841,
62
+ "max_lon": 80.24212909197156
63
+ },
64
+ "img_ids": [
65
+ 1528479,
66
+ 2555188,
67
+ 2555189
68
+ ],
69
+ "target_positions": [
70
+ [
71
+ 309,
72
+ 128
73
+ ],
74
+ [
75
+ 239,
76
+ 428
77
+ ],
78
+ [
79
+ 240,
80
+ 419
81
+ ]
82
+ ],
83
+ "num_landmarks": 3
84
+ },
85
+ "Animalia Chordata Reptilia Crocodylia Alligatoridae Caiman crocodilus": {
86
+ "id": 340271,
87
+ "sat_key": "340271_10.52832_-83.49678",
88
+ "sat_path": "340271_10.52832_-83.49678.jpg",
89
+ "taxonomy": "Animalia Chordata Reptilia Crocodylia Alligatoridae Caiman crocodilus",
90
+ "count": 7,
91
+ "spread": 40.13902957324975,
92
+ "sat_bounds": {
93
+ "min_lat": 10.516747947357544,
94
+ "max_lat": 10.53989204420829,
95
+ "min_lon": -83.50847402265151,
96
+ "max_lon": -83.48508597734848
97
+ },
98
+ "img_ids": [
99
+ 1683531,
100
+ 1281855,
101
+ 223089,
102
+ 688111,
103
+ 330757,
104
+ 2408375,
105
+ 1955359
106
+ ],
107
+ "target_positions": [
108
+ [
109
+ 347,
110
+ 75
111
+ ],
112
+ [
113
+ 47,
114
+ 22
115
+ ],
116
+ [
117
+ 111,
118
+ 43
119
+ ],
120
+ [
121
+ 116,
122
+ 51
123
+ ],
124
+ [
125
+ 86,
126
+ 108
127
+ ],
128
+ [
129
+ 31,
130
+ 62
131
+ ],
132
+ [
133
+ 4,
134
+ 78
135
+ ]
136
+ ],
137
+ "num_landmarks": 3
138
+ },
139
+ "Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis": {
140
+ "id": 304160,
141
+ "sat_key": "304160_34.0144_-119.54417",
142
+ "sat_path": "304160_34.0144_-119.54417.jpg",
143
+ "taxonomy": "Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis",
144
+ "count": 3,
145
+ "spread": 237.64152837579553,
146
+ "sat_bounds": {
147
+ "min_lat": 34.00286041606169,
148
+ "max_lat": 34.02593956225012,
149
+ "min_lon": -119.55802743361286,
150
+ "max_lon": -119.53031256638712
151
+ },
152
+ "img_ids": [
153
+ 304160,
154
+ 1473173,
155
+ 384867
156
+ ],
157
+ "target_positions": [
158
+ [
159
+ 255,
160
+ 256
161
+ ],
162
+ [
163
+ 19,
164
+ 22
165
+ ],
166
+ [
167
+ 29,
168
+ 274
169
+ ]
170
+ ],
171
+ "num_landmarks": 3
172
+ }
173
+ }
inference/model/avs_rl_policy.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44e642df9aaa2847ba44dd4707985c67ef712f5264272ef7993aeb7805c80f5a
3
+ size 52167246
maps/example/masks_val/MSK_0001.png ADDED

Git LFS Details

  • SHA256: 318773e2c18275d84b5145d7e69836baa0bedd833f44b49f98e6619357677cff
  • Pointer size: 130 Bytes
  • Size of remote file: 75.9 kB
maps/gpt4o/envs_val/MSK_0001.png ADDED

Git LFS Details

  • SHA256: 7af11bcef1972b7e047f53b597fef2a332d82c7feceb21aac6e14a57469c436b
  • Pointer size: 129 Bytes
  • Size of remote file: 2.34 kB
planner/env.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #######################################################################
2
+ # Name: env.py
3
+ #
4
+ # - Reads and processes training and test maps
5
+ # - Processes rewards, new frontiers given action
6
+ # - Updates a graph representation of environment for input into network
7
+ #######################################################################
8
+
9
+ import sys
10
+ if sys.modules['TRAINING']:
11
+ from .parameter import *
12
+ else:
13
+ from .test_parameter import *
14
+
15
+ import os
16
+ import cv2
17
+ import copy
18
+ import matplotlib.image as mpimg
19
+ import matplotlib.pyplot as plt
20
+ from skimage import io
21
+ from skimage.measure import block_reduce
22
+ from scipy.ndimage import label, find_objects
23
+ from .sensor import *
24
+ from .graph_generator import *
25
+ from .node import *
26
+
27
+
28
+ class Env():
29
+ def __init__(self, map_index, n_agent, k_size=20, plot=False, test=False, mask_index=None):
30
+ self.n_agent = n_agent
31
+ self.test = test
32
+ self.map_dir = GRIDMAP_SET_DIR
33
+
34
+ # Import environment gridmap
35
+ self.map_list = os.listdir(self.map_dir)
36
+ self.map_list.sort(reverse=True)
37
+
38
+ # NEW: Import segmentation utility map
39
+ self.seg_dir = MASK_SET_DIR
40
+ self.segmentation_mask, self.target_positions, self.target_found_idxs = None, [], []
41
+ self.segmentation_mask_list = os.listdir(self.seg_dir)
42
+ self.segmentation_mask_list.sort(reverse=True)
43
+
44
+ # # NEW: Find common files in both directories
45
+ self.map_index = map_index % len(self.map_list)
46
+ if mask_index is not None:
47
+ self.mask_index = mask_index % len(self.segmentation_mask_list)
48
+ else:
49
+ self.mask_index = map_index % len(self.segmentation_mask_list)
50
+
51
+ # Import ground truth and segmentation mask
52
+ self.ground_truth, self.map_start_position = self.import_ground_truth(
53
+ os.path.join(self.map_dir, self.map_list[self.map_index]))
54
+ self.ground_truth_size = np.shape(self.ground_truth)
55
+ self.robot_belief = np.ones(self.ground_truth_size) * 127 # unexplored 127
56
+ self.downsampled_belief = None
57
+ self.old_robot_belief = copy.deepcopy(self.robot_belief)
58
+ self.coverage_belief = np.ones(self.ground_truth_size) * 127 # unexplored 127
59
+
60
+ # Import segmentation mask
61
+ mask_filename = self.segmentation_mask_list[self.mask_index]
62
+ self.segmentation_mask = self.import_segmentation_mask(
63
+ os.path.join(self.seg_dir, mask_filename))
64
+
65
+ # Overwrite target positions if directory specified
66
+ if self.test and TARGETS_SET_DIR != "":
67
+ self.target_positions = self.import_targets(
68
+ os.path.join(TARGETS_SET_DIR, self.map_list[self.map_index]))
69
+
70
+ self.segmentation_info_mask = None
71
+ self.segmentation_info_mask_unnormalized = None
72
+ self.filtered_seg_info_mask = None
73
+ self.num_targets_found = 0
74
+ self.num_new_targets_found = 0
75
+ self.resolution = 4
76
+ self.sensor_range = SENSOR_RANGE
77
+ self.explored_rate = 0
78
+ self.targets_found_rate = 0
79
+ self.frontiers = None
80
+ self.start_positions = []
81
+ self.plot = plot
82
+ self.frame_files = []
83
+ self.graph_generator = Graph_generator(map_size=self.ground_truth_size, sensor_range=self.sensor_range, k_size=k_size, plot=plot)
84
+ self.node_coords, self.graph, self.node_utility, self.guidepost = None, None, None, None
85
+
86
+ self.begin(self.map_start_position)
87
+
88
+
89
+ def find_index_from_coords(self, position):
90
+ index = np.argmin(np.linalg.norm(self.node_coords - position, axis=1))
91
+ return index
92
+
93
+ def begin(self, start_position):
94
+ self.robot_belief = self.ground_truth
95
+ self.downsampled_belief = block_reduce(self.robot_belief.copy(), block_size=(self.resolution, self.resolution), func=np.min)
96
+ self.frontiers = self.find_frontier()
97
+ self.old_robot_belief = copy.deepcopy(self.robot_belief)
98
+
99
+ self.node_coords, self.graph, self.node_utility, self.guidepost = self.graph_generator.generate_graph(
100
+ self.robot_belief, self.frontiers)
101
+
102
+ # Define start positions
103
+ if FIX_START_POSITION:
104
+ coords_res_row = int(self.robot_belief.shape[0]/NUM_COORDS_HEIGHT)
105
+ coords_res_col = int(self.robot_belief.shape[1]/NUM_COORDS_WIDTH)
106
+ self.start_positions = [(int(self.robot_belief.shape[1]/2)-coords_res_col/2,int(self.robot_belief.shape[0]/2)-coords_res_row/2) for _ in range(self.n_agent)]
107
+ else:
108
+ nearby_coords = self.graph_generator.get_neighbors_grid_coords(start_position)
109
+ itr = 0
110
+ for i in range(self.n_agent):
111
+ if i == 0 or len(nearby_coords) == 0:
112
+ self.start_positions.append(start_position)
113
+ else:
114
+ idx = min(itr, len(nearby_coords)-1)
115
+ self.start_positions.append(nearby_coords[idx])
116
+ itr += 1
117
+
118
+ for i in range(len(self.start_positions)):
119
+ self.start_positions[i] = self.node_coords[self.find_index_from_coords(self.start_positions[i])]
120
+ self.coverage_belief = self.update_robot_belief(self.start_positions[i], self.sensor_range, self.coverage_belief,
121
+ self.ground_truth)
122
+
123
+ for start_position in self.start_positions:
124
+ self.graph_generator.route_node.append(start_position)
125
+
126
+ # Info map from ground truth
127
+ rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH)
128
+ rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT)
129
+ self.segmentation_info_mask = np.zeros((len(self.node_coords), 1))
130
+ for i, node_coord in enumerate(self.node_coords):
131
+ max_x = min(node_coord[0] + int(math.ceil(rng_x)), self.ground_truth.shape[1])
132
+ min_x = max(node_coord[0] - int(math.ceil(rng_x)), 0)
133
+ max_y = min(node_coord[1] + int(math.ceil(rng_y)), self.ground_truth.shape[0])
134
+ min_y = max(node_coord[1] - int(math.ceil(rng_y)), 0)
135
+
136
+ if TARGETS_SET_DIR == "":
137
+ exclude = {208} # Exclude target positions
138
+ else:
139
+ exclude = {}
140
+ self.segmentation_info_mask[i] = max(x for x in self.segmentation_mask[min_y:max_y, min_x:max_x].flatten() if x not in exclude) / 100.0
141
+
142
+ self.filtered_seg_info_mask = copy.deepcopy(self.segmentation_info_mask)
143
+ done, num_targets_found = self.check_done()
144
+ self.num_targets_found = num_targets_found
145
+
146
+
147
+ def multi_robot_step(self, next_position_list, dist_list, travel_dist_list):
148
+ reward_list = []
149
+ for dist, robot_position in zip(dist_list, next_position_list):
150
+ self.graph_generator.route_node.append(robot_position)
151
+ next_node_index = self.find_index_from_coords(robot_position)
152
+ self.graph_generator.nodes_list[next_node_index].set_visited()
153
+ self.coverage_belief = self.update_robot_belief(robot_position, self.sensor_range, self.coverage_belief,
154
+ self.ground_truth)
155
+ self.robot_belief = self.ground_truth
156
+ self.downsampled_belief = block_reduce(self.robot_belief.copy(),
157
+ block_size=(self.resolution, self.resolution),
158
+ func=np.min)
159
+
160
+ frontiers = self.find_frontier()
161
+ individual_reward = -dist / 32
162
+
163
+ info_gain_reward = 0
164
+ robot_position_idx = self.find_index_from_coords(robot_position)
165
+ info_gain_reward = self.filtered_seg_info_mask[robot_position_idx][0] * 1.5
166
+ if self.guidepost[robot_position_idx] == 0.0:
167
+ info_gain_reward += 0.2
168
+ individual_reward += info_gain_reward
169
+
170
+ reward_list.append(individual_reward)
171
+
172
+ self.node_coords, self.graph, self.node_utility, self.guidepost = self.graph_generator.update_graph(self.robot_belief, self.old_robot_belief, frontiers, self.frontiers)
173
+ self.old_robot_belief = copy.deepcopy(self.robot_belief)
174
+
175
+ self.filtered_seg_info_mask = [info[0] if self.guidepost[i] == 0.0 else 0.0 for i, info in enumerate(self.segmentation_info_mask)]
176
+ self.filtered_seg_info_mask = np.expand_dims(np.array(self.filtered_seg_info_mask), axis=1)
177
+
178
+ self.frontiers = frontiers
179
+ self.explored_rate = self.evaluate_exploration_rate()
180
+
181
+ done, num_targets_found = self.check_done()
182
+ self.num_new_targets_found = num_targets_found - self.num_targets_found
183
+ team_reward = 0.0
184
+
185
+ self.num_targets_found = num_targets_found
186
+ self.targets_found_rate = self.evaluate_targets_found_rate()
187
+
188
+ if done:
189
+ team_reward += 40
190
+ for i in range(len(reward_list)):
191
+ reward_list[i] += team_reward
192
+
193
+ return reward_list, done
194
+
195
+
196
+ def import_ground_truth(self, map_index):
197
+ # occupied 1, free 255, unexplored 127
198
+
199
+ try:
200
+ ground_truth = (io.imread(map_index, 1)).astype(int)
201
+ if np.all(ground_truth == 0):
202
+ ground_truth = (io.imread(map_index, 1) * 255).astype(int)
203
+ except:
204
+ new_map_index = self.map_dir + '/' + self.map_list[0]
205
+ ground_truth = (io.imread(new_map_index, 1)).astype(int)
206
+ print('could not read the map_path ({}), hence skipping it and using ({}).'.format(map_index, new_map_index))
207
+
208
+ robot_location = np.nonzero(ground_truth == 208)
209
+ robot_location = np.array([np.array(robot_location)[1, 127], np.array(robot_location)[0, 127]])
210
+ ground_truth = (ground_truth > 150)
211
+ ground_truth = ground_truth * 254 + 1
212
+ return ground_truth, robot_location
213
+
214
+
215
+ def import_segmentation_mask(self, map_index):
216
+ mask = cv2.imread(map_index).astype(int)
217
+ return mask
218
+
219
+ def import_targets(self, map_index):
220
+ # occupied 1, free 255, unexplored 127, target 208
221
+ mask = cv2.imread(map_index).astype(int)
222
+ target_positions = self.find_target_locations(mask)
223
+ return target_positions
224
+
225
+
226
+ def find_target_locations(self, image_array, grey_value=208):
227
+
228
+ grey_pixels = np.where(image_array == grey_value)
229
+ binary_array = np.zeros_like(image_array, dtype=bool)
230
+ binary_array[grey_pixels] = True
231
+ labeled_array, num_features = label(binary_array)
232
+ slices = find_objects(labeled_array)
233
+
234
+ # Calculate the center of each box
235
+ centers = []
236
+ for slice in slices:
237
+ row_center = (slice[0].start + slice[0].stop - 1) // 2
238
+ col_center = (slice[1].start + slice[1].stop - 1) // 2
239
+ centers.append((col_center, row_center)) # (y,x)
240
+
241
+ return centers
242
+
243
+ def free_cells(self):
244
+ index = np.where(self.ground_truth == 255)
245
+ free = np.asarray([index[1], index[0]]).T
246
+ return free
247
+
248
+ def update_robot_belief(self, robot_position, sensor_range, robot_belief, ground_truth):
249
+ robot_belief = sensor_work(robot_position, sensor_range, robot_belief, ground_truth)
250
+ return robot_belief
251
+
252
+
253
+ def check_done(self):
254
+ done = False
255
+ num_targets_found = 0
256
+ self.target_found_idxs = []
257
+ for i, target in enumerate(self.target_positions):
258
+ if self.coverage_belief[target[1], target[0]] == 255:
259
+ num_targets_found += 1
260
+ self.target_found_idxs.append(i)
261
+
262
+ if TERMINATE_ON_TGTS_FOUND and num_targets_found >= len(self.target_positions):
263
+ done = True
264
+ if not TERMINATE_ON_TGTS_FOUND and np.sum(self.coverage_belief == 255) / np.sum(self.ground_truth == 255) >= 0.99:
265
+ done = True
266
+
267
+ return done, num_targets_found
268
+
269
+
270
+ def calculate_num_observed_frontiers(self, old_frontiers, frontiers):
271
+ frontiers_to_check = frontiers[:, 0] + frontiers[:, 1] * 1j
272
+ pre_frontiers_to_check = old_frontiers[:, 0] + old_frontiers[:, 1] * 1j
273
+ frontiers_num = np.intersect1d(frontiers_to_check, pre_frontiers_to_check).shape[0]
274
+ pre_frontiers_num = pre_frontiers_to_check.shape[0]
275
+ delta_num = pre_frontiers_num - frontiers_num
276
+
277
+ return delta_num
278
+
279
+ def calculate_reward(self, dist, frontiers):
280
+ reward = 0
281
+ reward -= dist / 64
282
+
283
+ frontiers_to_check = frontiers[:, 0] + frontiers[:, 1] * 1j
284
+ pre_frontiers_to_check = self.frontiers[:, 0] + self.frontiers[:, 1] * 1j
285
+ frontiers_num = np.intersect1d(frontiers_to_check, pre_frontiers_to_check).shape[0]
286
+ pre_frontiers_num = pre_frontiers_to_check.shape[0]
287
+ delta_num = pre_frontiers_num - frontiers_num
288
+
289
+ reward += delta_num / 50
290
+
291
+ return reward
292
+
293
+ def evaluate_exploration_rate(self):
294
+ rate = np.sum(self.coverage_belief == 255) / np.sum(self.ground_truth == 255)
295
+ return rate
296
+
297
+ def evaluate_targets_found_rate(self):
298
+ if len(self.target_positions) == 0:
299
+ return 0
300
+ else:
301
+ rate = self.num_targets_found / len(self.target_positions)
302
+ return rate
303
+
304
+ def calculate_new_free_area(self):
305
+ old_free_area = self.old_robot_belief == 255
306
+ current_free_area = self.robot_belief == 255
307
+
308
+ new_free_area = (current_free_area.astype(np.int) - old_free_area.astype(np.int)) * 255
309
+
310
+ return new_free_area, np.sum(old_free_area)
311
+
312
+ def calculate_dist_path(self, path):
313
+ dist = 0
314
+ start = path[0]
315
+ end = path[-1]
316
+ for index in path:
317
+ if index == end:
318
+ break
319
+ dist += np.linalg.norm(self.node_coords[start] - self.node_coords[index])
320
+ start = index
321
+ return dist
322
+
323
+ def find_frontier(self):
324
+ y_len = self.downsampled_belief.shape[0]
325
+ x_len = self.downsampled_belief.shape[1]
326
+ mapping = self.downsampled_belief.copy()
327
+ belief = self.downsampled_belief.copy()
328
+ # 0-1 unknown area map
329
+ mapping = (mapping == 127) * 1
330
+ mapping = np.lib.pad(mapping, ((1, 1), (1, 1)), 'constant', constant_values=0)
331
+ fro_map = mapping[2:][:, 1:x_len + 1] + mapping[:y_len][:, 1:x_len + 1] + mapping[1:y_len + 1][:, 2:] + \
332
+ mapping[1:y_len + 1][:, :x_len] + mapping[:y_len][:, 2:] + mapping[2:][:, :x_len] + mapping[2:][:,
333
+ 2:] + \
334
+ mapping[:y_len][:, :x_len]
335
+ ind_free = np.where(belief.ravel(order='F') == 255)[0]
336
+ ind_fron_1 = np.where(1 < fro_map.ravel(order='F'))[0]
337
+ ind_fron_2 = np.where(fro_map.ravel(order='F') < 8)[0]
338
+ ind_fron = np.intersect1d(ind_fron_1, ind_fron_2)
339
+ ind_to = np.intersect1d(ind_free, ind_fron)
340
+
341
+ map_x = x_len
342
+ map_y = y_len
343
+ x = np.linspace(0, map_x - 1, map_x)
344
+ y = np.linspace(0, map_y - 1, map_y)
345
+ t1, t2 = np.meshgrid(x, y)
346
+ points = np.vstack([t1.T.ravel(), t2.T.ravel()]).T
347
+
348
+ f = points[ind_to]
349
+ f = f.astype(int)
350
+
351
+ f = f * self.resolution
352
+
353
+ return f
354
+
355
+
356
+
357
+ def plot_env(self, n, path, step, travel_dist, robots_route, img_path_override=None, sat_path_override=None, msk_name_override=None, sound_id_override=None):
358
+
359
+ plt.switch_backend('agg')
360
+ plt.cla()
361
+ color_list = ["r", "g", "c", "m", "y", "k"]
362
+
363
+ if not LOAD_AVS_BENCH:
364
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
365
+ else:
366
+ fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(20, 5.5))
367
+
368
+ ### Fig: Segmentation Mask ###
369
+ if LOAD_AVS_BENCH:
370
+ ax = ax1
371
+ image = mpimg.imread(img_path_override)
372
+ ax.imshow(image)
373
+ ax.set_title("Ground Image")
374
+ ax.axis("off")
375
+
376
+ ### Fig: Environment ###
377
+ msk_name = ""
378
+ if LOAD_AVS_BENCH:
379
+ image = mpimg.imread(sat_path_override)
380
+ msk_name = msk_name_override
381
+
382
+ ### Fig1: Environment ###
383
+ ax = ax2
384
+ ax.imshow(image)
385
+ ax.axis((0, self.ground_truth_size[1], self.ground_truth_size[0], 0))
386
+ ax.set_title("Image")
387
+ for i, route in enumerate(robots_route):
388
+ robot_marker_color = color_list[i % len(color_list)]
389
+ xPoints = route[0]
390
+ yPoints = route[1]
391
+ ax.plot(xPoints, yPoints, c=robot_marker_color, linewidth=2)
392
+ ax.plot(xPoints[-1], yPoints[-1], markersize=12, zorder=99, marker="^", ls="-", c=robot_marker_color, mec="black")
393
+ ax.plot(xPoints[0], yPoints[0], 'co', c=robot_marker_color, markersize=8, zorder=5)
394
+
395
+ # Sensor range
396
+ rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH)
397
+ rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT)
398
+ max_x = min(xPoints[-1] + int(math.ceil(rng_x)), self.ground_truth.shape[1])
399
+ min_x = max(xPoints[-1] - int(math.ceil(rng_x)), 0)
400
+ max_y = min(yPoints[-1] + int(math.ceil(rng_y)), self.ground_truth.shape[0])
401
+ min_y = max(yPoints[-1] - int(math.ceil(rng_y)), 0)
402
+ ax.plot((min_x, min_x), (min_y, max_y), c=robot_marker_color, linewidth=1)
403
+ ax.plot((min_x, max_x), (max_y, max_y), c=robot_marker_color, linewidth=1)
404
+ ax.plot((max_x, max_x), (max_y, min_y), c=robot_marker_color, linewidth=1)
405
+ ax.plot((max_x, min_x), (min_y, min_y), c=robot_marker_color, linewidth=1)
406
+
407
+
408
+ ### Fig: Graph ###
409
+ ax = ax3 if LOAD_AVS_BENCH else ax1
410
+ ax.imshow(self.coverage_belief, cmap='gray')
411
+ ax.axis((0, self.ground_truth_size[1], self.ground_truth_size[0], 0))
412
+ ax.set_title("Information Graph")
413
+ if VIZ_GRAPH_EDGES:
414
+ for i in range(len(self.graph_generator.x)):
415
+ ax.plot(self.graph_generator.x[i], self.graph_generator.y[i], 'tan', zorder=1)
416
+ ax.scatter(self.node_coords[:, 0], self.node_coords[:, 1], c=self.filtered_seg_info_mask, zorder=5, s=8)
417
+
418
+ for i, route in enumerate(robots_route):
419
+ robot_marker_color = color_list[i % len(color_list)]
420
+ xPoints = route[0]
421
+ yPoints = route[1]
422
+ ax.plot(xPoints, yPoints, c=robot_marker_color, linewidth=2)
423
+ ax.plot(xPoints[-1], yPoints[-1], markersize=12, zorder=99, marker="^", ls="-", c=robot_marker_color, mec="black")
424
+ ax.plot(xPoints[0], yPoints[0], 'co', c=robot_marker_color, markersize=8, zorder=5)
425
+
426
+ # Sensor range
427
+ rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH)
428
+ rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT)
429
+ max_x = min(xPoints[-1] + int(math.ceil(rng_x)), self.ground_truth.shape[1])
430
+ min_x = max(xPoints[-1] - int(math.ceil(rng_x)), 0)
431
+ max_y = min(yPoints[-1] + int(math.ceil(rng_y)), self.ground_truth.shape[0])
432
+ min_y = max(yPoints[-1] - int(math.ceil(rng_y)), 0)
433
+ ax.plot((min_x, min_x), (min_y, max_y), c=robot_marker_color, linewidth=1)
434
+ ax.plot((min_x, max_x), (max_y, max_y), c=robot_marker_color, linewidth=1)
435
+ ax.plot((max_x, max_x), (max_y, min_y), c=robot_marker_color, linewidth=1)
436
+ ax.plot((max_x, min_x), (min_y, min_y), c=robot_marker_color, linewidth=1)
437
+
438
+ # Plot target positions
439
+ for target in self.target_positions:
440
+ if self.coverage_belief[target[1], target[0]] == 255:
441
+ ax.plot(target[0], target[1], color='g', marker='x', linestyle='-', markersize=12, markeredgewidth=4, zorder=99)
442
+ else:
443
+ ax.plot(target[0], target[1], color='r', marker='x', linestyle='-', markersize=12, markeredgewidth=4, zorder=99)
444
+
445
+ ### Fig: Segmentation Mask ###
446
+ ax = ax4 if LOAD_AVS_BENCH else ax2
447
+ if LOAD_AVS_BENCH and USE_CLIP_PREDS:
448
+ H, W = self.ground_truth_size
449
+ mask_viz = self.segmentation_info_mask.squeeze().reshape((NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT)).T
450
+ im = ax.imshow(
451
+ mask_viz,
452
+ cmap="viridis",
453
+ origin="upper",
454
+ extent=[0, W, H, 0],
455
+ interpolation="nearest",
456
+ zorder=0,
457
+ )
458
+ ax.set_xlim(0, W)
459
+ ax.set_ylim(H, 0)
460
+ ax.set_axis_off()
461
+ else:
462
+ im = ax.imshow(self.segmentation_mask.mean(axis=-1), cmap='viridis', vmin=0, vmax=100) # cmap='gray'
463
+ ax.axis((0, self.ground_truth_size[1], self.ground_truth_size[0], 0))
464
+ ax.set_title(f"Predicted Mask (Normalized)")
465
+ for i, route in enumerate(robots_route):
466
+ robot_marker_color = color_list[i % len(color_list)]
467
+ xPoints = route[0]
468
+ yPoints = route[1]
469
+ ax.plot(xPoints, yPoints, c=robot_marker_color, linewidth=2)
470
+ ax.plot(xPoints[-1], yPoints[-1], markersize=12, zorder=99, marker="^", ls="-", c=robot_marker_color, mec="black")
471
+ ax.plot(xPoints[0], yPoints[0], 'co', c=robot_marker_color, markersize=8, zorder=5)
472
+
473
+ # Sensor range
474
+ rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH)
475
+ rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT)
476
+ max_x = min(xPoints[-1] + int(math.ceil(rng_x)), self.ground_truth.shape[1])
477
+ min_x = max(xPoints[-1] - int(math.ceil(rng_x)), 0)
478
+ max_y = min(yPoints[-1] + int(math.ceil(rng_y)), self.ground_truth.shape[0])
479
+ min_y = max(yPoints[-1] - int(math.ceil(rng_y)), 0)
480
+ ax.plot((min_x, min_x), (min_y, max_y), c=robot_marker_color, linewidth=1)
481
+ ax.plot((min_x, max_x), (max_y, max_y), c=robot_marker_color, linewidth=1)
482
+ ax.plot((max_x, max_x), (max_y, min_y), c=robot_marker_color, linewidth=1)
483
+ ax.plot((max_x, min_x), (min_y, min_y), c=robot_marker_color, linewidth=1)
484
+
485
+ # Add a colorbar
486
+ cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
487
+ cbar.set_label("Normalized Probs")
488
+
489
+ if sound_id_override is not None:
490
+ plt.suptitle('Targets Found: {}/{} Coverage ratio: {:.4g} Travel Dist: {:.4g} \n ({}) \n (Sound ID: {})'.format(self.num_targets_found, len(self.target_positions), self.explored_rate, travel_dist, msk_name, sound_id_override))
491
+ elif msk_name != "":
492
+ plt.suptitle('Targets Found: {}/{} Coverage ratio: {:.4g} Travel Dist: {:.4g} \n ({})'.format(self.num_targets_found, len(self.target_positions), self.explored_rate, travel_dist, msk_name))
493
+ else:
494
+ plt.suptitle('Targets Found: {}/{} Coverage ratio: {:.4g} Travel Dist: {:.4g}'.format(self.num_targets_found, len(self.target_positions), self.explored_rate, travel_dist))
495
+
496
+ plt.tight_layout()
497
+ plt.savefig('{}/{}_{}_samples.png'.format(path, n, step, dpi=100))
498
+ frame = '{}/{}_{}_samples.png'.format(path, n, step)
499
+ self.frame_files.append(frame)
500
+ plt.close()
501
+
502
+
503
+ ####################
504
+ # ADDED: For app.py
505
+ ####################
506
+
507
+ def plot_heatmap(self, save_dir, step, travel_dist, robots_route=None):
508
+ """Plot only the segmentation heatmap and save it as ``{step}.png`` in
509
+ ``save_dir``. This lightweight helper is meant for asynchronous
510
+ streaming in the Gradio demo when full `plot_env` is too heavy.
511
+
512
+ Parameters
513
+ ----------
514
+ save_dir : str
515
+ Directory to save the generated PNG file.
516
+ step : int
517
+ Current timestep; becomes the filename ``{step}.png``.
518
+ robots_route : list | None
519
+ Optional list of routes (xPoints, yPoints) to overlay.
520
+ Returns
521
+ -------
522
+ str
523
+ Full path to the generated PNG file.
524
+ """
525
+ import os
526
+ plt.switch_backend('agg')
527
+ # Do not clear the global figure state in case it interferes with
528
+ # the current figure. Each call creates its own Figure object that
529
+ # we close explicitly at the end, so a global clear is unnecessary
530
+ # and may break concurrent drawing.
531
+ # plt.cla()
532
+
533
+ color_list = ["r", "g", "c", "m", "y", "k"]
534
+ fig, ax = plt.subplots(1, 1, figsize=(6, 6))
535
+
536
+ # Select the mask to visualise
537
+ # if TAXABIND_TTA and USE_CLIP_PREDS:
538
+ side_dim = int(np.sqrt(self.segmentation_info_mask.shape[0]))
539
+ mask_viz = self.segmentation_info_mask.squeeze().reshape((side_dim, side_dim)).T
540
+
541
+ # Properly map image to pixel coordinates and keep limits fixed
542
+ H, W = self.ground_truth_size # rows (y), cols (x)
543
+ im = ax.imshow(
544
+ mask_viz,
545
+ cmap="viridis",
546
+ origin="upper",
547
+ extent=[0, W, H, 0], # x: 0..W, y: H..0 (origin at top-left)
548
+ interpolation="nearest", # keep cell edges sharp & aligned
549
+ zorder=0,
550
+ )
551
+ ax.set_xlim(0, W)
552
+ ax.set_ylim(H, 0)
553
+ ax.set_axis_off() # hide ticks but keep limits
554
+ # else:
555
+ # im = ax.imshow(self.segmentation_mask.mean(axis=-1), cmap='viridis', vmin=0, vmax=100)
556
+ # ax.axis((0, self.ground_truth_size[1], self.ground_truth_size[0], 0))
557
+
558
+ # Optionally overlay robot paths
559
+ if robots_route is not None:
560
+ for i, route in enumerate(robots_route):
561
+ robot_marker_color = color_list[i % len(color_list)]
562
+ xPoints, yPoints = route
563
+ ax.plot(xPoints, yPoints, c=robot_marker_color, linewidth=2)
564
+ ax.plot(xPoints[-1], yPoints[-1], markersize=12, zorder=99, marker="^", ls="-", c=robot_marker_color, mec="black")
565
+ ax.plot(xPoints[0], yPoints[0], 'co', c=robot_marker_color, markersize=8, zorder=5)
566
+
567
+ # Plot target positions
568
+ for target in self.target_positions:
569
+ if self.coverage_belief[target[1], target[0]] == 255:
570
+ # ax.plot(target[0], target[1], 'go', markersize=8, zorder=99)
571
+ ax.plot(target[0], target[1], color='g', marker='x', linestyle='-', markersize=12, markeredgewidth=4, zorder=99)
572
+ else:
573
+ # ax.plot(target[0], target[1], 'ro', markersize=8, zorder=99)
574
+ ax.plot(target[0], target[1], color='r', marker='x', linestyle='-', markersize=12, markeredgewidth=4, zorder=99)
575
+
576
+ # Sensor range
577
+ rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH)
578
+ rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT)
579
+ max_x = min(xPoints[-1] + int(math.ceil(rng_x)), self.ground_truth.shape[1])
580
+ min_x = max(xPoints[-1] - int(math.ceil(rng_x)), 0)
581
+ max_y = min(yPoints[-1] + int(math.ceil(rng_y)), self.ground_truth.shape[0])
582
+ min_y = max(yPoints[-1] - int(math.ceil(rng_y)), 0)
583
+ ax.plot((min_x, min_x), (min_y, max_y), c=robot_marker_color, linewidth=1)
584
+ ax.plot((min_x, max_x), (max_y, max_y), c=robot_marker_color, linewidth=1)
585
+ ax.plot((max_x, max_x), (max_y, min_y), c=robot_marker_color, linewidth=1)
586
+ ax.plot((max_x, min_x), (min_y, min_y), c=robot_marker_color, linewidth=1)
587
+
588
+ # Color bar
589
+ cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
590
+ cbar.set_label("Normalized Probs")
591
+
592
+ # Change coverage to 1dp
593
+ plt.suptitle('Targets Found: {}/{} Coverage: {:.1f}% Steps: {}/{}'.format(
594
+ self.num_targets_found, \
595
+ len(self.target_positions),
596
+ self.explored_rate*100,
597
+ step+1,
598
+ NUM_EPS_STEPS),
599
+ y=0.94, # Closer to plot
600
+ )
601
+
602
+ plt.tight_layout()
603
+ os.makedirs(save_dir, exist_ok=True)
604
+ out_path = os.path.join(save_dir, f"{step}.png")
605
+ # Save atomically: write to temp file then move into place so the poller never sees a partial file.
606
+ tmp_path = out_path + ".tmp"
607
+ fig.savefig(tmp_path, dpi=100, format='png')
608
+ os.replace(tmp_path, out_path) # atomic on same filesystem
609
+ plt.close(fig)
610
+ return out_path
planner/graph.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #######################################################################
2
+ # Name: env.py
3
+ #
4
+ # - Adapted from https://gist.github.com/betandr/541a1f6466b6855471de5ca30b74cb31
5
+ # - Simple graph class to perform distance calculations (E.g. A-Star, Djikstra)
6
+ #######################################################################
7
+
8
+
9
+ class Edge:
10
+ def __init__(self, to_node, length):
11
+ self.to_node = to_node
12
+ self.length = length
13
+
14
+
15
+ class Graph:
16
+ def __init__(self):
17
+ self.nodes = set()
18
+ self.edges = dict()
19
+
20
+ def add_node(self, node):
21
+ self.nodes.add(node)
22
+
23
+ def add_edge(self, from_node, to_node, length):
24
+ edge = Edge(to_node, length)
25
+ if from_node in self.edges:
26
+ from_node_edges = self.edges[from_node]
27
+ else:
28
+ self.edges[from_node] = dict()
29
+ from_node_edges = self.edges[from_node]
30
+ from_node_edges[to_node] = edge
31
+
32
+ def clear_edge(self, from_node):
33
+ if from_node in self.edges:
34
+ self.edges[from_node] = dict()
35
+
36
+ def min_dist(q, dist):
37
+ """
38
+ Returns the node with the smallest distance in q.
39
+ Implemented to keep the main algorithm clean.
40
+ """
41
+ min_node = None
42
+ for node in q:
43
+ if min_node == None:
44
+ min_node = node
45
+ elif dist[node] < dist[min_node]:
46
+ min_node = node
47
+
48
+ return min_node
49
+
50
+
51
+ INFINITY = float('Infinity')
52
+
53
+
54
+ def dijkstra(graph, source):
55
+ q = set()
56
+ dist = {}
57
+ prev = {}
58
+
59
+ for v in graph.nodes:
60
+ dist[v] = INFINITY # unknown distance from source to v
61
+ prev[v] = INFINITY # previous node in optimal path from source
62
+ q.add(v) # all nodes initially in q (unvisited nodes)
63
+
64
+ # distance from source to source
65
+ dist[source] = 0
66
+
67
+ while q:
68
+ # node with the least distance selected first
69
+ u = min_dist(q, dist)
70
+
71
+ q.remove(u)
72
+
73
+ try:
74
+ if u in graph.edges:
75
+ for _, v in graph.edges[u].items():
76
+ alt = dist[u] + v.length
77
+ if alt < dist[v.to_node]:
78
+ # a shorter path to v has been found
79
+ dist[v.to_node] = alt
80
+ prev[v.to_node] = u
81
+ except:
82
+ pass
83
+
84
+ return dist, prev
85
+
86
+
87
+ def to_array(prev, from_node):
88
+ """Creates an ordered list of labels as a route."""
89
+ previous_node = prev[from_node]
90
+ route = [from_node]
91
+
92
+ while previous_node != INFINITY:
93
+ route.append(previous_node)
94
+ temp = previous_node
95
+ previous_node = prev[temp]
96
+
97
+ route.reverse()
98
+ return route
99
+
100
+
101
+ def h(index, destination, node_coords):
102
+ current = node_coords[index]
103
+ end = node_coords[destination]
104
+ h = abs(end[0] - current[0]) + abs(end[1] - current[1])
105
+ return h
106
+
107
+
108
+ def a_star(start, destination, node_coords, graph):
109
+ if start == destination:
110
+ return [], 0
111
+ if str(destination) in graph.edges[str(start)].keys():
112
+ cost = graph.edges[str(start)][str(destination)].length
113
+ return [start, destination], cost
114
+ open_list = {start}
115
+ closed_list = set([])
116
+
117
+ g = {start: 0}
118
+ parents = {start: start}
119
+
120
+ while len(open_list) > 0:
121
+ n = None
122
+ h_n = 1e5
123
+ for v in open_list:
124
+ h_v = h(v, destination, node_coords)
125
+ if n is not None:
126
+ h_n = h(n, destination, node_coords)
127
+ if n is None or g[v] + h_v < g[n] + h_n:
128
+ n = v
129
+
130
+ if n is None:
131
+ print('Path does not exist!')
132
+ return None, 1e5
133
+
134
+ if n == destination:
135
+ reconst_path = []
136
+ while parents[n] != n:
137
+ reconst_path.append(n)
138
+ n = parents[n]
139
+ reconst_path.append(start)
140
+ reconst_path.reverse()
141
+ return reconst_path, g[destination]
142
+
143
+ for edge in graph.edges[str(n)].values():
144
+ m = int(edge.to_node)
145
+ cost = edge.length
146
+ if m not in open_list and m not in closed_list:
147
+ open_list.add(m)
148
+ parents[m] = n
149
+ g[m] = g[n] + cost
150
+
151
+ else:
152
+ if g[m] > g[n] + cost:
153
+ g[m] = g[n] + cost
154
+ parents[m] = n
155
+
156
+ if m in closed_list:
157
+ closed_list.remove(m)
158
+ open_list.add(m)
159
+
160
+ open_list.remove(n)
161
+ closed_list.add(n)
162
+
163
+ print('Path does not exist!')
164
+ return None, 1e5
165
+
166
+
167
+
planner/graph_generator.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #######################################################################
2
+ # Name: graph_generator.py
3
+ #
4
+ # - Wrapper for graph.py
5
+ # - Sends the formatted inputs into graph.py to get useful info
6
+ #######################################################################
7
+
8
+ import sys
9
+ if sys.modules['TRAINING']:
10
+ from .parameter import *
11
+ else:
12
+ from .test_parameter import *
13
+
14
+ import numpy as np
15
+ import shapely.geometry
16
+ from sklearn.neighbors import NearestNeighbors
17
+ from .node import Node
18
+ from .graph import Graph, a_star
19
+
20
+
21
+ class Graph_generator:
22
+ def __init__(self, map_size, k_size, sensor_range, plot=False):
23
+ self.k_size = k_size
24
+ self.graph = Graph()
25
+ self.node_coords = None
26
+ self.plot = plot
27
+ self.x = []
28
+ self.y = []
29
+ self.map_x = map_size[1]
30
+ self.map_y = map_size[0]
31
+ self.uniform_points, self.grid_coords = self.generate_uniform_points()
32
+ self.sensor_range = sensor_range
33
+ self.route_node = []
34
+ self.nodes_list = []
35
+ self.node_utility = None
36
+ self.guidepost = None
37
+
38
+
39
+ def edge_clear_all_nodes(self):
40
+ self.graph = Graph()
41
+ self.x = []
42
+ self.y = []
43
+
44
+
45
+ def edge_clear(self, coords):
46
+ node_index = str(self.find_index_from_coords(self.node_coords, coords))
47
+ self.graph.clear_edge(node_index)
48
+
49
+
50
+ def generate_graph(self, robot_belief, frontiers):
51
+ self.edge_clear_all_nodes()
52
+ free_area = self.free_area(robot_belief)
53
+
54
+ free_area_to_check = free_area[:, 0] + free_area[:, 1] * 1j
55
+ uniform_points_to_check = self.uniform_points[:, 0] + self.uniform_points[:, 1] * 1j
56
+ _, _, candidate_indices = np.intersect1d(free_area_to_check, uniform_points_to_check, return_indices=True)
57
+ node_coords = self.uniform_points[candidate_indices]
58
+
59
+ self.node_coords = self.unique_coords(node_coords).reshape(-1, 2)
60
+ self.find_nearest_neighbor_all_nodes(self.node_coords, robot_belief)
61
+
62
+ self.node_utility = []
63
+ for coords in self.node_coords:
64
+ node = Node(coords, frontiers, robot_belief)
65
+ self.nodes_list.append(node)
66
+ utility = node.utility
67
+ self.node_utility.append(utility)
68
+ self.node_utility = np.array(self.node_utility)
69
+
70
+ self.guidepost = np.zeros((self.node_coords.shape[0], 1))
71
+ x = self.node_coords[:,0] + self.node_coords[:,1]*1j
72
+ for node in self.route_node:
73
+ index = np.argwhere(x.reshape(-1) == node[0]+node[1]*1j)[0]
74
+ self.guidepost[index] = 1
75
+
76
+ return self.node_coords, self.graph.edges, self.node_utility, self.guidepost
77
+
78
+
79
+ def update_graph(self, robot_belief, old_robot_belief, frontiers, old_frontiers):
80
+ new_free_area = self.free_area((robot_belief - old_robot_belief > 0) * 255)
81
+ free_area_to_check = new_free_area[:, 0] + new_free_area[:, 1] * 1j
82
+ uniform_points_to_check = self.uniform_points[:, 0] + self.uniform_points[:, 1] * 1j
83
+ _, _, candidate_indices = np.intersect1d(free_area_to_check, uniform_points_to_check, return_indices=True)
84
+ new_node_coords = self.uniform_points[candidate_indices]
85
+ self.node_coords = np.concatenate((self.node_coords, new_node_coords))
86
+
87
+ old_node_to_update = []
88
+ for coords in new_node_coords:
89
+ neighbor_indices = self.find_k_neighbor(coords, self.node_coords, robot_belief)
90
+ old_node_to_update += neighbor_indices
91
+ old_node_to_update = set(old_node_to_update)
92
+ for index in old_node_to_update:
93
+ coords = self.node_coords[index]
94
+ self.edge_clear(coords)
95
+ self.find_k_neighbor(coords, self.node_coords, robot_belief)
96
+
97
+ old_frontiers_to_check = old_frontiers[:, 0] + old_frontiers[:, 1] * 1j
98
+ new_frontiers_to_check = frontiers[:, 0] + frontiers[:, 1] * 1j
99
+ observed_frontiers_index = np.where(
100
+ np.isin(old_frontiers_to_check, new_frontiers_to_check, assume_unique=True) == False)
101
+ new_frontiers_index = np.where(
102
+ np.isin(new_frontiers_to_check, old_frontiers_to_check, assume_unique=True) == False)
103
+ observed_frontiers = old_frontiers[observed_frontiers_index]
104
+ new_frontiers = frontiers[new_frontiers_index]
105
+ for node in self.nodes_list:
106
+ if node.zero_utility_node is True:
107
+ pass
108
+ else:
109
+ node.update_observable_frontiers(observed_frontiers, new_frontiers, robot_belief)
110
+
111
+ for new_coords in new_node_coords:
112
+ node = Node(new_coords, frontiers, robot_belief)
113
+ self.nodes_list.append(node)
114
+
115
+ self.node_utility = []
116
+ for i, coords in enumerate(self.node_coords):
117
+ utility = self.nodes_list[i].utility
118
+ self.node_utility.append(utility)
119
+ self.node_utility = np.array(self.node_utility)
120
+
121
+ self.guidepost = np.zeros((self.node_coords.shape[0], 1))
122
+ x = self.node_coords[:, 0] + self.node_coords[:, 1] * 1j
123
+ for node in self.route_node:
124
+ index = np.argwhere(x.reshape(-1) == node[0] + node[1] * 1j)
125
+ self.guidepost[index] = 1
126
+
127
+ return self.node_coords, self.graph.edges, self.node_utility, self.guidepost
128
+
129
+
130
+ def generate_uniform_points(self):
131
+ padding_x = 0.5 * (self.map_x / NUM_COORDS_WIDTH)
132
+ padding_y = 0.5 * (self.map_y / NUM_COORDS_HEIGHT)
133
+ x = np.linspace(padding_x, self.map_x - padding_x - 1, NUM_COORDS_WIDTH).round().astype(int)
134
+ y = np.linspace(padding_y, self.map_y - padding_y - 1, NUM_COORDS_HEIGHT).round().astype(int)
135
+
136
+ t1, t2 = np.meshgrid(x, y)
137
+ points = np.vstack([t1.T.ravel(), t2.T.ravel()]).T
138
+ matrix = np.stack((t1, t2), axis=-1)
139
+ return points, matrix
140
+
141
+
142
+ def free_area(self, robot_belief):
143
+ index = np.where(robot_belief == 255)
144
+ free = np.asarray([index[1], index[0]]).T
145
+ return free
146
+
147
+
148
+ def unique_coords(self, coords):
149
+ x = coords[:, 0] + coords[:, 1] * 1j
150
+ indices = np.unique(x, return_index=True)[1]
151
+ coords = np.array([coords[idx] for idx in sorted(indices)])
152
+ return coords
153
+
154
+
155
+ def find_k_neighbor(self, coords, node_coords, robot_belief):
156
+ dist_list = np.linalg.norm((coords-node_coords), axis=-1)
157
+ sorted_index = np.argsort(dist_list)
158
+ k = 0
159
+ neighbor_index_list = []
160
+ while k < self.k_size and k< node_coords.shape[0]:
161
+ neighbor_index = sorted_index[k]
162
+ neighbor_index_list.append(neighbor_index)
163
+ dist = dist_list[k]
164
+ start = coords
165
+ end = node_coords[neighbor_index]
166
+ if not self.check_collision(start, end, robot_belief):
167
+ a = str(self.find_index_from_coords(node_coords, start))
168
+ b = str(neighbor_index)
169
+ self.graph.add_node(a)
170
+ self.graph.add_edge(a, b, dist)
171
+
172
+ if self.plot:
173
+ self.x.append([start[0], end[0]])
174
+ self.y.append([start[1], end[1]])
175
+ k += 1
176
+ return neighbor_index_list
177
+
178
+
179
+ def find_k_neighbor_all_nodes(self, node_coords, robot_belief):
180
+ X = node_coords
181
+ if len(node_coords) >= self.k_size:
182
+ knn = NearestNeighbors(n_neighbors=self.k_size)
183
+ else:
184
+ knn = NearestNeighbors(n_neighbors=len(node_coords))
185
+ knn.fit(X)
186
+ distances, indices = knn.kneighbors(X)
187
+
188
+ for i, p in enumerate(X):
189
+ for j, neighbour in enumerate(X[indices[i][:]]):
190
+ start = p
191
+ end = neighbour
192
+ if not self.check_collision(start, end, robot_belief):
193
+ a = str(self.find_index_from_coords(node_coords, p))
194
+ b = str(self.find_index_from_coords(node_coords, neighbour))
195
+ self.graph.add_node(a)
196
+ self.graph.add_edge(a, b, distances[i, j])
197
+
198
+ if self.plot:
199
+ self.x.append([p[0], neighbour[0]])
200
+ self.y.append([p[1], neighbour[1]])
201
+
202
+
203
+ def find_nearest_neighbor_all_nodes(self, node_coords, robot_belief):
204
+ for i, p in enumerate(node_coords):
205
+ filtered_coords = self.get_neighbors_grid_coords(p)
206
+
207
+ for j, neighbour in enumerate(filtered_coords):
208
+ start = p
209
+ end = neighbour
210
+ if not self.check_collision(start, end, robot_belief):
211
+ a = str(self.find_index_from_coords(node_coords, p))
212
+ b = str(self.find_index_from_coords(node_coords, neighbour))
213
+ self.graph.add_node(a)
214
+ self.graph.add_edge(a, b, np.linalg.norm(start-end))
215
+
216
+ if self.plot:
217
+ self.x.append([p[0], neighbour[0]])
218
+ self.y.append([p[1], neighbour[1]])
219
+
220
+
221
+ def find_index_from_coords(self, node_coords, p):
222
+ return np.where(np.linalg.norm(node_coords - p, axis=1) < 1e-5)[0][0]
223
+
224
+
225
+ def find_closest_index_from_coords(self, node_coords, p):
226
+ return np.argmin(np.linalg.norm(node_coords - p, axis=1))
227
+
228
+
229
+ def find_index_from_grid_coords_2d(self, p):
230
+ diffs = np.linalg.norm(self.grid_coords - p, axis=2)
231
+ indices = np.where(diffs < 1e-5)
232
+
233
+ if indices[0].size > 0:
234
+ return indices[0][0], indices[1][0]
235
+ else:
236
+ raise ValueError(f"Coordinate {p} not found in self.grid_coords.")
237
+
238
+
239
+ def find_closest_index_from_grid_coords_2d(self, p):
240
+ distances = np.linalg.norm(self.grid_coords - p, axis=2)
241
+ flat_index = np.argmin(distances)
242
+ return np.unravel_index(flat_index, distances.shape)
243
+
244
+
245
+ def check_collision(self, start, end, robot_belief):
246
+ collision = False
247
+ line = shapely.geometry.LineString([start, end])
248
+
249
+ sortx = np.sort([start[0], end[0]])
250
+ sorty = np.sort([start[1], end[1]])
251
+
252
+ robot_belief = robot_belief[sorty[0]:sorty[1]+1, sortx[0]:sortx[1]+1]
253
+
254
+ occupied_area_index = np.where(robot_belief == 1)
255
+ occupied_area_coords = np.asarray([occupied_area_index[1]+sortx[0], occupied_area_index[0]+sorty[0]]).T
256
+ unexplored_area_index = np.where(robot_belief == 127)
257
+ unexplored_area_coords = np.asarray([unexplored_area_index[1]+sortx[0], unexplored_area_index[0]+sorty[0]]).T
258
+ unfree_area_coords = np.concatenate((occupied_area_coords, unexplored_area_coords))
259
+
260
+ for i in range(unfree_area_coords.shape[0]):
261
+ coords = ([(unfree_area_coords[i][0], unfree_area_coords[i][1]),
262
+ (unfree_area_coords[i][0] + 1, unfree_area_coords[i][1]),
263
+ (unfree_area_coords[i][0], unfree_area_coords[i][1] + 1),
264
+ (unfree_area_coords[i][0] + 1, unfree_area_coords[i][1] + 1)])
265
+ obstacle = shapely.geometry.Polygon(coords)
266
+ collision = line.intersects(obstacle)
267
+ if collision:
268
+ break
269
+
270
+ return collision
271
+
272
+
273
+ def find_shortest_path(self, current, destination, node_coords):
274
+ start_node = str(self.find_index_from_coords(node_coords, current))
275
+ end_node = str(self.find_index_from_coords(node_coords, destination))
276
+ route, dist = a_star(int(start_node), int(end_node), self.node_coords, self.graph)
277
+ if start_node != end_node:
278
+ assert route != []
279
+ route = list(map(str, route))
280
+ return dist, route
281
+
282
+ def get_neighbors_grid_coords(self, coord):
283
+ # Return the 8 closest neighbors of a given coordinate
284
+
285
+ nearest_coord = self.node_coords[self.find_closest_index_from_coords(self.node_coords, coord)]
286
+ rows, cols = self.grid_coords.shape[:2]
287
+ neighbors = []
288
+ i, j = self.find_index_from_grid_coords_2d(nearest_coord)
289
+
290
+ # Create a range of indices for rows and columns
291
+ row_range = np.clip([i - 1, i, i + 1], 0, rows - 1)
292
+ col_range = np.clip([j - 1, j, j + 1], 0, cols - 1)
293
+
294
+ # Iterate over the valid indices
295
+ for ni in row_range:
296
+ for nj in col_range:
297
+ if (ni, nj) != (i, j): # Skip the center point
298
+ neighbors.append(tuple(self.grid_coords[ni, nj]))
299
+
300
+ return neighbors
planner/model.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #######################################################################
2
+ # Name: model.py
3
+ #
4
+ # - Attention-based encoders & decoders
5
+ # - Policy Net: Input = Augmented Graph, Output = Node to go to
6
+ # - Critic Net: Input = Augmented Graph + Action, Output = Q_Value
7
+ #######################################################################
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import math
12
+
13
+
14
+ class SingleHeadAttention(nn.Module):
15
+ def __init__(self, embedding_dim):
16
+ super(SingleHeadAttention, self).__init__()
17
+ self.input_dim = embedding_dim
18
+ self.embedding_dim = embedding_dim
19
+ self.value_dim = embedding_dim
20
+ self.key_dim = self.value_dim
21
+ self.tanh_clipping = 10
22
+ self.norm_factor = 1 / math.sqrt(self.key_dim)
23
+
24
+ self.w_query = nn.Parameter(torch.Tensor(self.input_dim, self.key_dim))
25
+ self.w_key = nn.Parameter(torch.Tensor(self.input_dim, self.key_dim))
26
+
27
+ self.init_parameters()
28
+
29
+ def init_parameters(self):
30
+ for param in self.parameters():
31
+ stdv = 1. / math.sqrt(param.size(-1))
32
+ param.data.uniform_(-stdv, stdv)
33
+
34
+ def forward(self, q, k, mask=None):
35
+
36
+ n_batch, n_key, n_dim = k.size()
37
+ n_query = q.size(1)
38
+
39
+ k_flat = k.reshape(-1, n_dim)
40
+ q_flat = q.reshape(-1, n_dim)
41
+
42
+ shape_k = (n_batch, n_key, -1)
43
+ shape_q = (n_batch, n_query, -1)
44
+
45
+ Q = torch.matmul(q_flat, self.w_query).view(shape_q)
46
+ K = torch.matmul(k_flat, self.w_key).view(shape_k)
47
+
48
+ U = self.norm_factor * torch.matmul(Q, K.transpose(1, 2))
49
+ U = self.tanh_clipping * torch.tanh(U)
50
+
51
+ if mask is not None:
52
+ U = U.masked_fill(mask == 1, -1e8)
53
+ attention = torch.log_softmax(U, dim=-1) # n_batch*n_query*n_key
54
+
55
+ return attention
56
+
57
+
58
+ class MultiHeadAttention(nn.Module):
59
+ def __init__(self, embedding_dim, n_heads=8):
60
+ super(MultiHeadAttention, self).__init__()
61
+ self.n_heads = n_heads
62
+ self.input_dim = embedding_dim
63
+ self.embedding_dim = embedding_dim
64
+ self.value_dim = self.embedding_dim // self.n_heads
65
+ self.key_dim = self.value_dim
66
+ self.norm_factor = 1 / math.sqrt(self.key_dim)
67
+
68
+ self.w_query = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.key_dim))
69
+ self.w_key = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.key_dim))
70
+ self.w_value = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.value_dim))
71
+ self.w_out = nn.Parameter(torch.Tensor(self.n_heads, self.value_dim, self.embedding_dim))
72
+
73
+ self.init_parameters()
74
+
75
+ def init_parameters(self):
76
+ for param in self.parameters():
77
+ stdv = 1. / math.sqrt(param.size(-1))
78
+ param.data.uniform_(-stdv, stdv)
79
+
80
+ def forward(self, q, k=None, v=None, key_padding_mask=None, attn_mask=None):
81
+ if k is None:
82
+ k = q
83
+ if v is None:
84
+ v = q
85
+
86
+ n_batch, n_key, n_dim = k.size()
87
+ n_query = q.size(1)
88
+ n_value = v.size(1)
89
+
90
+ k_flat = k.contiguous().view(-1, n_dim)
91
+ v_flat = v.contiguous().view(-1, n_dim)
92
+ q_flat = q.contiguous().view(-1, n_dim)
93
+ shape_v = (self.n_heads, n_batch, n_value, -1)
94
+ shape_k = (self.n_heads, n_batch, n_key, -1)
95
+ shape_q = (self.n_heads, n_batch, n_query, -1)
96
+
97
+ Q = torch.matmul(q_flat, self.w_query).view(shape_q) # n_heads*batch_size*n_query*key_dim
98
+ K = torch.matmul(k_flat, self.w_key).view(shape_k) # n_heads*batch_size*targets_size*key_dim
99
+ V = torch.matmul(v_flat, self.w_value).view(shape_v) # n_heads*batch_size*targets_size*value_dim
100
+
101
+ U = self.norm_factor * torch.matmul(Q, K.transpose(2, 3)) # n_heads*batch_size*n_query*targets_size
102
+
103
+ if attn_mask is not None:
104
+ attn_mask = attn_mask.view(1, n_batch, n_query, n_key).expand_as(U)
105
+
106
+ if key_padding_mask is not None:
107
+ key_padding_mask = key_padding_mask.repeat(1, n_query, 1)
108
+ key_padding_mask = key_padding_mask.view(1, n_batch, n_query, n_key).expand_as(U) # copy for n_heads times
109
+
110
+ if attn_mask is not None and key_padding_mask is not None:
111
+ mask = (attn_mask + key_padding_mask)
112
+ elif attn_mask is not None:
113
+ mask = attn_mask
114
+ elif key_padding_mask is not None:
115
+ mask = key_padding_mask
116
+ else:
117
+ mask = None
118
+
119
+ if mask is not None:
120
+ U = U.masked_fill(mask > 0, -1e8)
121
+
122
+ attention = torch.softmax(U, dim=-1) # n_heads*batch_size*n_query*targets_size
123
+ heads = torch.matmul(attention, V) # n_heads*batch_size*n_query*value_dim
124
+ out = torch.mm(
125
+ heads.permute(1, 2, 0, 3).reshape(-1, self.n_heads * self.value_dim),
126
+ # batch_size*n_query*n_heads*value_dim
127
+ self.w_out.view(-1, self.embedding_dim)
128
+ # n_heads*value_dim*embedding_dim
129
+ ).view(-1, n_query, self.embedding_dim)
130
+
131
+
132
+ return out, attention # batch_size*n_query*embedding_dim
133
+
134
+
135
+ class Normalization(nn.Module):
136
+ def __init__(self, embedding_dim):
137
+ super(Normalization, self).__init__()
138
+ self.normalizer = nn.LayerNorm(embedding_dim)
139
+
140
+ def forward(self, input):
141
+ return self.normalizer(input.view(-1, input.size(-1))).view(*input.size())
142
+
143
+
144
+ class EncoderLayer(nn.Module):
145
+ def __init__(self, embedding_dim, n_head):
146
+ super(EncoderLayer, self).__init__()
147
+ self.multiHeadAttention = MultiHeadAttention(embedding_dim, n_head)
148
+ self.normalization1 = Normalization(embedding_dim)
149
+ self.feedForward = nn.Sequential(nn.Linear(embedding_dim, 512), nn.ReLU(inplace=True),
150
+ nn.Linear(512, embedding_dim))
151
+ self.normalization2 = Normalization(embedding_dim)
152
+
153
+ def forward(self, src, key_padding_mask=None, attn_mask=None):
154
+ h0 = src
155
+ h = self.normalization1(src)
156
+ h, _ = self.multiHeadAttention(q=h, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
157
+ h = h + h0
158
+ h1 = h
159
+ h = self.normalization2(h)
160
+ h = self.feedForward(h)
161
+ h2 = h + h1
162
+ return h2
163
+
164
+
165
+ class DecoderLayer(nn.Module):
166
+ def __init__(self, embedding_dim, n_head):
167
+ super(DecoderLayer, self).__init__()
168
+ self.multiHeadAttention = MultiHeadAttention(embedding_dim, n_head)
169
+ self.normalization1 = Normalization(embedding_dim)
170
+ self.feedForward = nn.Sequential(nn.Linear(embedding_dim, 512),
171
+ nn.ReLU(inplace=True),
172
+ nn.Linear(512, embedding_dim))
173
+ self.normalization2 = Normalization(embedding_dim)
174
+
175
+ def forward(self, tgt, memory, key_padding_mask=None, attn_mask=None):
176
+ h0 = tgt
177
+ tgt = self.normalization1(tgt)
178
+ memory = self.normalization1(memory)
179
+ h, w = self.multiHeadAttention(q=tgt, k=memory, v=memory, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
180
+ h = h + h0
181
+ h1 = h
182
+ h = self.normalization2(h)
183
+ h = self.feedForward(h)
184
+ h2 = h + h1
185
+ return h2, w
186
+
187
+
188
+ class Encoder(nn.Module):
189
+ def __init__(self, embedding_dim=128, n_head=8, n_layer=1):
190
+ super(Encoder, self).__init__()
191
+ self.layers = nn.ModuleList(EncoderLayer(embedding_dim, n_head) for i in range(n_layer))
192
+
193
+ def forward(self, src, key_padding_mask=None, attn_mask=None):
194
+ for layer in self.layers:
195
+ src = layer(src, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
196
+ return src
197
+
198
+
199
+ class Decoder(nn.Module):
200
+ def __init__(self, embedding_dim=128, n_head=8, n_layer=1):
201
+ super(Decoder, self).__init__()
202
+ self.layers = nn.ModuleList([DecoderLayer(embedding_dim, n_head) for i in range(n_layer)])
203
+
204
+ def forward(self, tgt, memory, key_padding_mask=None, attn_mask=None):
205
+ for layer in self.layers:
206
+ tgt, w = layer(tgt, memory, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
207
+ return tgt, w
208
+
209
+
210
+ class PolicyNet(nn.Module):
211
+ def __init__(self, input_dim, embedding_dim):
212
+ super(PolicyNet, self).__init__()
213
+ self.initial_embedding = nn.Linear(input_dim, embedding_dim) # layer for non-end position
214
+
215
+ self.current_embedding = nn.Linear(embedding_dim * 2, embedding_dim)
216
+
217
+ self.encoder = Encoder(embedding_dim=embedding_dim, n_head=8, n_layer=6)
218
+ self.decoder = Decoder(embedding_dim=embedding_dim, n_head=8, n_layer=1)
219
+ self.pointer = SingleHeadAttention(embedding_dim)
220
+
221
+ def encode_graph(self, node_inputs, node_padding_mask, edge_mask):
222
+ node_feature = self.initial_embedding(node_inputs)
223
+ enhanced_node_feature = self.encoder(src=node_feature, key_padding_mask=node_padding_mask, attn_mask=edge_mask)
224
+
225
+ return enhanced_node_feature
226
+
227
+ def output_policy(self, enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask):
228
+ k_size = edge_inputs.size()[2]
229
+ current_edge = torch.gather(edge_inputs, 1, current_index.repeat(1, 1, k_size))
230
+ current_edge = current_edge.permute(0, 2, 1)
231
+ embedding_dim = enhanced_node_feature.size()[2]
232
+
233
+ neigboring_feature = torch.gather(enhanced_node_feature, 1, current_edge.repeat(1, 1, embedding_dim))
234
+
235
+ current_node_feature = torch.gather(enhanced_node_feature, 1, current_index.repeat(1, 1, embedding_dim))
236
+
237
+ if edge_padding_mask is not None:
238
+ current_mask = torch.gather(edge_padding_mask, 1, current_index.repeat(1,1,k_size)).to(enhanced_node_feature.device)
239
+ else:
240
+ current_mask = None
241
+ current_mask[:,:,0] = 1 # don't stay at current position
242
+
243
+ if not 0 in current_mask:
244
+ current_mask[:,:,0] = 0
245
+
246
+ enhanced_current_node_feature, _ = self.decoder(current_node_feature, enhanced_node_feature, node_padding_mask)
247
+ enhanced_current_node_feature = self.current_embedding(torch.cat((enhanced_current_node_feature, current_node_feature), dim=-1))
248
+ logp = self.pointer(enhanced_current_node_feature, neigboring_feature, current_mask)
249
+ logp= logp.squeeze(1) # batch_size*k_size
250
+
251
+ return logp
252
+
253
+ def forward(self, node_inputs, edge_inputs, current_index, node_padding_mask=None, edge_padding_mask=None, edge_mask=None):
254
+ enhanced_node_feature = self.encode_graph(node_inputs, node_padding_mask, edge_mask)
255
+ logp = self.output_policy(enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask)
256
+ return logp
257
+
258
+
259
+ class QNet(nn.Module):
260
+ def __init__(self, input_dim, embedding_dim):
261
+ super(QNet, self).__init__()
262
+ self.initial_embedding = nn.Linear(input_dim, embedding_dim) # layer for non-end position
263
+ self.action_embedding = nn.Linear(embedding_dim*3, embedding_dim)
264
+
265
+ self.encoder = Encoder(embedding_dim=embedding_dim, n_head=8, n_layer=6)
266
+ self.decoder = Decoder(embedding_dim=embedding_dim, n_head=8, n_layer=1)
267
+
268
+ self.q_values_layer = nn.Linear(embedding_dim, 1)
269
+
270
+ def encode_graph(self, node_inputs, node_padding_mask, edge_mask):
271
+ embedding_feature = self.initial_embedding(node_inputs)
272
+ embedding_feature = self.encoder(src=embedding_feature, key_padding_mask=node_padding_mask, attn_mask=edge_mask)
273
+
274
+ return embedding_feature
275
+
276
+ def output_q_values(self, enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask):
277
+ k_size = edge_inputs.size()[2]
278
+ current_edge = torch.gather(edge_inputs, 1, current_index.repeat(1, 1, k_size))
279
+ current_edge = current_edge.permute(0, 2, 1)
280
+ embedding_dim = enhanced_node_feature.size()[2]
281
+
282
+ neigboring_feature = torch.gather(enhanced_node_feature, 1, current_edge.repeat(1, 1, embedding_dim))
283
+
284
+ current_node_feature = torch.gather(enhanced_node_feature, 1, current_index.repeat(1, 1, embedding_dim))
285
+
286
+ enhanced_current_node_feature, attention_weights = self.decoder(current_node_feature, enhanced_node_feature, node_padding_mask)
287
+ action_features = torch.cat((enhanced_current_node_feature.repeat(1, k_size, 1), current_node_feature.repeat(1, k_size, 1), neigboring_feature), dim=-1)
288
+ action_features = self.action_embedding(action_features)
289
+ q_values = self.q_values_layer(action_features)
290
+
291
+ if edge_padding_mask is not None:
292
+ current_mask = torch.gather(edge_padding_mask, 1, current_index.repeat(1, 1, k_size)).to(
293
+ enhanced_node_feature.device)
294
+ else:
295
+ current_mask = None
296
+ current_mask[:, :, 0] = 1 # don't stay at current position
297
+
298
+ if not 0 in current_mask:
299
+ current_mask[:,:,0] = 0
300
+
301
+ current_mask = current_mask.permute(0, 2, 1)
302
+ zero = torch.zeros_like(q_values).to(q_values.device)
303
+ q_values = torch.where(current_mask == 1, zero, q_values)
304
+
305
+ return q_values, attention_weights
306
+
307
+ def forward(self, node_inputs, edge_inputs, current_index, node_padding_mask=None, edge_padding_mask=None,
308
+ edge_mask=None):
309
+ enhanced_node_feature = self.encode_graph(node_inputs, node_padding_mask, edge_mask)
310
+ q_values, attention_weights = self.output_q_values(enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask)
311
+ return q_values, attention_weights
312
+
planner/node.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #######################################################################
2
+ # Name: node.py
3
+ #
4
+ # - Contains info per node on graph (edge)
5
+ # - Contains: Position, Utility, Visitation History
6
+ #######################################################################
7
+
8
+ import sys
9
+ if sys.modules['TRAINING']:
10
+ from .parameter import *
11
+ else:
12
+ from .test_parameter import *
13
+
14
+ import numpy as np
15
+ import shapely.geometry
16
+
17
+
18
+ class Node():
19
+ def __init__(self, coords, frontiers, robot_belief):
20
+ self.coords = coords
21
+ self.observable_frontiers = []
22
+ self.sensor_range = SENSOR_RANGE
23
+ self.initialize_observable_frontiers(frontiers, robot_belief)
24
+ self.utility = self.get_node_utility()
25
+ if self.utility == 0:
26
+ self.zero_utility_node = True
27
+ else:
28
+ self.zero_utility_node = False
29
+
30
+ def initialize_observable_frontiers(self, frontiers, robot_belief):
31
+ dist_list = np.linalg.norm(frontiers - self.coords, axis=-1)
32
+ frontiers_in_range = frontiers[dist_list < self.sensor_range - 10]
33
+ for point in frontiers_in_range:
34
+ collision = self.check_collision(self.coords, point, robot_belief)
35
+ if not collision:
36
+ self.observable_frontiers.append(point)
37
+
38
+ def get_node_utility(self):
39
+ return len(self.observable_frontiers)
40
+
41
+ def update_observable_frontiers(self, observed_frontiers, new_frontiers, robot_belief):
42
+ if observed_frontiers != []:
43
+ observed_index = []
44
+ for i, point in enumerate(self.observable_frontiers):
45
+ if point[0] + point[1] * 1j in observed_frontiers[:, 0] + observed_frontiers[:, 1] * 1j:
46
+ observed_index.append(i)
47
+ for index in reversed(observed_index):
48
+ self.observable_frontiers.pop(index)
49
+ #
50
+ if new_frontiers != []:
51
+ dist_list = np.linalg.norm(new_frontiers - self.coords, axis=-1)
52
+ new_frontiers_in_range = new_frontiers[dist_list < self.sensor_range - 15]
53
+ for point in new_frontiers_in_range:
54
+ collision = self.check_collision(self.coords, point, robot_belief)
55
+ if not collision:
56
+ self.observable_frontiers.append(point)
57
+
58
+ self.utility = self.get_node_utility()
59
+ if self.utility == 0:
60
+ self.zero_utility_node = True
61
+ else:
62
+ self.zero_utility_node = False
63
+
64
+ def set_visited(self):
65
+ self.observable_frontiers = []
66
+ self.utility = 0
67
+ self.zero_utility_node = True
68
+
69
+ def check_collision(self, start, end, robot_belief):
70
+ collision = False
71
+ line = shapely.geometry.LineString([start, end])
72
+
73
+ sortx = np.sort([start[0], end[0]])
74
+ sorty = np.sort([start[1], end[1]])
75
+
76
+ robot_belief = robot_belief[sorty[0]:sorty[1] + 1, sortx[0]:sortx[1] + 1]
77
+
78
+ occupied_area_index = np.where(robot_belief == 1)
79
+ occupied_area_coords = np.asarray(
80
+ [occupied_area_index[1] + sortx[0], occupied_area_index[0] + sorty[0]]).T
81
+ unexplored_area_index = np.where(robot_belief == 127)
82
+ unexplored_area_coords = np.asarray(
83
+ [unexplored_area_index[1] + sortx[0], unexplored_area_index[0] + sorty[0]]).T
84
+ unfree_area_coords = np.concatenate((occupied_area_coords, unexplored_area_coords))
85
+
86
+ for i in range(unfree_area_coords.shape[0]):
87
+ coords = ([(unfree_area_coords[i][0], unfree_area_coords[i][1]),
88
+ (unfree_area_coords[i][0] + 1, unfree_area_coords[i][1]),
89
+ (unfree_area_coords[i][0], unfree_area_coords[i][1] + 1),
90
+ (unfree_area_coords[i][0] + 1, unfree_area_coords[i][1] + 1)])
91
+ obstacle = shapely.geometry.Polygon(coords)
92
+ collision = line.intersects(obstacle)
93
+ if collision:
94
+ break
95
+
96
+ return collision
planner/robot.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #######################################################################
2
+ # Name: robot.py
3
+ #
4
+ # - Stores S(t), A(t), R(t), S(t+1)
5
+ #######################################################################
6
+
7
+ import torch
8
+ from copy import deepcopy
9
+
10
+ class Robot:
11
+ def __init__(self, robot_id, position, plot=False):
12
+ self.robot_id = robot_id
13
+ self.plot = plot
14
+ self.travel_dist = 0
15
+ self.robot_position = position
16
+ self.observations = None
17
+ self.trajectory_coords = []
18
+ self.targets_found_on_path = []
19
+
20
+ self.episode_buffer = []
21
+ for i in range(15):
22
+ self.episode_buffer.append([])
23
+
24
+ if self.plot:
25
+ self.xPoints = [self.robot_position[0]]
26
+ self.yPoints = [self.robot_position[1]]
27
+
28
+ def save_observations(self, observations):
29
+ node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask = observations
30
+ self.episode_buffer[0] += deepcopy(node_inputs).to('cpu')
31
+ self.episode_buffer[1] += deepcopy(edge_inputs).to('cpu')
32
+ self.episode_buffer[2] += deepcopy(current_index).to('cpu')
33
+ self.episode_buffer[3] += deepcopy(node_padding_mask).to('cpu')
34
+ self.episode_buffer[4] += deepcopy(edge_padding_mask).to('cpu')
35
+ self.episode_buffer[5] += deepcopy(edge_mask).to('cpu')
36
+
37
+ def save_action(self, action_index):
38
+ self.episode_buffer[6] += action_index.unsqueeze(0).unsqueeze(0)
39
+
40
+ def save_reward_done(self, reward, done):
41
+ self.episode_buffer[7] += deepcopy(torch.FloatTensor([[[reward]]])).to('cpu')
42
+ self.episode_buffer[8] += deepcopy(torch.tensor([[[(int(done))]]])).to('cpu')
43
+ if self.plot:
44
+ self.xPoints.append(self.robot_position[0])
45
+ self.yPoints.append(self.robot_position[1])
46
+
47
+ def save_next_observations(self, observations):
48
+ node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask = observations
49
+ self.episode_buffer[9] += deepcopy(node_inputs).to('cpu')
50
+ self.episode_buffer[10] += deepcopy(edge_inputs).to('cpu')
51
+ self.episode_buffer[11] += deepcopy(current_index).to('cpu')
52
+ self.episode_buffer[12] += deepcopy(node_padding_mask).to('cpu')
53
+ self.episode_buffer[13] += deepcopy(edge_padding_mask).to('cpu')
54
+ self.episode_buffer[14] += deepcopy(edge_mask).to('cpu')
55
+
56
+ def save_trajectory_coords(self, robot_position_coords, num_target_found):
57
+ self.trajectory_coords.append(robot_position_coords)
58
+ self.targets_found_on_path.append(num_target_found)
planner/sensor.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #######################################################################
2
+ # Name: sensor.py
3
+ #
4
+ # - Computes sensor related checks (e.g. collision, utility etc)
5
+ #######################################################################
6
+
7
+ import sys
8
+ if sys.modules['TRAINING']:
9
+ from .parameter import *
10
+ else:
11
+ from .test_parameter import *
12
+
13
+ import math
14
+ import numpy as np
15
+ import copy
16
+
17
+ def collision_check(x0, y0, x1, y1, ground_truth, robot_belief):
18
+ x0 = x0.round()
19
+ y0 = y0.round()
20
+ x1 = x1.round()
21
+ y1 = y1.round()
22
+ dx, dy = abs(x1 - x0), abs(y1 - y0)
23
+ x, y = x0, y0
24
+ error = dx - dy
25
+ x_inc = 1 if x1 > x0 else -1
26
+ y_inc = 1 if y1 > y0 else -1
27
+ dx *= 2
28
+ dy *= 2
29
+
30
+ collision_flag = 0
31
+ max_collision = 10
32
+
33
+ while 0 <= x < ground_truth.shape[1] and 0 <= y < ground_truth.shape[0]:
34
+ k = ground_truth.item(y, x)
35
+ if k == 1 and collision_flag < max_collision:
36
+ collision_flag += 1
37
+ if collision_flag >= max_collision:
38
+ break
39
+
40
+ if k !=1 and collision_flag > 0:
41
+ break
42
+
43
+ if x == x1 and y == y1:
44
+ break
45
+
46
+ robot_belief.itemset((y, x), k)
47
+
48
+ if error > 0:
49
+ x += x_inc
50
+ error -= dy
51
+ else:
52
+ y += y_inc
53
+ error += dx
54
+
55
+ return robot_belief
56
+
57
+
58
+ def sensor_work(robot_position, sensor_range, robot_belief, ground_truth, sensor_model=SENSOR_MODEL):
59
+ x0 = robot_position[0]
60
+ y0 = robot_position[1]
61
+ rng_x = 0.5 * (ground_truth.shape[1] / NUM_COORDS_WIDTH)
62
+ rng_y = 0.5 * (ground_truth.shape[0] / NUM_COORDS_HEIGHT)
63
+
64
+ if sensor_model == "rectangular": # TODO: add collision check
65
+ max_x = min(x0 + int(math.ceil(rng_x)), ground_truth.shape[1])
66
+ min_x = max(x0 - int(math.ceil(rng_x)), 0)
67
+ max_y = min(y0 + int(math.ceil(rng_y)), ground_truth.shape[0])
68
+ min_y = max(y0 - int(math.ceil(rng_y)), 0)
69
+ robot_belief[min_y:max_y, min_x:max_x] = ground_truth[min_y:max_y, min_x:max_x]
70
+ else:
71
+ sensor_angle_inc = 0.5 / 180 * np.pi
72
+ sensor_angle = 0
73
+ while sensor_angle < 2 * np.pi:
74
+ x1 = x0 + np.cos(sensor_angle) * sensor_range
75
+ y1 = y0 + np.sin(sensor_angle) * sensor_range
76
+ robot_belief = collision_check(x0, y0, x1, y1, ground_truth, robot_belief)
77
+ sensor_angle += sensor_angle_inc
78
+ return robot_belief
79
+
80
+
81
+ def unexplored_area_check(x0, y0, x1, y1, current_belief):
82
+ x0 = x0.round()
83
+ y0 = y0.round()
84
+ x1 = x1.round()
85
+ y1 = y1.round()
86
+ dx, dy = abs(x1 - x0), abs(y1 - y0)
87
+ x, y = x0, y0
88
+ error = dx - dy
89
+ x_inc = 1 if x1 > x0 else -1
90
+ y_inc = 1 if y1 > y0 else -1
91
+ dx *= 2
92
+ dy *= 2
93
+
94
+ while 0 <= x < current_belief.shape[1] and 0 <= y < current_belief.shape[0]:
95
+ k = current_belief.item(y, x)
96
+ if x == x1 and y == y1:
97
+ break
98
+
99
+ if k == 1:
100
+ break
101
+
102
+ if k == 127:
103
+ current_belief.itemset((y, x), 0)
104
+ break
105
+
106
+ if error > 0:
107
+ x += x_inc
108
+ error -= dy
109
+ else:
110
+ y += y_inc
111
+ error += dx
112
+
113
+ return current_belief
114
+
115
+
116
+ def calculate_utility(waypoint_position, sensor_range, robot_belief):
117
+ sensor_angle_inc = 5 / 180 * np.pi
118
+ sensor_angle = 0
119
+ x0 = waypoint_position[0]
120
+ y0 = waypoint_position[1]
121
+ current_belief = copy.deepcopy(robot_belief)
122
+ while sensor_angle < 2 * np.pi:
123
+ x1 = x0 + np.cos(sensor_angle) * sensor_range
124
+ y1 = y0 + np.sin(sensor_angle) * sensor_range
125
+ current_belief = unexplored_area_check(x0, y0, x1, y1, current_belief)
126
+ sensor_angle += sensor_angle_inc
127
+ utility = np.sum(robot_belief == 127) - np.sum(current_belief == 127)
128
+ return utility
planner/test_info_surfing.py ADDED
@@ -0,0 +1,1071 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #######################################################################
2
+ # Name: test_info_surfing.py
3
+ #
4
+ # - Runs robot in environment using Info Surfing Planner
5
+ #######################################################################
6
+
7
+ import sys
8
+ sys.modules['TRAINING'] = False # False = Inference Testing
9
+
10
+ import copy
11
+ import os
12
+ import imageio
13
+ import numpy as np
14
+ import matplotlib.pyplot as plt
15
+ from pathlib import Path
16
+ from time import time
17
+ from types import SimpleNamespace
18
+ from skimage.transform import resize
19
+ from taxabind_avs.satbind.kmeans_clustering import CombinedSilhouetteInertiaClusterer
20
+ from .env import Env
21
+ from .test_parameter import *
22
+
23
+
24
+ OPPOSITE_ACTIONS = {1: 3, 2: 4, 3: 1, 4: 2, 5: 7, 6: 8, 7: 5, 8: 6}
25
+ # color
26
+ agentColor = (1, 0.2, 0.6)
27
+ agentCommColor = (1, 0.6, 0.2)
28
+ obstacleColor = (0., 0., 0.)
29
+ targetNotFound = (0., 1., 0.)
30
+ targetFound = (0.545, 0.27, 0.075)
31
+ highestProbColor = (1., 0., 0.)
32
+ highestUncertaintyColor = (0., 0., 1.)
33
+ lowestProbColor = (1., 1., 1.)
34
+
35
+
36
+ class ISEnv:
37
+ """Custom Environment that follows gym interface"""
38
+ metadata = {'render.modes': ['human']}
39
+
40
+ def __init__(self, global_step=0, state=None, shape=(24, 24), numAgents=8, observationSize=11, sensorSize=1, diag=False, save_image=False, clip_seg_tta=None):
41
+
42
+ self.global_step = global_step
43
+ self.infoMap = None
44
+ self.targetMap = None
45
+ self.agents = []
46
+ self.targets = []
47
+ self.numAgents = numAgents
48
+ self.found_target = []
49
+ self.shape = shape
50
+ self.observationSize = observationSize
51
+ self.sensorSize = sensorSize
52
+ self.diag = diag
53
+ self.communicateCircle = 11
54
+ self.distribs = []
55
+ self.mask = None
56
+ self.finished = False
57
+ self.action_vects = [[-1., 0.], [0., 1.], [1., 0], [0., -1.]] if not diag else [[-1., 0.], [0., 1.], [1., 0], [0., -1.], [-0.707, -0.707], [-0.707, 0.707], [0.707, 0.707], [0.707, -0.707]]
58
+ self.actionlist = []
59
+ self.IS_step = 0
60
+ self.save_image = save_image
61
+ self.clip_seg_tta = clip_seg_tta
62
+ self.perf_metrics = dict()
63
+ self.steps_to_first_tgt = None
64
+ self.steps_to_mid_tgt = None
65
+ self.steps_to_last_tgt = None
66
+ self.targets_found_on_path = []
67
+ self.step_since_tta = 0
68
+ self.IS_frame_files = []
69
+ self.bad_mask_init = False
70
+
71
+ # define env
72
+ self.env = Env(map_index=self.global_step, n_agent=numAgents, k_size=K_SIZE, plot=save_image, test=True)
73
+
74
+ # Overwrite state
75
+ if self.clip_seg_tta is not None:
76
+ self.clip_seg_tta.reset(sample_idx=self.global_step)
77
+
78
+ # Override target positions in env
79
+ self.env.target_positions = [(pose[1], pose[0]) for pose in self.clip_seg_tta.target_positions]
80
+
81
+ # Override segmentation mask
82
+ if not USE_CLIP_PREDS and OVERRIDE_MASK_DIR != "":
83
+ score_mask_path = os.path.join(OVERRIDE_MASK_DIR, self.clip_seg_tta.gt_mask_name)
84
+ print("score_mask_path: ", score_mask_path)
85
+ if os.path.exists(score_mask_path):
86
+ self.env.segmentation_mask = self.env.import_segmentation_mask(score_mask_path)
87
+ self.env.begin(self.env.map_start_position)
88
+ else:
89
+ print(f"\n\n{RED}ERROR: Trying to override, but score mask not found at path:{NC} ", score_mask_path)
90
+ self.bad_mask_init = True
91
+
92
+ # Save clustered embeds from sat encoder
93
+ if USE_CLIP_PREDS:
94
+ self.kmeans_clusterer = CombinedSilhouetteInertiaClusterer(
95
+ k_min=1,
96
+ k_max=8,
97
+ k_avg_max=4,
98
+ silhouette_threshold=0.15,
99
+ relative_threshold=0.15,
100
+ random_state=0,
101
+ min_patch_size=5,
102
+ n_smooth_iter=2,
103
+ ignore_label=-1,
104
+ plot=self.save_image,
105
+ gifs_dir = GIFS_PATH
106
+ )
107
+ # Generate kmeans clusters
108
+ self.kmeans_sat_embeds_clusters = self.kmeans_clusterer.fit_predict(
109
+ patch_embeds=self.clip_seg_tta.patch_embeds,
110
+ map_shape=(CLIP_GRIDS_DIMS[0], CLIP_GRIDS_DIMS[1]),
111
+ )
112
+
113
+ if EXECUTE_TTA:
114
+ print("Will execute TTA...")
115
+
116
+ IS_info_map = copy.deepcopy(self.env.segmentation_info_mask)
117
+ IS_agent_loc = copy.deepcopy(self.env.start_positions)
118
+ IS_target_loc = copy.deepcopy(self.env.target_positions)
119
+ state=[IS_info_map, IS_agent_loc, IS_target_loc]
120
+ self.setWorld(state)
121
+
122
+
123
+ def init_render(self):
124
+ """
125
+ Call this once (e.g., in __init__ or just before the scenario loop)
126
+ to initialize storage for agent paths and turn interactive plotting on.
127
+ """
128
+ # Keep track of each agent's trajectory
129
+ self.trajectories = [[] for _ in range(self.numAgents)]
130
+ self.trajectories_upscaled = [[] for _ in range(self.numAgents)]
131
+
132
+ # Turn on interactive mode so we can update the same figure repeatedly
133
+ plt.ion()
134
+ plt.figure(figsize=(6,6))
135
+ plt.title("Information Map with Agents, Targets, and Sensor Ranges")
136
+
137
+
138
+ def record_positions(self):
139
+ """
140
+ Call this after all agents have moved in a step (or whenever you want to update
141
+ the trajectory). It appends the current positions of each agent to `self.trajectories`.
142
+ """
143
+ for idx, agent in enumerate(self.agents):
144
+ self.trajectories[idx].append((agent.row, agent.col))
145
+ self.trajectories_upscaled[idx].append(self.env.graph_generator.grid_coords[agent.row, agent.col])
146
+
147
+
148
+ def render(self, episode_num, step_num):
149
+ """
150
+ Renders the current state in a single matplotlib plot.
151
+ Ensures consistent image size for GIF generation.
152
+ """
153
+
154
+ # Completely reset the figure to avoid leftover state
155
+ plt.close('all')
156
+ fig = plt.figure(figsize=(6.4, 4.8), dpi=100)
157
+ ax = fig.add_subplot(111)
158
+
159
+ # Plot the information map
160
+ ax.imshow(self.infoMap, origin='lower', cmap='gray')
161
+
162
+ # Show agent positions and their trajectories
163
+ for idx, agent in enumerate(self.agents):
164
+ positions = self.trajectories[idx]
165
+ if len(positions) > 1:
166
+ rows = [p[0] for p in positions]
167
+ cols = [p[1] for p in positions]
168
+ ax.plot(cols, rows, linewidth=1)
169
+
170
+ ax.scatter(agent.col, agent.row, marker='o', s=50)
171
+
172
+ # Plot target locations
173
+ for t in self.targets:
174
+ color = 'green' if np.isnan(t.time_found) else 'red'
175
+ ax.scatter(t.col, t.row, marker='x', s=100, color=color)
176
+
177
+ # Title and axis formatting
178
+ ax.set_title(f"Step: {self.IS_step}")
179
+ ax.invert_yaxis()
180
+
181
+ # Create output folder if it doesn't exist
182
+ if not os.path.exists(GIFS_PATH):
183
+ os.makedirs(GIFS_PATH)
184
+
185
+ # Save the frame with consistent canvas
186
+ frame_path = f'{GIFS_PATH}/IS_{episode_num}_{step_num}.png'
187
+ plt.savefig(frame_path, bbox_inches='tight', pad_inches=0.1)
188
+ self.IS_frame_files.append(frame_path)
189
+
190
+ # Cleanup
191
+ plt.close(fig)
192
+
193
+
194
+ def setWorld(self, state=None):
195
+ """
196
+ 1. empty all the element
197
+ 2. create the new episode
198
+ """
199
+ if state is not None:
200
+ self.infoMap = copy.deepcopy(state[0].reshape(self.shape).T)
201
+ agents = []
202
+ self.numAgents = len(state[1])
203
+ for a in range(1, self.numAgents + 1):
204
+ abs_pos = state[1].pop(0)
205
+ abs_pos = np.array(abs_pos)
206
+ row, col = self.env.graph_generator.find_closest_index_from_grid_coords_2d(np.array(abs_pos))
207
+ agents.append(Agent(ID=a, row=row, col=col, sensorSize=self.sensorSize, infoMap=np.copy(self.infoMap),
208
+ uncertaintyMap=np.copy(self.infoMap), shape=self.shape, numAgents=self.numAgents))
209
+ self.agents = agents
210
+
211
+ targets, n_targets = [], 1
212
+ for t in range(len(state[2])):
213
+ abs_pos = state[2].pop(0)
214
+ abs_pos = np.array(abs_pos)
215
+ row, col = self.env.graph_generator.find_closest_index_from_grid_coords_2d(abs_pos)
216
+ targets.append(Target(ID=n_targets, row=row, col=col, time_found=np.nan))
217
+ n_targets = n_targets + 1
218
+ self.targets = targets
219
+
220
+ def extractObservation(self, agent):
221
+ """
222
+ Extract observations from information map
223
+ """
224
+
225
+ transform_row = self.observationSize // 2 - agent.row
226
+ transform_col = self.observationSize // 2 - agent.col
227
+
228
+ observation_layers = np.zeros((1, self.observationSize, self.observationSize))
229
+ min_row = max((agent.row - self.observationSize // 2), 0)
230
+ max_row = min((agent.row + self.observationSize // 2 + 1), self.shape[0])
231
+ min_col = max((agent.col - self.observationSize // 2), 0)
232
+ max_col = min((agent.col + self.observationSize // 2 + 1), self.shape[1])
233
+
234
+ observation = np.full((self.observationSize, self.observationSize), 0.)
235
+ infoMap = np.full((self.observationSize, self.observationSize), 0.)
236
+ densityMap = np.full((self.observationSize, self.observationSize), 0.)
237
+
238
+ infoMap[(min_row + transform_row):(max_row + transform_row),
239
+ (min_col + transform_col):(max_col + transform_col)] = self.infoMap[
240
+ min_row:max_row, min_col:max_col]
241
+ observation_layers[0] = infoMap
242
+
243
+ return observation_layers
244
+
245
+
246
+ def listNextValidActions(self, agent_id, prev_action=0):
247
+ """
248
+ No movement: 0
249
+ North (-1,0): 1
250
+ East (0,1): 2
251
+ South (1,0): 3
252
+ West (0,-1): 4
253
+ """
254
+ available_actions = [0]
255
+ agent = self.agents[agent_id - 1]
256
+
257
+ MOVES = [(-1, 0), (0, 1), (1, 0), (0, -1), (-1, -1), (-1, 1), (1, 1), (1, -1)]
258
+ size = 4 + self.diag * 4
259
+ for action in range(size):
260
+ out_of_bounds = agent.row + MOVES[action][0] >= self.shape[0] \
261
+ or agent.row + MOVES[action][0] < 0\
262
+ or agent.col + MOVES[action][1] >= self.shape[1] \
263
+ or agent.col + MOVES[action][1] < 0
264
+
265
+ if (not out_of_bounds) and not (prev_action == OPPOSITE_ACTIONS[action + 1]):
266
+ available_actions.append(action + 1)
267
+
268
+ return np.array(available_actions)
269
+
270
+
271
+ def executeAction(self, agentID, action, timeStep):
272
+ """
273
+ No movement: 0
274
+ North (-1,0): 1
275
+ East (0,1): 2
276
+ South (1,0): 3
277
+ West (0,-1): 4
278
+ LeftUp (-1,-1) : 5
279
+ RightUP (-1,1) :6
280
+ RightDown (1,1) :7
281
+ RightLeft (1,-1) :8
282
+ """
283
+ agent = self.agents[agentID - 1]
284
+ origLoc = agent.getLocation()
285
+
286
+ if (action >= 1) and (action <= 8):
287
+ agent.move(action)
288
+ row, col = agent.getLocation()
289
+
290
+ # If the move is not valid, roll it back
291
+ if (row < 0) or (col < 0) or (row >= self.shape[0]) or (col >= self.shape[1]):
292
+ self.updateInfoCheckTarget(agentID, timeStep, origLoc)
293
+ return 0
294
+
295
+ elif action == 0:
296
+ self.updateInfoCheckTarget(agentID, timeStep, origLoc)
297
+ return 0
298
+
299
+ else:
300
+ print("INVALID ACTION: {}".format(action))
301
+ sys.exit()
302
+
303
+ newLoc = agent.getLocation()
304
+ self.updateInfoCheckTarget(agentID, timeStep, origLoc)
305
+ return action
306
+
307
+
308
+ def updateInfoCheckTarget(self, agentID, timeStep, origLoc):
309
+ """
310
+ update the self.infoMap and check whether the agent has found a target
311
+ """
312
+ agent = self.agents[agentID - 1]
313
+ transform_row = self.sensorSize // 2 - agent.row
314
+ transform_col = self.sensorSize // 2 - agent.col
315
+
316
+ min_row = max((agent.row - self.sensorSize // 2), 0)
317
+ max_row = min((agent.row + self.sensorSize // 2 + 1), self.shape[0])
318
+ min_col = max((agent.col - self.sensorSize // 2), 0)
319
+ max_col = min((agent.col + self.sensorSize // 2 + 1), self.shape[1])
320
+ for t in self.targets:
321
+ if (t.row == agent.row) and (t.col == agent.col):
322
+ t.updateFound(timeStep)
323
+ self.found_target.append(t)
324
+ t.status = True
325
+
326
+ self.infoMap[min_row:max_row, min_col:max_col] *= 0.05
327
+
328
+
329
+ def updateInfoEntireTrajectory(self, agentID):
330
+ """
331
+ update the self.infoMap and check whether the agent has found a target
332
+ """
333
+ traj = self.trajectories[agentID - 1]
334
+
335
+ for (row,col) in traj:
336
+ min_row = max((row - self.sensorSize // 2), 0)
337
+ max_row = min((row + self.sensorSize // 2 + 1), self.shape[0])
338
+ min_col = max((col - self.sensorSize // 2), 0)
339
+ max_col = min((col + self.sensorSize // 2 + 1), self.shape[1])
340
+ self.infoMap[min_row:max_row, min_col:max_col] *= 0.05
341
+
342
+
343
+ # Execute one time step within the environment
344
+ def step(self, agentID, action, timeStep):
345
+ """
346
+ the agents execute the actions
347
+ No movement: 0
348
+ North (-1,0): 1
349
+ East (0,1): 2
350
+ South (1,0): 3
351
+ West (0,-1): 4
352
+ """
353
+ assert (agentID > 0)
354
+
355
+ self.executeAction(agentID, action, timeStep)
356
+
357
+
358
+ def observe(self, agentID):
359
+ assert (agentID > 0)
360
+ vectorObs = self.extractObservation(self.agents[agentID - 1])
361
+ return [vectorObs]
362
+
363
+
364
+ def check_finish(self):
365
+ if TERMINATE_ON_TGTS_FOUND:
366
+ found_status = [t.time_found for t in self.targets]
367
+ d = False
368
+ if np.isnan(found_status).sum() == 0:
369
+ d = True
370
+ return d
371
+ else:
372
+ return False
373
+
374
+
375
+ def gradVec(self, observation, agent):
376
+ a = observation[0]
377
+
378
+ # Make info & unc cells with low value as 0
379
+ a[a < 0.0002] = 0.0
380
+
381
+ # Center square from 11x11
382
+ a_11x11 = a[4:7, 4:7]
383
+ m_11x11 = np.array((a_11x11))
384
+
385
+ # Center square from 9x9
386
+ a_9x9 = self.pooling(a, (3, 3), stride=(1, 1), method='max', pad=False)
387
+ a_9x9 = a_9x9[3:6, 3:6]
388
+ m_9x9 = np.array((a_9x9))
389
+
390
+ # Center square from 6x6
391
+ a_6x6 = self.pooling(a, (6, 6), stride=(1, 1), method='max', pad=False)
392
+ a_6x6 = a_6x6[1:4, 1:4]
393
+ m_6x6 = np.array((a_6x6))
394
+
395
+ # Center square from 3x3
396
+ a_3x3 = self.pooling(a, (5, 5), stride=(3, 3), method='max', pad=False)
397
+ m_3x3 = np.array((a_3x3))
398
+
399
+ # Merging multiScales with weights
400
+ m = m_3x3 * 0.25 + m_6x6 * 0.25 + m_9x9 * 0.25 + m_11x11 * 0.25
401
+ a = m
402
+
403
+ adx, ady = np.gradient(a)
404
+ den = np.linalg.norm(np.array([adx[1, 1], ady[1, 1]]))
405
+ if (den != 0) and (not np.isnan(den)):
406
+ infovec = np.array([adx[1, 1], ady[1, 1]]) / den
407
+ else:
408
+ infovec = 0
409
+ agentvec = []
410
+
411
+ if len(agentvec) == 0:
412
+ den = np.linalg.norm(infovec)
413
+ if (den != 0) and (not np.isnan(den)):
414
+ direction = infovec / den
415
+ else:
416
+ direction = self.action_vects[np.random.randint(4 + self.diag * 4)]
417
+ else:
418
+ den = np.linalg.norm(np.mean(agentvec, 0))
419
+ if (den != 0) and (not np.isnan(den)):
420
+ agentvec = np.mean(agentvec, 0) / den
421
+ else:
422
+ agentvec = 0
423
+
424
+ den = np.linalg.norm(0.6 * infovec + 0.4 * agentvec)
425
+ if (den != 0) and (not np.isnan(den)):
426
+ direction = (0.6 * infovec + 0.4 * agentvec) / den
427
+ else:
428
+ direction = self.action_vects[np.random.randint(4 + self.diag * 4)]
429
+
430
+ action_vec = [[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]] if not self.diag else [[0., 0.], [-1., 0.], [0., 1.], [1., 0], [0., -1.], [-0.707, -0.707], [-0.707, 0.707], [0.707, 0.707], [0.707, -0.707]]
431
+ actionid = np.argmax([np.dot(direction, a) for a in action_vec])
432
+ actionid = self.best_valid_action(actionid, agent, direction)
433
+ return actionid
434
+
435
+
436
+ def best_valid_action(self, actionid, agent, direction):
437
+ if len(self.actionlist) > 1:
438
+ if self.action_invalid(actionid, agent):
439
+ action_vec = [[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]] if not self.diag else [[0., 0.], [-1., 0.], [0., 1.], [1., 0], [0., -1.], [-0.707, -0.707], [-0.707, 0.707], [0.707, 0.707], [0.707, -0.707]]
440
+ actionid = np.array([np.dot(direction, a) for a in action_vec])
441
+ actionid = actionid.argsort()
442
+ pi = 3 + self.diag*4
443
+ while self.action_invalid(actionid[pi], agent) and pi >= 0:
444
+ pi -= 1
445
+ if pi == -1:
446
+ return OPPOSITE_ACTIONS[self.actionlist[self.IS_step - 1][agent - 1]]
447
+ elif actionid[pi] == 0:
448
+ return OPPOSITE_ACTIONS[self.actionlist[self.IS_step - 1][agent - 1]]
449
+ else:
450
+ return actionid[pi]
451
+ return actionid
452
+
453
+
454
+ def action_invalid(self, action, agent):
455
+ # Going back to the previous cell is disabled
456
+ if action == OPPOSITE_ACTIONS[self.actionlist[self.IS_step - 1][agent - 1]]:
457
+ return True
458
+ # Move N,E,S,W
459
+ if (action >= 1) and (action <= 8):
460
+ agent = self.agents[agent - 1]
461
+ agent.move(action)
462
+ row, col = agent.getLocation()
463
+
464
+ # If the move is not valid, roll it back
465
+ if ((row < 0) or (col < 0) or (row >= self.shape[0]) or (col >= self.shape[1])):
466
+ agent.reverseMove(action)
467
+ return True
468
+
469
+ agent.reverseMove(action)
470
+ return False
471
+ return False
472
+
473
+
474
+ def step_all_parallel(self):
475
+ actions = []
476
+ reward = 0
477
+ # Decide actions for each agent
478
+ for agent_id in range(1, self.numAgents + 1):
479
+ o = self.observe(agent_id)
480
+ actions.append(self.gradVec(o[0], agent_id))
481
+ self.actionlist.append(actions)
482
+
483
+ # Execute those actions
484
+ for agent_id in range(1, self.numAgents + 1):
485
+ self.step(agent_id, actions[agent_id - 1], self.IS_step)
486
+
487
+ # Record for visualization
488
+ self.record_positions()
489
+
490
+ def is_scenario(self, max_step=512, episode_number=0):
491
+
492
+ # Return all metrics as None if faulty mask init
493
+ if self.bad_mask_init:
494
+ self.perf_metrics['tax'] = None
495
+ self.perf_metrics['travel_dist'] = None
496
+ self.perf_metrics['travel_steps'] = None
497
+ self.perf_metrics['steps_to_first_tgt'] = None
498
+ self.perf_metrics['steps_to_mid_tgt'] = None
499
+ self.perf_metrics['steps_to_last_tgt'] = None
500
+ self.perf_metrics['explored_rate'] = None
501
+ self.perf_metrics['targets_found'] = None
502
+ self.perf_metrics['targets_total'] = None
503
+ self.perf_metrics['kmeans_k'] = None
504
+ self.perf_metrics['tgts_gt_score'] = None
505
+ self.perf_metrics['clip_inference_time'] = None
506
+ self.perf_metrics['tta_time'] = None
507
+ self.perf_metrics['success_rate'] = None
508
+ return
509
+
510
+ eps_start = time()
511
+ self.IS_step = 0
512
+ self.finished = False
513
+ reward = 0
514
+
515
+ # Initialize the rendering just once before the loop
516
+ self.init_render()
517
+ self.record_positions()
518
+
519
+ # Initial Setup
520
+ if LOAD_AVS_BENCH and USE_CLIP_PREDS:
521
+ if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]: # If heatmap is resized from clip original dims
522
+ heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT))
523
+ self.env.segmentation_info_mask = np.expand_dims(heatmap.T.flatten(), axis=1)
524
+ unnormalized_heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap_unnormalized, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT))
525
+ self.env.segmentation_info_mask_unnormalized = np.expand_dims(unnormalized_heatmap.T.flatten(), axis=1)
526
+ self.infoMap = copy.deepcopy(heatmap)
527
+ print("Resized heatmap to", NUM_COORDS_WIDTH, "x", NUM_COORDS_HEIGHT)
528
+ else:
529
+ self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1)
530
+ self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1)
531
+ self.infoMap = copy.deepcopy(self.clip_seg_tta.heatmap)
532
+
533
+ self.targets_found_on_path.append(self.env.num_new_targets_found)
534
+
535
+ while self.IS_step < max_step and not self.check_finish():
536
+ self.step_all_parallel()
537
+ self.IS_step += 1
538
+
539
+ # Render after each step
540
+ if self.save_image:
541
+ self.render(episode_num=self.global_step, step_num=self.IS_step)
542
+
543
+ # Update in env
544
+ next_position_list = [self.trajectories_upscaled[i][-1] for i, agent in enumerate(self.agents)]
545
+ dist_list = [0 for _ in range(self.numAgents)]
546
+ travel_dist_list = [self.compute_travel_distance(traj) for traj in self.trajectories]
547
+ self.env.multi_robot_step(next_position_list, dist_list, travel_dist_list)
548
+ self.targets_found_on_path.append(self.env.num_new_targets_found)
549
+
550
+ # TTA Update via Poisson Test (with KMeans clustering stats)
551
+ robot_id = 0 # Assume 1 agent for now
552
+ robot_traj = self.trajectories[robot_id]
553
+ if LOAD_AVS_BENCH and USE_CLIP_PREDS and EXECUTE_TTA:
554
+ flat_traj_coords = [robot_traj[i][1] * self.shape[0] + robot_traj[i][0] for i in range(len(robot_traj))]
555
+ robot = SimpleNamespace(
556
+ trajectory_coords=flat_traj_coords,
557
+ targets_found_on_path=self.targets_found_on_path
558
+ )
559
+ self.poisson_tta_update(robot, self.global_step, self.IS_step)
560
+ self.infoMap = copy.deepcopy(self.env.segmentation_info_mask.reshape((self.shape[1],self.shape[0])).T)
561
+ self.updateInfoEntireTrajectory(robot_id)
562
+
563
+ # Update metrics
564
+ self.log_metrics(step=self.IS_step-1)
565
+
566
+ ### Save a frame to generate gif of robot trajectories ###
567
+ if self.save_image:
568
+ robots_route = [ ([], []) ] # Assume 1 robot
569
+ for point in self.trajectories_upscaled[robot_id]:
570
+ robots_route[robot_id][0].append(point[0])
571
+ robots_route[robot_id][1].append(point[1])
572
+ if not os.path.exists(GIFS_PATH):
573
+ os.makedirs(GIFS_PATH)
574
+ if LOAD_AVS_BENCH:
575
+ sound_id_override = None if self.clip_seg_tta.sound_ids == [] else self.clip_seg_tta.sound_ids[0]
576
+ self.env.plot_env(
577
+ self.global_step,
578
+ GIFS_PATH,
579
+ self.IS_step-1,
580
+ max(travel_dist_list),
581
+ robots_route,
582
+ img_path_override=self.clip_seg_tta.img_paths[0], # Viz 1st
583
+ sat_path_override=self.clip_seg_tta.imo_path,
584
+ msk_name_override=self.clip_seg_tta.species_name,
585
+ sound_id_override=sound_id_override,
586
+ )
587
+ else:
588
+ self.env.plot_env(
589
+ self.global_step,
590
+ GIFS_PATH,
591
+ self.IS_step-1,
592
+ max(travel_dist_list),
593
+ robots_route
594
+ )
595
+
596
+ # Log metrics
597
+ if LOAD_AVS_BENCH:
598
+ tax = Path(self.clip_seg_tta.gt_mask_name).stem
599
+ self.perf_metrics['tax'] = " ".join(tax.split("_")[1:])
600
+ else:
601
+ self.perf_metrics['tax'] = None
602
+ travel_distances = [self.compute_travel_distance(traj) for traj in self.trajectories]
603
+ self.perf_metrics['travel_dist'] = max(travel_distances)
604
+ self.perf_metrics['travel_steps'] = self.IS_step
605
+ self.perf_metrics['steps_to_first_tgt'] = self.steps_to_first_tgt
606
+ self.perf_metrics['steps_to_mid_tgt'] = self.steps_to_mid_tgt
607
+ self.perf_metrics['steps_to_last_tgt'] = self.steps_to_last_tgt
608
+ self.perf_metrics['explored_rate'] = self.env.explored_rate
609
+ self.perf_metrics['targets_found'] = self.env.targets_found_rate
610
+ self.perf_metrics['targets_total'] = len(self.env.target_positions)
611
+ if USE_CLIP_PREDS:
612
+ self.perf_metrics['kmeans_k'] = self.kmeans_clusterer.final_k
613
+ self.perf_metrics['tgts_gt_score'] = self.clip_seg_tta.tgts_gt_score
614
+ self.perf_metrics['clip_inference_time'] = self.clip_seg_tta.clip_inference_time
615
+ self.perf_metrics['tta_time'] = self.clip_seg_tta.tta_time
616
+ else:
617
+ self.perf_metrics['kmeans_k'] = None
618
+ self.perf_metrics['tgts_gt_score'] = None
619
+ self.perf_metrics['clip_inference_time'] = None
620
+ self.perf_metrics['tta_time'] = None
621
+ if FORCE_LOGGING_DONE_TGTS_FOUND and self.env.targets_found_rate == 1.0:
622
+ self.perf_metrics['success_rate'] = True
623
+ else:
624
+ self.perf_metrics['success_rate'] = self.env.check_done()[0]
625
+
626
+ # save gif
627
+ if self.save_image:
628
+ path = GIFS_PATH
629
+ self.make_gif(path, self.global_step)
630
+
631
+ print(YELLOW, f"[Eps {episode_number} Completed] Time Taken: {time()-eps_start:.2f}s, Steps: {self.IS_step}", NC)
632
+
633
+
634
+ def asStride(self, arr, sub_shape, stride):
635
+ """
636
+ Get a strided sub-matrices view of an ndarray.
637
+ See also skimage.util.shape.view_as_windows()
638
+ """
639
+ s0, s1 = arr.strides[:2]
640
+ m1, n1 = arr.shape[:2]
641
+ m2, n2 = sub_shape
642
+ view_shape = (1+(m1-m2)//stride[0], 1+(n1-n2)//stride[1], m2, n2)+arr.shape[2:]
643
+ strides = (stride[0]*s0, stride[1]*s1, s0, s1)+arr.strides[2:]
644
+ subs = np.lib.stride_tricks.as_strided(arr, view_shape, strides=strides)
645
+ return subs
646
+
647
+
648
+ def pooling(self, mat, ksize, stride=None, method='max', pad=False):
649
+ """
650
+ Overlapping pooling on 2D or 3D data.
651
+
652
+ <mat>: ndarray, input array to pool.
653
+ <ksize>: tuple of 2, kernel size in (ky, kx).
654
+ <stride>: tuple of 2 or None, stride of pooling window.
655
+ If None, same as <ksize> (non-overlapping pooling).
656
+ <method>: str, 'max for max-pooling,
657
+ 'mean' for mean-pooling.
658
+ <pad>: bool, pad <mat> or not. If no pad, output has size
659
+ (n-f)//s+1, n being <mat> size, f being kernel size, s stride.
660
+ if pad, output has size ceil(n/s).
661
+
662
+ Return <result>: pooled matrix.
663
+ """
664
+
665
+ m, n = mat.shape[:2]
666
+ ky, kx = ksize
667
+ if stride is None:
668
+ stride = (ky, kx)
669
+ sy, sx = stride
670
+
671
+ _ceil = lambda x, y: int(np.ceil(x/float(y)))
672
+
673
+ if pad:
674
+ ny = _ceil(m,sy)
675
+ nx = _ceil(n,sx)
676
+ size = ((ny-1)*sy+ky, (nx-1)*sx+kx) + mat.shape[2:]
677
+ mat_pad = np.full(size,np.nan)
678
+ mat_pad[:m,:n,...] = mat
679
+ else:
680
+ mat_pad = mat[:(m-ky)//sy*sy+ky, :(n-kx)//sx*sx+kx, ...]
681
+
682
+ view = self.asStride(mat_pad,ksize,stride)
683
+
684
+ if method == 'max':
685
+ result = np.nanmax(view,axis=(2,3))
686
+ else:
687
+ result = np.nanmean(view,axis=(2,3))
688
+
689
+ return result
690
+
691
+
692
+ def compute_travel_distance(self, trajectory):
693
+ distance = 0.0
694
+ for i in range(1, len(trajectory)):
695
+ # Convert the tuple positions to numpy arrays for easy computation.
696
+ prev_pos = np.array(trajectory[i-1])
697
+ curr_pos = np.array(trajectory[i])
698
+ # Euclidean distance between consecutive positions.
699
+ distance += np.linalg.norm(curr_pos - prev_pos)
700
+ return distance
701
+
702
+ ################################################################################
703
+ # SPPP Related Fns
704
+ ################################################################################
705
+
706
+ def log_metrics(self, step):
707
+ # Update tgt found metrics
708
+ if self.steps_to_first_tgt is None and self.env.num_targets_found == 1:
709
+ self.steps_to_first_tgt = step + 1
710
+ if self.steps_to_mid_tgt is None and self.env.num_targets_found == int(len(self.env.target_positions) / 2):
711
+ self.steps_to_mid_tgt = step + 1
712
+ if self.steps_to_last_tgt is None and self.env.num_targets_found == len(self.env.target_positions):
713
+ self.steps_to_last_tgt = step + 1
714
+
715
+
716
+ def transpose_flat_idx(self, idx, H=NUM_COORDS_HEIGHT, W=NUM_COORDS_WIDTH):
717
+ """
718
+ Transpose a flat index from an ``H×W`` grid to the equivalent
719
+ position in the ``W×H`` transposed grid while **keeping the result
720
+ in 1-D**.
721
+ """
722
+ # --- Safety check to catch out-of-range indices ---
723
+ assert 0 <= idx < H * W, f"idx {idx} out of bounds for shape ({H}, {W})"
724
+
725
+ # Original (row, col)
726
+ row, col = divmod(idx, W)
727
+ # After transpose these coordinates swap
728
+ row_T, col_T = col, row
729
+
730
+ # Flatten back into 1-D (row-major) for the W×H grid
731
+ return row_T * H + col_T
732
+
733
+
734
+ def poisson_tta_update(self, robot, episode, step):
735
+
736
+ # Generate Kmeans Clusters Stats
737
+ # Scale index back to CLIP_GRIDS_DIMS to be compatible with CLIP patch size
738
+ if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]:
739
+ # High-res remap via pixel coordinates preserves exact neighbourhood
740
+ filt_traj_coords, filt_targets_found_on_path = self.scale_trajectory(
741
+ robot.trajectory_coords,
742
+ self.env.target_positions,
743
+ old_dims=(NUM_COORDS_HEIGHT, NUM_COORDS_WIDTH),
744
+ full_dims=(512, 512),
745
+ new_dims=(CLIP_GRIDS_DIMS[0], CLIP_GRIDS_DIMS[1])
746
+ )
747
+ else:
748
+ filt_traj_coords = [self.transpose_flat_idx(idx) for idx in robot.trajectory_coords]
749
+ filt_targets_found_on_path = robot.targets_found_on_path
750
+
751
+ region_stats_dict = self.kmeans_clusterer.compute_region_statistics(
752
+ self.kmeans_sat_embeds_clusters,
753
+ self.clip_seg_tta.heatmap_unnormalized,
754
+ filt_traj_coords,
755
+ episode_num=episode,
756
+ step_num=step
757
+ )
758
+
759
+ # Prep & execute TTA
760
+ self.step_since_tta += 1
761
+ if robot.targets_found_on_path[-1] or self.step_since_tta % STEPS_PER_TTA == 0:
762
+
763
+ num_cells = self.clip_seg_tta.heatmap.shape[0] * self.clip_seg_tta.heatmap.shape[1]
764
+ pos_sample_weight_scale, neg_sample_weight_scale = [], []
765
+
766
+ for i, sample_loc in enumerate(filt_traj_coords):
767
+ label = self.kmeans_clusterer.get_label_id(self.kmeans_sat_embeds_clusters, sample_loc)
768
+ num_patches = region_stats_dict[label]['num_patches']
769
+ patches_visited = region_stats_dict[label]['patches_visited']
770
+ expectation = region_stats_dict[label]['expectation']
771
+
772
+ # Exponent like focal loss to wait for more samples before confidently decreasing
773
+ pos_weight = 4.0
774
+ neg_weight = min(1.0, (patches_visited/(3*num_patches))**GAMMA_EXPONENT)
775
+ pos_sample_weight_scale.append(pos_weight)
776
+ neg_sample_weight_scale.append(neg_weight)
777
+
778
+ # Adaptative LR (as samples increase, increase LR to fit more datapoints)
779
+ adaptive_lr = MIN_LR + (MAX_LR - MIN_LR) * (step / num_cells)
780
+
781
+ # TTA Update
782
+ self.clip_seg_tta.execute_tta(
783
+ filt_traj_coords,
784
+ filt_targets_found_on_path,
785
+ tta_steps=NUM_TTA_STEPS,
786
+ lr=adaptive_lr,
787
+ pos_sample_weight=pos_sample_weight_scale,
788
+ neg_sample_weight=neg_sample_weight_scale,
789
+ reset_weights=RESET_WEIGHTS
790
+ )
791
+ if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]: # If heatmap is resized from clip original dims
792
+ heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT))
793
+ self.env.segmentation_info_mask = np.expand_dims(heatmap.T.flatten(), axis=1)
794
+ unnormalized_heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap_unnormalized, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT))
795
+ self.env.segmentation_info_mask_unnormalized = np.expand_dims(unnormalized_heatmap.T.flatten(), axis=1)
796
+ print("~Resized heatmap to", NUM_COORDS_WIDTH, "x", NUM_COORDS_HEIGHT)
797
+ else:
798
+ self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1)
799
+ self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1)
800
+
801
+ self.step_since_tta = 0
802
+
803
+
804
+ def convert_heatmap_resolution(self, heatmap, full_dims=(512, 512), new_dims=(24, 24)):
805
+ heatmap_large = resize(heatmap, full_dims, order=1, # order=1 → bilinear
806
+ mode='reflect', anti_aliasing=True)
807
+
808
+ coords = self.env.graph_generator.grid_coords # (N, N, 2)
809
+ rows, cols = coords[...,1], coords[...,0]
810
+ heatmap_resized = heatmap_large[rows, cols]
811
+ heatmap_resized = heatmap_resized.reshape(new_dims[1], new_dims[0])
812
+ return heatmap_resized
813
+
814
+
815
+ def convert_labelmap_resolution(self, labelmap, full_dims=(512, 512), new_dims=(24, 24)):
816
+ """
817
+ 1) Upsample via nearest‐neighbor to full_dims
818
+ 2) Sample back down to your graph grid using grid_coords
819
+ """
820
+ # 1) Upsample with nearest‐neighbor, preserving integer labels
821
+ up = resize(
822
+ labelmap,
823
+ full_dims,
824
+ order=0, # nearest‐neighbor
825
+ mode='edge', # padding mode
826
+ preserve_range=True, # don't normalize labels
827
+ anti_aliasing=False # must be False for labels
828
+ ).astype(labelmap.dtype) # back to original integer dtype
829
+
830
+ # 2) Downsample via your precomputed grid coords (N×N×2)
831
+ coords = self.env.graph_generator.grid_coords # shape (N, N, 2)
832
+ rows = coords[...,1].astype(int)
833
+ cols = coords[...,0].astype(int)
834
+
835
+ small = up[rows, cols] # shape (N, N)
836
+ small = small.reshape(new_dims[0], new_dims[1])
837
+ return small
838
+
839
+
840
+ def scale_trajectory(self,
841
+ flat_indices,
842
+ targets,
843
+ old_dims=(17, 17),
844
+ full_dims=(512, 512),
845
+ new_dims=(24, 24)):
846
+ """
847
+ Args:
848
+ flat_indices: list of ints in [0..old_H*old_W-1]
849
+ targets: list of (y_pix, x_pix) in [0..full_H-1]
850
+ old_dims: (old_H, old_W)
851
+ full_dims: (full_H, full_W)
852
+ new_dims: (new_H, new_W)
853
+
854
+ Returns:
855
+ new_flat_traj: list of unique flattened indices in new_H×new_W
856
+ counts: list of ints, same length as new_flat_traj
857
+ """
858
+ old_H, old_W = old_dims
859
+ full_H, full_W = full_dims
860
+ new_H, new_W = new_dims
861
+
862
+ # 1) bin targets into new grid
863
+ cell_h_new = full_H / new_H
864
+ cell_w_new = full_W / new_W
865
+ grid_counts = [[0]*new_W for _ in range(new_H)]
866
+ for x_pix, y_pix in targets: # note (x, y) order as in original implementation
867
+ i_t = min(int(y_pix / cell_h_new), new_H - 1)
868
+ j_t = min(int(x_pix / cell_w_new), new_W - 1)
869
+ grid_counts[i_t][j_t] += 1
870
+
871
+ # 2) Walk the trajectory indices and project each old cell's *entire
872
+ # pixel footprint* onto the finer 24×24 grid.
873
+ cell_h_full = full_H / old_H
874
+ cell_w_full = full_W / old_W
875
+
876
+ seen = set()
877
+ new_flat_traj = []
878
+
879
+ for node_idx in flat_indices:
880
+ if node_idx < 0 or node_idx >= len(self.env.graph_generator.node_coords):
881
+ continue
882
+
883
+ coord_xy = self.env.graph_generator.node_coords[node_idx]
884
+ try:
885
+ row_old, col_old = self.env.graph_generator.find_index_from_grid_coords_2d(coord_xy)
886
+ except Exception:
887
+ continue
888
+
889
+ # Bounding box of the old cell in full-resolution pixel space
890
+ y0 = row_old * cell_h_full
891
+ y1 = (row_old + 1) * cell_h_full
892
+ x0 = col_old * cell_w_full
893
+ x1 = (col_old + 1) * cell_w_full
894
+
895
+ # Which new-grid rows & cols overlap? (inclusive ranges)
896
+ i_start = max(0, min(int(y0 / cell_h_new), new_H - 1))
897
+ i_end = max(0, min(int((y1 - 1) / cell_h_new), new_H - 1))
898
+ j_start = max(0, min(int(x0 / cell_w_new), new_W - 1))
899
+ j_end = max(0, min(int((x1 - 1) / cell_w_new), new_W - 1))
900
+
901
+ for ii in range(i_start, i_end + 1):
902
+ for jj in range(j_start, j_end + 1):
903
+ f_new = ii * new_W + jj
904
+ if f_new not in seen:
905
+ seen.add(f_new)
906
+ new_flat_traj.append(f_new)
907
+
908
+ # 3) annotate counts
909
+ counts = []
910
+ for f in new_flat_traj:
911
+ i_new, j_new = divmod(f, new_W)
912
+ counts.append(grid_counts[i_new][j_new])
913
+
914
+ return new_flat_traj, counts
915
+
916
+
917
+ ################################################################################
918
+
919
+ def make_gif(self, path, n):
920
+ """ Generate a gif given list of images """
921
+ with imageio.get_writer('{}/{}_target_rate_{:.2f}.gif'.format(path, n, self.env.targets_found_rate), mode='I',
922
+ fps=5) as writer:
923
+ for frame in self.env.frame_files:
924
+ image = imageio.imread(frame)
925
+ writer.append_data(image)
926
+ print('gif complete\n')
927
+
928
+ # Remove files
929
+ for filename in self.env.frame_files[:-1]:
930
+ os.remove(filename)
931
+
932
+ # For KMeans gif
933
+ if LOAD_AVS_BENCH and USE_CLIP_PREDS:
934
+ with imageio.get_writer('{}/{}_kmeans_stats.gif'.format(path, n), mode='I',
935
+ fps=5) as writer:
936
+ for frame in self.kmeans_clusterer.kmeans_frame_files:
937
+ image = imageio.imread(frame)
938
+ writer.append_data(image)
939
+ print('Kmeans Clusterer gif complete\n')
940
+
941
+ # Remove files
942
+ for filename in self.kmeans_clusterer.kmeans_frame_files[:-1]:
943
+ os.remove(filename)
944
+
945
+
946
+ # IS gif
947
+ with imageio.get_writer('{}/{}_IS.gif'.format(path, n), mode='I',
948
+ fps=5) as writer:
949
+ for frame in self.IS_frame_files:
950
+ image = imageio.imread(frame)
951
+ writer.append_data(image)
952
+ print('Kmeans Clusterer gif complete\n')
953
+
954
+ # Remove files
955
+ for filename in self.IS_frame_files[:-1]:
956
+ os.remove(filename)
957
+
958
+ ################################################################################
959
+
960
+
961
+ class Agent:
962
+ def __init__(self, ID, infoMap=None, uncertaintyMap=None, shape=None, row=0, col=0, sensorSize=9, numAgents=8):
963
+ self.ID = ID
964
+ self.row = row
965
+ self.col = col
966
+ self.numAgents = numAgents
967
+ self.sensorSize = sensorSize
968
+
969
+ def setLocation(self, row, col):
970
+ self.row = row
971
+ self.col = col
972
+
973
+ def getLocation(self):
974
+ return [self.row, self.col]
975
+
976
+ def move(self, action):
977
+ """
978
+ No movement: 0
979
+ North (-1,0): 1
980
+ East (0,1): 2
981
+ South (1,0): 3
982
+ West (0,-1): 4
983
+ LeftUp (-1,-1) : 5
984
+ RightUP (-1,1) :6
985
+ RightDown (1,1) :7
986
+ RightLeft (1,-1) :8
987
+ check valid action of the agent. be sure not to be out of the boundary
988
+ """
989
+ if action == 0:
990
+ return 0
991
+ elif action == 1:
992
+ self.row -= 1
993
+ elif action == 2:
994
+ self.col += 1
995
+ elif action == 3:
996
+ self.row += 1
997
+ elif action == 4:
998
+ self.col -= 1
999
+ elif action == 5:
1000
+ self.row -= 1
1001
+ self.col -= 1
1002
+ elif action == 6:
1003
+ self.row -= 1
1004
+ self.col += 1
1005
+ elif action == 7:
1006
+ self.row += 1
1007
+ self.col += 1
1008
+ elif action == 8:
1009
+ self.row += 1
1010
+ self.col -= 1
1011
+
1012
+ def reverseMove(self, action):
1013
+ if action == 0:
1014
+ return 0
1015
+ elif action == 1:
1016
+ self.row += 1
1017
+ elif action == 2:
1018
+ self.col -= 1
1019
+ elif action == 3:
1020
+ self.row -= 1
1021
+ elif action == 4:
1022
+ self.col += 1
1023
+ elif action == 5:
1024
+ self.row += 1
1025
+ self.col += 1
1026
+ elif action == 6:
1027
+ self.row += 1
1028
+ self.col -= 1
1029
+ elif action == 7:
1030
+ self.row -= 1
1031
+ self.col -= 1
1032
+ elif action == 8:
1033
+ self.row -= 1
1034
+ self.col += 1
1035
+ else:
1036
+ print("agent can only move NESW/1234")
1037
+ sys.exit()
1038
+
1039
+
1040
+ class Target:
1041
+ def __init__(self, row, col, ID, time_found=np.nan):
1042
+ self.row = row
1043
+ self.col = col
1044
+ self.ID = ID
1045
+ self.time_found = time_found
1046
+ self.status = None
1047
+ self.time_visited = time_found
1048
+
1049
+ def getLocation(self):
1050
+ return self.row, self.col
1051
+
1052
+ def updateFound(self, timeStep):
1053
+ if np.isnan(self.time_found):
1054
+ self.time_found = timeStep
1055
+
1056
+ def updateVisited(self, timeStep):
1057
+ if np.isnan(self.time_visited):
1058
+ self.time_visited = timeStep
1059
+
1060
+
1061
+ if __name__ == "__main__":
1062
+
1063
+ search_env = Env(map_index=1, k_size=K_SIZE, n_agent=NUM_ROBOTS, plot=SAVE_GIFS)
1064
+
1065
+ IS_info_map = search_env.segmentation_info_mask
1066
+ IS_agent_loc = search_env.start_positions
1067
+ IS_target_loc = [[312, 123], [123, 312], [312, 312], [123, 123]]
1068
+
1069
+ env = ISEnv(state=[IS_info_map, IS_agent_loc, IS_target_loc], shape=(NUM_COORDS_HEIGHT, NUM_COORDS_WIDTH))
1070
+ env.is_scenario(NUM_EPS_STEPS)
1071
+ print()
planner/test_parameter.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ############################################################################################
2
+ # Name: test_parameter.py
3
+ #
4
+ # NOTE: Change all your hyper-params here for eval
5
+ # Simple How-To Guide:
6
+ # 1. CLIP TTA: USE_CLIP_PREDS=True, EXECUTE_TTA=True
7
+ # 2. CLIP (No TTA): USE_CLIP_PREDS=True, EXECUTE_TTA=False
8
+ # 3. Custom masks (e.g. LISA): USE_CLIP_PREDS=False, EXECUTE_TTA=False
9
+ ############################################################################################
10
+
11
+ import os
12
+ import sys
13
+ sys.modules['TRAINING'] = False # False = Inference Testing
14
+
15
+ ###############################################################
16
+ # Overload Params
17
+ ###############################################################
18
+
19
+ OPT_VARS = {}
20
+ def getenv(var_name, default=None, cast_type=str):
21
+ try:
22
+ value = os.environ.get(var_name, None)
23
+ if value is None:
24
+ result = default
25
+ elif cast_type == bool:
26
+ result = value.lower() in ("true", "1", "yes")
27
+ else:
28
+ result = cast_type(value)
29
+ except (ValueError, TypeError):
30
+ result = default
31
+
32
+ OPT_VARS[var_name] = result # Log the result
33
+ return result
34
+
35
+ ###############################################################
36
+ # General
37
+ ###############################################################
38
+
39
+ # --- GENERAL --- #
40
+ USE_GPU = False
41
+ NUM_GPU = getenv("NUM_GPU", default=1, cast_type=int) # the number of GPUs
42
+ NUM_META_AGENT = getenv("NUM_META_AGENT", default=2, cast_type=int) # the number of concurrent processes
43
+ NUM_EPS_STEPS = getenv("NUM_EPS_STEPS", default=400, cast_type=int)
44
+ FIX_START_POSITION = getenv("FIX_START_POSITION", default=True, cast_type=bool) # Whether to fix the starting position of the robots (middle index)
45
+ NUM_ROBOTS = 1 # Only allow for 1 robot
46
+ NUM_COORDS_WIDTH=24 # How many node coords across width?
47
+ NUM_COORDS_HEIGHT=24 # How many node coords across height?
48
+ CLIP_GRIDS_DIMS=[24,24] # [16,16] if 'openai/clip-vit-large-patch14-336'
49
+ SENSOR_RANGE=80 # Only applicable to 'circle' sensor model
50
+ SENSOR_MODEL="rectangular" # "rectangular", "circle" (NOTE: no colllision check for rectangular)
51
+ TERMINATE_ON_TGTS_FOUND = True # Whether to terminate episode when all targets found
52
+ FORCE_LOGGING_DONE_TGTS_FOUND = True # Whether to force csv logging when all targets found
53
+
54
+
55
+ # --- Planner Params --- #
56
+ POLICY = getenv("POLICY", default="RL", cast_type=str)
57
+ NUM_TEST = 800 # Overriden if LOAD_AVS_BENCH
58
+ NUM_RUN = 1
59
+ MODEL_NAME = "avs_rl_policy.pth"
60
+ INPUT_DIM = 4
61
+ EMBEDDING_DIM = 128
62
+ K_SIZE = 8
63
+
64
+
65
+ # --- Folders & Visualizations --- #
66
+ GRIDMAP_SET_DIR = "maps/gpt4o/envs_val"
67
+ MASK_SET_DIR = "maps/example/masks_val" # Overriden if LOAD_AVS_BENCH
68
+ TARGETS_SET_DIR = ""
69
+ # TARGETS_SET_DIR = "maps/example/gt_masks_val_with_tgts" # Overriden if LOAD_AVS_BENCH
70
+ OVERRIDE_MASK_DIR = getenv("OVERRIDE_MASK_DIR", default="", cast_type=str) # Override initial score mask from CLIP
71
+ SAVE_GIFS = getenv("SAVE_GIFS", default=True, cast_type=bool) # do you want to save GIFs
72
+ FOLDER_NAME = 'avs_search'
73
+ MODEL_PATH = f'inference/model'
74
+ GIFS_PATH = f'inference/test_results/gifs/{FOLDER_NAME}'
75
+ LOG_PATH = f'inference/test_results/log/{FOLDER_NAME}'
76
+ LOG_TEMPLATE_XLSX = f'inference/template.xlsx'
77
+ CSV_EXPT_NAME = getenv("CSV_EXPT_NAME", default="data", cast_type=str)
78
+ VIZ_GRAPH_EDGES = False # do you want to visualize the graph edges
79
+
80
+
81
+ #######################################################################
82
+ # AVS Params
83
+ #######################################################################
84
+
85
+ # General PARAMS
86
+ USE_CLIP_PREDS = getenv("USE_CLIP_PREDS", default=True, cast_type=bool) # If false, use custom masks from OVERRIDE_MASK_DIR
87
+ QUERY_TAX = getenv("QUERY_TAX", default="", cast_type=str) # "" = Test all tax (can accept taxonomy substrings)
88
+ EXECUTE_TTA = getenv("EXECUTE_TTA", default=True, cast_type=bool) # Whether to execute TTA mask updates
89
+ QUERY_MODALITY = getenv("QUERY_MODALITY", default="image", cast_type=str) # "image", "text", "sound"
90
+ STEPS_PER_TTA = 20 # no. steps before each TTA series
91
+ NUM_TTA_STEPS = 1 # no. of TTA steps during each series
92
+ RESET_WEIGHTS = True
93
+ MIN_LR = 1e-6
94
+ MAX_LR = 1e-5
95
+ GAMMA_EXPONENT = 2
96
+
97
+ # Paths related to AVS (TRAIN w/ TARGETS)
98
+ LOAD_AVS_BENCH = True # Whether to init AVS datasets
99
+ AVS_IMG_DIR = '/mnt/hdd/avs_bench_ds/inat21'
100
+ AVS_IMO_DIR = '/mnt/hdd/avs_bench_ds/sat_jpg/train_512px'
101
+ AVS_INAT_JSON_PATH = '/mnt/hdd/avs_bench_ds/inat21/train.json'
102
+ AVS_SOUND_DIR = '/mnt/hdd/avs_bench_ds/sound_mp3/test'
103
+ AVS_GAUSSIAN_BLUR_KERNEL = (5,5)
104
+ AVS_SAT_TO_IMG_IDS_PATH = getenv("AVS_SAT_TO_IMG_IDS_PATH", default="search_tri_modal|val_in_domain", cast_type=str)
105
+ AVS_LOAD_PRETRAINED_HF_CHECKPOINT = getenv("AVS_LOAD_PRETRAINED_HF_CHECKPOINT", default=True, cast_type=bool) # If false, load locally using CHECKPOINT_PATHs
106
+ AVS_SAT_CHECKPOINT_PATH = getenv("AVS_SAT_CHECKPOINT_PATH", default="", cast_type=str)
107
+ AVS_SOUND_CHECKPOINT_PATH = getenv("AVS_SOUND_CHECKPOINT_PATH", default="", cast_type=str)
108
+
109
+ #######################################################################
110
+ # UTILS
111
+ #######################################################################
112
+
113
+ # COLORS (for printing)
114
+ RED='\033[1;31m'
115
+ GREEN='\033[1;32m'
116
+ YELLOW='\033[1;93m'
117
+ NC_BOLD='\033[1m' # Bold, No Color
118
+ NC='\033[0m' # No Color
planner/test_worker.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #######################################################################
2
+ # Name: test_worker.py
3
+ #
4
+ # - Runs robot in environment using RL Planner
5
+ #######################################################################
6
+
7
+ from .test_parameter import *
8
+
9
+ import imageio
10
+ import os
11
+ import copy
12
+ import numpy as np
13
+ import torch
14
+ from time import time
15
+ from pathlib import Path
16
+ from skimage.transform import resize
17
+ from taxabind_avs.satbind.kmeans_clustering import CombinedSilhouetteInertiaClusterer
18
+ from .env import Env
19
+ from .robot import Robot
20
+
21
+ np.seterr(invalid='raise', divide='raise')
22
+
23
+
24
+ class TestWorker:
25
+ def __init__(self, meta_agent_id, n_agent, policy_net, global_step, device='cuda', greedy=False, save_image=False, clip_seg_tta=None):
26
+ self.device = device
27
+ self.greedy = greedy
28
+ self.n_agent = n_agent
29
+ self.metaAgentID = meta_agent_id
30
+ self.global_step = global_step
31
+ self.k_size = K_SIZE
32
+ self.save_image = save_image
33
+ self.clip_seg_tta = clip_seg_tta
34
+ self.execute_tta = EXECUTE_TTA # Added to interface with app.py
35
+
36
+ self.env = Env(map_index=self.global_step, n_agent=n_agent, k_size=self.k_size, plot=save_image, test=True)
37
+ self.local_policy_net = policy_net
38
+
39
+ self.robot_list = []
40
+ self.all_robot_positions = []
41
+ for i in range(self.n_agent):
42
+ robot_position = self.env.start_positions[i]
43
+ robot = Robot(robot_id=i, position=robot_position, plot=save_image)
44
+ self.robot_list.append(robot)
45
+ self.all_robot_positions.append(robot_position)
46
+
47
+ self.perf_metrics = dict()
48
+ self.bad_mask_init = False
49
+
50
+ # NOTE: Option to override gifs_path to interface with app.py
51
+ self.gifs_path = GIFS_PATH
52
+
53
+ # NOTE: updated due to app.py (hf does not allow heatmap to persist)
54
+ if LOAD_AVS_BENCH:
55
+ if clip_seg_tta is not None:
56
+ heatmap, heatmap_unnormalized, heatmap_unnormalized_initial, patch_embeds = self.clip_seg_tta.reset(sample_idx=self.global_step)
57
+ self.clip_seg_tta.heatmap = heatmap
58
+ self.clip_seg_tta.heatmap_unnormalized = heatmap_unnormalized
59
+ self.clip_seg_tta.heatmap_unnormalized_initial = heatmap_unnormalized_initial
60
+ self.clip_seg_tta.patch_embeds = patch_embeds
61
+
62
+ # Override target positions in env
63
+ self.env.target_positions = [(pose[1], pose[0]) for pose in self.clip_seg_tta.target_positions]
64
+
65
+ # Override segmentation mask
66
+ if not USE_CLIP_PREDS and OVERRIDE_MASK_DIR != "":
67
+ score_mask_path = os.path.join(OVERRIDE_MASK_DIR, self.clip_seg_tta.gt_mask_name)
68
+ print("score_mask_path: ", score_mask_path)
69
+ if os.path.exists(score_mask_path):
70
+ self.env.segmentation_mask = self.env.import_segmentation_mask(score_mask_path)
71
+ self.env.begin(self.env.map_start_position)
72
+ else:
73
+ print(f"\n\n{RED}ERROR: Trying to override, but score mask not found at path:{NC} ", score_mask_path)
74
+ self.bad_mask_init = True
75
+
76
+ # Save clustered embeds from sat encoder
77
+ if USE_CLIP_PREDS:
78
+ self.kmeans_clusterer = CombinedSilhouetteInertiaClusterer(
79
+ k_min=1,
80
+ k_max=8,
81
+ k_avg_max=4,
82
+ silhouette_threshold=0.15,
83
+ relative_threshold=0.15,
84
+ random_state=0,
85
+ min_patch_size=5,
86
+ n_smooth_iter=2,
87
+ ignore_label=-1,
88
+ plot=self.save_image,
89
+ gifs_dir = GIFS_PATH
90
+ )
91
+ # Generate kmeans clusters
92
+ self.kmeans_sat_embeds_clusters = self.kmeans_clusterer.fit_predict(
93
+ patch_embeds=self.clip_seg_tta.patch_embeds,
94
+ map_shape=(CLIP_GRIDS_DIMS[0], CLIP_GRIDS_DIMS[1]),
95
+ )
96
+ print("Chosen k:", self.kmeans_clusterer.final_k)
97
+
98
+ # if EXECUTE_TTA:
99
+ # print("Will execute TTA...")
100
+
101
+ # Define Poisson TTA params
102
+ self.step_since_tta = 0
103
+ self.steps_to_first_tgt = None
104
+ self.steps_to_mid_tgt = None
105
+ self.steps_to_last_tgt = None
106
+
107
+
108
+ def run_episode(self, curr_episode):
109
+
110
+ # Return all metrics as None if faulty mask init
111
+ if self.bad_mask_init:
112
+ self.perf_metrics['tax'] = None
113
+ self.perf_metrics['travel_dist'] = None
114
+ self.perf_metrics['travel_steps'] = None
115
+ self.perf_metrics['steps_to_first_tgt'] = None
116
+ self.perf_metrics['steps_to_mid_tgt'] = None
117
+ self.perf_metrics['steps_to_last_tgt'] = None
118
+ self.perf_metrics['explored_rate'] = None
119
+ self.perf_metrics['targets_found'] = None
120
+ self.perf_metrics['targets_total'] = None
121
+ self.perf_metrics['kmeans_k'] = None
122
+ self.perf_metrics['tgts_gt_score'] = None
123
+ self.perf_metrics['clip_inference_time'] = None
124
+ self.perf_metrics['tta_time'] = None
125
+ self.perf_metrics['success_rate'] = None
126
+ return
127
+
128
+ eps_start = time()
129
+ done = False
130
+ for robot_id, deciding_robot in enumerate(self.robot_list):
131
+ deciding_robot.observations = self.get_observations(deciding_robot.robot_position)
132
+ if LOAD_AVS_BENCH and USE_CLIP_PREDS:
133
+ if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]: # If heatmap is resized from clip original dims
134
+ heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT))
135
+ self.env.segmentation_info_mask = np.expand_dims(heatmap.T.flatten(), axis=1)
136
+ unnormalized_heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap_unnormalized, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT))
137
+ self.env.segmentation_info_mask_unnormalized = np.expand_dims(unnormalized_heatmap.T.flatten(), axis=1)
138
+ print("Resized heatmap to", NUM_COORDS_WIDTH, "x", NUM_COORDS_HEIGHT)
139
+ else:
140
+ self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1)
141
+ self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1)
142
+
143
+ ### Run episode ###
144
+ for step in range(NUM_EPS_STEPS):
145
+
146
+ next_position_list = []
147
+ dist_list = []
148
+ travel_dist_list = []
149
+ dist_array = np.zeros((self.n_agent, 1))
150
+ for robot_id, deciding_robot in enumerate(self.robot_list):
151
+ observations = deciding_robot.observations
152
+
153
+ ### Forward pass through policy to get next position ###
154
+ next_position, action_index = self.select_node(observations)
155
+ dist = np.linalg.norm(next_position - deciding_robot.robot_position)
156
+
157
+ ### Log results of action (e.g. distance travelled) ###
158
+ dist_array[robot_id] = dist
159
+ dist_list.append(dist)
160
+ travel_dist_list.append(deciding_robot.travel_dist)
161
+ next_position_list.append(next_position)
162
+ self.all_robot_positions[robot_id] = next_position
163
+
164
+ arriving_sequence = np.argsort(dist_list)
165
+ next_position_list = np.array(next_position_list)
166
+ dist_list = np.array(dist_list)
167
+ travel_dist_list = np.array(travel_dist_list)
168
+ next_position_list = next_position_list[arriving_sequence]
169
+ dist_list = dist_list[arriving_sequence]
170
+ travel_dist_list = travel_dist_list[arriving_sequence]
171
+
172
+ ### Take Action (Deconflict if 2 agents choose the same target position) ###
173
+ next_position_list, dist_list = self.solve_conflict(arriving_sequence, next_position_list, dist_list)
174
+ reward_list, done = self.env.multi_robot_step(next_position_list, dist_list, travel_dist_list)
175
+
176
+ ### Update observations + rewards from action ###
177
+ for reward, robot_id in zip(reward_list, arriving_sequence):
178
+ robot = self.robot_list[robot_id]
179
+ robot.save_trajectory_coords(self.env.find_index_from_coords(robot.robot_position), self.env.num_new_targets_found)
180
+
181
+ # # TTA Update via Poisson Test (with KMeans clustering stats)
182
+ if LOAD_AVS_BENCH and USE_CLIP_PREDS and self.execute_tta:
183
+ self.poisson_tta_update(robot, self.global_step, step)
184
+
185
+ robot.observations = self.get_observations(robot.robot_position)
186
+ robot.save_reward_done(reward, done)
187
+
188
+ # Update metrics
189
+ self.log_metrics(step=step)
190
+
191
+ ### Save a frame to generate gif of robot trajectories ###
192
+ if self.save_image:
193
+ robots_route = []
194
+ for robot in self.robot_list:
195
+ robots_route.append([robot.xPoints, robot.yPoints])
196
+ if not os.path.exists(self.gifs_path):
197
+ os.makedirs(self.gifs_path)
198
+ if LOAD_AVS_BENCH:
199
+ # NOTE: Replaced since using app.py
200
+ self.env.plot_heatmap(self.gifs_path, step, max(travel_dist_list), robots_route)
201
+
202
+ if done:
203
+ break
204
+
205
+ if LOAD_AVS_BENCH:
206
+ tax = Path(self.clip_seg_tta.gt_mask_name).stem
207
+ self.perf_metrics['tax'] = " ".join(tax.split("_")[1:])
208
+ else:
209
+ self.perf_metrics['tax'] = None
210
+ self.perf_metrics['travel_dist'] = max(travel_dist_list)
211
+ self.perf_metrics['travel_steps'] = step + 1
212
+ self.perf_metrics['steps_to_first_tgt'] = self.steps_to_first_tgt
213
+ self.perf_metrics['steps_to_mid_tgt'] = self.steps_to_mid_tgt
214
+ self.perf_metrics['steps_to_last_tgt'] = self.steps_to_last_tgt
215
+ self.perf_metrics['explored_rate'] = self.env.explored_rate
216
+ self.perf_metrics['targets_found'] = self.env.targets_found_rate
217
+ self.perf_metrics['targets_total'] = len(self.env.target_positions)
218
+ if USE_CLIP_PREDS:
219
+ self.perf_metrics['kmeans_k'] = self.kmeans_clusterer.final_k
220
+ self.perf_metrics['tgts_gt_score'] = self.clip_seg_tta.tgts_gt_score
221
+ self.perf_metrics['clip_inference_time'] = self.clip_seg_tta.clip_inference_time
222
+ self.perf_metrics['tta_time'] = self.clip_seg_tta.tta_time
223
+ else:
224
+ self.perf_metrics['kmeans_k'] = None
225
+ self.perf_metrics['tgts_gt_score'] = None
226
+ self.perf_metrics['clip_inference_time'] = None
227
+ self.perf_metrics['tta_time'] = None
228
+ if FORCE_LOGGING_DONE_TGTS_FOUND and self.env.targets_found_rate == 1.0:
229
+ self.perf_metrics['success_rate'] = True
230
+ else:
231
+ self.perf_metrics['success_rate'] = done
232
+
233
+ # save gif
234
+ if self.save_image:
235
+ path = self.gifs_path # NOTE: Set to self.gifs_path since using app.py
236
+ self.make_gif(path, curr_episode)
237
+
238
+ print(YELLOW, f"[Eps {curr_episode} Completed] Time Taken: {time()-eps_start:.2f}s, Steps: {step+1}", NC)
239
+
240
+ def get_observations(self, robot_position):
241
+ """ Get robot's sensor observation of environment given position """
242
+ current_node_index = self.env.find_index_from_coords(robot_position)
243
+ current_index = torch.tensor([current_node_index]).unsqueeze(0).unsqueeze(0).to(self.device) # (1,1,1)
244
+
245
+ node_coords = copy.deepcopy(self.env.node_coords)
246
+ graph = copy.deepcopy(self.env.graph)
247
+ node_utility = copy.deepcopy(self.env.node_utility)
248
+ guidepost = copy.deepcopy(self.env.guidepost)
249
+ segmentation_info_mask = copy.deepcopy(self.env.filtered_seg_info_mask)
250
+
251
+ n_nodes = node_coords.shape[0]
252
+ node_coords = node_coords / 640
253
+ node_utility = node_utility / 50
254
+ node_utility_inputs = node_utility.reshape((n_nodes, 1))
255
+
256
+ occupied_node = np.zeros((n_nodes, 1))
257
+ for position in self.all_robot_positions:
258
+ index = self.env.find_index_from_coords(position)
259
+ if index == current_index.item():
260
+ occupied_node[index] = -1
261
+ else:
262
+ occupied_node[index] = 1
263
+
264
+ node_inputs = np.concatenate((node_coords, segmentation_info_mask, guidepost), axis=1)
265
+ node_inputs = torch.FloatTensor(node_inputs).unsqueeze(0).to(self.device)
266
+ node_padding_mask = None
267
+
268
+ graph = list(graph.values())
269
+ edge_inputs = []
270
+ for node in graph:
271
+ node_edges = list(map(int, node))
272
+ edge_inputs.append(node_edges)
273
+
274
+ bias_matrix = self.calculate_edge_mask(edge_inputs)
275
+ edge_mask = torch.from_numpy(bias_matrix).float().unsqueeze(0).to(self.device)
276
+
277
+ for edges in edge_inputs:
278
+ while len(edges) < self.k_size:
279
+ edges.append(0)
280
+
281
+ edge_inputs = torch.tensor(edge_inputs).unsqueeze(0).to(self.device)
282
+ edge_padding_mask = torch.zeros((1, len(edge_inputs), K_SIZE), dtype=torch.int64).to(self.device)
283
+ one = torch.ones_like(edge_padding_mask, dtype=torch.int64).to(self.device)
284
+ edge_padding_mask = torch.where(edge_inputs == 0, one, edge_padding_mask)
285
+
286
+ observations = node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask
287
+ return observations
288
+
289
+
290
+ def select_node(self, observations):
291
+ """ Forward pass through policy to get next position to go to on map """
292
+ node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask = observations
293
+ with torch.no_grad():
294
+ logp_list = self.local_policy_net(node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask)
295
+
296
+ if self.greedy:
297
+ action_index = torch.argmax(logp_list, dim=1).long()
298
+ else:
299
+ action_index = torch.multinomial(logp_list.exp(), 1).long().squeeze(1)
300
+
301
+ next_node_index = edge_inputs[:, current_index.item(), action_index.item()]
302
+
303
+ next_position = self.env.node_coords[next_node_index]
304
+
305
+ return next_position, action_index
306
+
307
+ def solve_conflict(self, arriving_sequence, next_position_list, dist_list):
308
+ """ Deconflict if 2 agents choose the same target position """
309
+ for j, [robot_id, next_position] in enumerate(zip(arriving_sequence, next_position_list)):
310
+ moving_robot = self.robot_list[robot_id]
311
+ # if next_position[0] + next_position[1] * 1j in (next_position_list[:, 0] + next_position_list[:, 1] * 1j)[:j]:
312
+ # dist_to_next_position = np.argsort(np.linalg.norm(self.env.node_coords - next_position, axis=1))
313
+ # k = 0
314
+ # while next_position[0] + next_position[1] * 1j in (next_position_list[:, 0] + next_position_list[:, 1] * 1j)[:j]:
315
+ # k += 1
316
+ # next_position = self.env.node_coords[dist_to_next_position[k]]
317
+
318
+ dist = np.linalg.norm(next_position - moving_robot.robot_position)
319
+ next_position_list[j] = next_position
320
+ dist_list[j] = dist
321
+ moving_robot.travel_dist += dist
322
+ moving_robot.robot_position = next_position
323
+
324
+ return next_position_list, dist_list
325
+
326
+ def work(self, currEpisode):
327
+ '''
328
+ Interacts with the environment. The agent gets either gradients or experience buffer
329
+ '''
330
+ self.run_episode(currEpisode)
331
+
332
+ def calculate_edge_mask(self, edge_inputs):
333
+ size = len(edge_inputs)
334
+ bias_matrix = np.ones((size, size))
335
+ for i in range(size):
336
+ for j in range(size):
337
+ if j in edge_inputs[i]:
338
+ bias_matrix[i][j] = 0
339
+ return bias_matrix
340
+
341
+ def make_gif(self, path, n):
342
+ """ Generate a gif given list of images """
343
+ with imageio.get_writer('{}/{}_target_rate_{:.2f}.gif'.format(path, n, self.env.targets_found_rate), mode='I',
344
+ fps=5) as writer:
345
+ for frame in self.env.frame_files:
346
+ image = imageio.imread(frame)
347
+ writer.append_data(image)
348
+ print('gif complete\n')
349
+
350
+ # Remove files
351
+ for filename in self.env.frame_files[:-1]:
352
+ os.remove(filename)
353
+
354
+ # For gif during TTA
355
+ if LOAD_AVS_BENCH:
356
+ with imageio.get_writer('{}/{}_kmeans_stats.gif'.format(path, n), mode='I',
357
+ fps=5) as writer:
358
+ for frame in self.kmeans_clusterer.kmeans_frame_files:
359
+ image = imageio.imread(frame)
360
+ writer.append_data(image)
361
+ print('Kmeans Clusterer gif complete\n')
362
+
363
+ # Remove files
364
+ for filename in self.kmeans_clusterer.kmeans_frame_files[:-1]:
365
+ os.remove(filename)
366
+
367
+ ################################################################################
368
+ # SPPP Related Fns
369
+ ################################################################################
370
+
371
+ def log_metrics(self, step):
372
+ # Update tgt found metrics
373
+ if self.steps_to_first_tgt is None and self.env.num_targets_found == 1:
374
+ self.steps_to_first_tgt = step + 1
375
+ if self.steps_to_mid_tgt is None and self.env.num_targets_found == int(len(self.env.target_positions) / 2):
376
+ self.steps_to_mid_tgt = step + 1
377
+ if self.steps_to_last_tgt is None and self.env.num_targets_found == len(self.env.target_positions):
378
+ self.steps_to_last_tgt = step + 1
379
+
380
+
381
+ def transpose_flat_idx(self, idx, H=NUM_COORDS_HEIGHT, W=NUM_COORDS_WIDTH):
382
+ """
383
+ Transpose a flat index from an ``H×W`` grid to the equivalent
384
+ position in the ``W×H`` transposed grid while **keeping the result
385
+ in 1-D**.
386
+ """
387
+ # --- Safety check to catch out-of-range indices ---
388
+ assert 0 <= idx < H * W, f"idx {idx} out of bounds for shape ({H}, {W})"
389
+
390
+ # Original (row, col)
391
+ row, col = divmod(idx, W)
392
+ # After transpose these coordinates swap
393
+ row_T, col_T = col, row
394
+
395
+ # Flatten back into 1-D (row-major) for the W×H grid
396
+ return row_T * H + col_T
397
+
398
+
399
+ def poisson_tta_update(self, robot, episode, step):
400
+
401
+ # Generate Kmeans Clusters Stats
402
+ # Scale index back to CLIP_GRIDS_DIMS to be compatible with CLIP patch size
403
+ if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]:
404
+ # High-res remap via pixel coordinates preserves exact neighbourhood
405
+ filt_traj_coords, filt_targets_found_on_path = self.scale_trajectory(
406
+ robot.trajectory_coords,
407
+ self.env.target_positions,
408
+ old_dims=(NUM_COORDS_HEIGHT, NUM_COORDS_WIDTH),
409
+ full_dims=(512, 512),
410
+ new_dims=(CLIP_GRIDS_DIMS[0], CLIP_GRIDS_DIMS[1])
411
+ )
412
+ else:
413
+ filt_traj_coords = [self.transpose_flat_idx(idx) for idx in robot.trajectory_coords]
414
+ filt_targets_found_on_path = robot.targets_found_on_path
415
+
416
+ region_stats_dict = self.kmeans_clusterer.compute_region_statistics(
417
+ self.kmeans_sat_embeds_clusters,
418
+ self.clip_seg_tta.heatmap_unnormalized,
419
+ filt_traj_coords,
420
+ episode_num=episode,
421
+ step_num=step
422
+ )
423
+
424
+ # Prep & execute TTA
425
+ self.step_since_tta += 1
426
+ if robot.targets_found_on_path[-1] or self.step_since_tta % STEPS_PER_TTA == 0:
427
+
428
+ # NOTE: integration with app.py on hf
429
+ self.clip_seg_tta.executing_tta = True
430
+
431
+ num_cells = self.clip_seg_tta.heatmap.shape[0] * self.clip_seg_tta.heatmap.shape[1]
432
+ pos_sample_weight_scale, neg_sample_weight_scale = [], []
433
+
434
+ for i, sample_loc in enumerate(filt_traj_coords):
435
+ label = self.kmeans_clusterer.get_label_id(self.kmeans_sat_embeds_clusters, sample_loc)
436
+ num_patches = region_stats_dict[label]['num_patches']
437
+ patches_visited = region_stats_dict[label]['patches_visited']
438
+ expectation = region_stats_dict[label]['expectation']
439
+
440
+ # Exponent like focal loss to wait for more samples before confidently decreasing
441
+ pos_weight = 4.0
442
+ neg_weight = min(1.0, (patches_visited/(3*num_patches))**GAMMA_EXPONENT)
443
+ pos_sample_weight_scale.append(pos_weight)
444
+ neg_sample_weight_scale.append(neg_weight)
445
+
446
+ # # # Adaptative LR (as samples increase, increase LR to fit more datapoints)
447
+ adaptive_lr = MIN_LR + (MAX_LR - MIN_LR) * (step / num_cells)
448
+
449
+ # TTA Update
450
+ # NOTE: updated due to app.py (hf does not allow heatmap to persist)
451
+ heatmap = self.clip_seg_tta.execute_tta(
452
+ filt_traj_coords,
453
+ filt_targets_found_on_path,
454
+ tta_steps=NUM_TTA_STEPS,
455
+ lr=adaptive_lr,
456
+ pos_sample_weight=pos_sample_weight_scale,
457
+ neg_sample_weight=neg_sample_weight_scale,
458
+ reset_weights=RESET_WEIGHTS
459
+ )
460
+ self.clip_seg_tta.heatmap = heatmap
461
+
462
+ if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]: # If heatmap is resized from clip original dims
463
+ heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT))
464
+ self.env.segmentation_info_mask = np.expand_dims(heatmap.T.flatten(), axis=1)
465
+ unnormalized_heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap_unnormalized, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT))
466
+ self.env.segmentation_info_mask_unnormalized = np.expand_dims(unnormalized_heatmap.T.flatten(), axis=1)
467
+ print("~Resized heatmap to", NUM_COORDS_WIDTH, "x", NUM_COORDS_HEIGHT)
468
+ else:
469
+ self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1)
470
+ self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1)
471
+
472
+ self.step_since_tta = 0
473
+
474
+ # NOTE: integration with app.py on hf
475
+ self.clip_seg_tta.executing_tta = False
476
+
477
+
478
+ def convert_heatmap_resolution(self, heatmap, full_dims=(512, 512), new_dims=(24, 24)):
479
+ heatmap_large = resize(heatmap, full_dims, order=1, # order=1 → bilinear
480
+ mode='reflect', anti_aliasing=True)
481
+
482
+ coords = self.env.graph_generator.grid_coords # (N, N, 2)
483
+ rows, cols = coords[...,1], coords[...,0]
484
+ heatmap_resized = heatmap_large[rows, cols]
485
+ heatmap_resized = heatmap_resized.reshape(new_dims[1], new_dims[0])
486
+ return heatmap_resized
487
+
488
+
489
+ def convert_labelmap_resolution(self, labelmap, full_dims=(512, 512), new_dims=(24, 24)):
490
+ """
491
+ 1) Upsample via nearest‐neighbor to full_dims
492
+ 2) Sample back down to your graph grid using grid_coords
493
+ """
494
+ # 1) Upsample with nearest‐neighbor, preserving integer labels
495
+ up = resize(
496
+ labelmap,
497
+ full_dims,
498
+ order=0, # nearest‐neighbor
499
+ mode='edge', # padding mode
500
+ preserve_range=True, # don't normalize labels
501
+ anti_aliasing=False # must be False for labels
502
+ ).astype(labelmap.dtype) # back to original integer dtype
503
+
504
+ # 2) Downsample via your precomputed grid coords
505
+ coords = self.env.graph_generator.grid_coords # shape (N, N, 2)
506
+ rows = coords[...,1].astype(int)
507
+ cols = coords[...,0].astype(int)
508
+
509
+ small = up[rows, cols] # shape (N, N)
510
+ small = small.reshape(new_dims[0], new_dims[1])
511
+ return small
512
+
513
+
514
+ def scale_trajectory(self,
515
+ flat_indices,
516
+ targets,
517
+ old_dims=(17, 17),
518
+ full_dims=(512, 512),
519
+ new_dims=(24, 24)):
520
+ """
521
+ Args:
522
+ flat_indices: list of ints in [0..old_H*old_W-1]
523
+ targets: list of (y_pix, x_pix) in [0..full_H-1]
524
+ old_dims: (old_H, old_W)
525
+ full_dims: (full_H, full_W)
526
+ new_dims: (new_H, new_W)
527
+
528
+ Returns:
529
+ new_flat_traj: list of unique flattened indices in new_H×new_W
530
+ counts: list of ints, same length as new_flat_traj
531
+ """
532
+ old_H, old_W = old_dims
533
+ full_H, full_W = full_dims
534
+ new_H, new_W = new_dims
535
+
536
+ # 1) bin targets into new grid
537
+ cell_h_new = full_H / new_H
538
+ cell_w_new = full_W / new_W
539
+ grid_counts = [[0]*new_W for _ in range(new_H)]
540
+ for x_pix, y_pix in targets: # note (x, y) order as in original implementation
541
+ i_t = min(int(y_pix / cell_h_new), new_H - 1)
542
+ j_t = min(int(x_pix / cell_w_new), new_W - 1)
543
+ grid_counts[i_t][j_t] += 1
544
+
545
+ # 2) Walk the trajectory indices and project each old cell's *entire
546
+ # pixel footprint* onto the finer 24×24 grid.
547
+ cell_h_full = full_H / old_H
548
+ cell_w_full = full_W / old_W
549
+
550
+ seen = set()
551
+ new_flat_traj = []
552
+
553
+ for node_idx in flat_indices:
554
+ if node_idx < 0 or node_idx >= len(self.env.graph_generator.node_coords):
555
+ continue
556
+
557
+ coord_xy = self.env.graph_generator.node_coords[node_idx]
558
+ try:
559
+ row_old, col_old = self.env.graph_generator.find_index_from_grid_coords_2d(coord_xy)
560
+ except Exception:
561
+ continue
562
+
563
+ # Bounding box of the old cell in full-resolution pixel space
564
+ y0 = row_old * cell_h_full
565
+ y1 = (row_old + 1) * cell_h_full
566
+ x0 = col_old * cell_w_full
567
+ x1 = (col_old + 1) * cell_w_full
568
+
569
+ # Which new-grid rows & cols overlap? (inclusive ranges)
570
+ i_start = max(0, min(int(y0 / cell_h_new), new_H - 1))
571
+ i_end = max(0, min(int((y1 - 1) / cell_h_new), new_H - 1))
572
+ j_start = max(0, min(int(x0 / cell_w_new), new_W - 1))
573
+ j_end = max(0, min(int((x1 - 1) / cell_w_new), new_W - 1))
574
+
575
+ for ii in range(i_start, i_end + 1):
576
+ for jj in range(j_start, j_end + 1):
577
+ f_new = ii * new_W + jj
578
+ if f_new not in seen:
579
+ seen.add(f_new)
580
+ new_flat_traj.append(f_new)
581
+
582
+ # 3) annotate counts
583
+ counts = []
584
+ for f in new_flat_traj:
585
+ i_new, j_new = divmod(f, new_W)
586
+ counts.append(grid_counts[i_new][j_new])
587
+
588
+ return new_flat_traj, counts
589
+
590
+ ################################################################################
planner/worker.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #######################################################################
2
+ # Name: worker.py
3
+ #
4
+ # - Runs robot in environment for N steps
5
+ # - Collects & Returns S(t), A(t), R(t), S(t+1)
6
+ #######################################################################
7
+
8
+ from .parameter import *
9
+
10
+ import os
11
+ import json
12
+ import copy
13
+ import imageio
14
+ import numpy as np
15
+ import torch
16
+ from time import time
17
+ from .env import Env
18
+ from .robot import Robot
19
+
20
+ class Worker:
21
+ def __init__(self, meta_agent_id, n_agent, policy_net, q_net, global_step, device='cuda', greedy=False, save_image=False, clip_seg_tta=None):
22
+ self.device = device
23
+ self.greedy = greedy
24
+ self.n_agent = n_agent
25
+ self.metaAgentID = meta_agent_id
26
+ self.global_step = global_step
27
+ self.node_padding_size = NODE_PADDING_SIZE
28
+ self.k_size = K_SIZE
29
+ self.save_image = save_image
30
+ self.clip_seg_tta = clip_seg_tta
31
+
32
+ # Randomize map_index
33
+ mask_index = None
34
+ if MASKS_RAND_INDICES_PATH != "":
35
+ with open(MASKS_RAND_INDICES_PATH, 'r') as f:
36
+ mask_index_rand_json = json.load(f)
37
+ mask_index = mask_index_rand_json[self.global_step % len(mask_index_rand_json)]
38
+ print("mask_index: ", mask_index)
39
+
40
+ self.env = Env(map_index=self.global_step, n_agent=n_agent, k_size=self.k_size, plot=save_image, mask_index=mask_index)
41
+ self.local_policy_net = policy_net
42
+ self.local_q_net = q_net
43
+
44
+ self.robot_list = []
45
+ self.all_robot_positions = []
46
+
47
+ for i in range(self.n_agent):
48
+ robot_position = self.env.start_positions[i]
49
+ robot = Robot(robot_id=i, position=robot_position, plot=save_image)
50
+ self.robot_list.append(robot)
51
+ self.all_robot_positions.append(robot_position)
52
+
53
+ self.perf_metrics = dict()
54
+ self.episode_buffer = []
55
+ for i in range(15):
56
+ self.episode_buffer.append([])
57
+
58
+
59
+ def run_episode(self, curr_episode):
60
+
61
+ eps_start = time()
62
+ done = False
63
+ for robot_id, deciding_robot in enumerate(self.robot_list):
64
+ deciding_robot.observations = self.get_observations(deciding_robot.robot_position)
65
+
66
+ ### Run episode ###
67
+ for step in range(NUM_EPS_STEPS):
68
+
69
+ next_position_list = []
70
+ dist_list = []
71
+ travel_dist_list = []
72
+ dist_array = np.zeros((self.n_agent, 1))
73
+ for robot_id, deciding_robot in enumerate(self.robot_list):
74
+ observations = deciding_robot.observations
75
+ deciding_robot.save_observations(observations)
76
+
77
+ ### Forward pass through policy to get next position ###
78
+ next_position, action_index = self.select_node(observations)
79
+ deciding_robot.save_action(action_index)
80
+
81
+ dist = np.linalg.norm(next_position - deciding_robot.robot_position)
82
+
83
+ ### Log results of action (e.g. distance travelled) ###
84
+ dist_array[robot_id] = dist
85
+ dist_list.append(dist)
86
+ travel_dist_list.append(deciding_robot.travel_dist)
87
+ next_position_list.append(next_position)
88
+ self.all_robot_positions[robot_id] = next_position
89
+
90
+ arriving_sequence = np.argsort(dist_list)
91
+ next_position_list = np.array(next_position_list)
92
+ dist_list = np.array(dist_list)
93
+ travel_dist_list = np.array(travel_dist_list)
94
+ next_position_list = next_position_list[arriving_sequence]
95
+ dist_list = dist_list[arriving_sequence]
96
+ travel_dist_list = travel_dist_list[arriving_sequence]
97
+
98
+ ### Take Action (Deconflict if 2 agents choose the same target position) ###
99
+ next_position_list, dist_list = self.solve_conflict(arriving_sequence, next_position_list, dist_list)
100
+ reward_list, done = self.env.multi_robot_step(next_position_list, dist_list, travel_dist_list)
101
+
102
+ ### Update observations + rewards from action ###
103
+ for reward, robot_id in zip(reward_list, arriving_sequence):
104
+ robot = self.robot_list[robot_id]
105
+ robot.observations = self.get_observations(robot.robot_position)
106
+ robot.save_trajectory_coords(self.env.find_index_from_coords(robot.robot_position), self.env.num_new_targets_found)
107
+ robot.save_reward_done(reward, done)
108
+ robot.save_next_observations(robot.observations)
109
+
110
+ ### Save a frame to generate gif of robot trajectories ###
111
+ if self.save_image:
112
+ robots_route = []
113
+ for robot in self.robot_list:
114
+ robots_route.append([robot.xPoints, robot.yPoints])
115
+ if not os.path.exists(GIFS_PATH):
116
+ os.makedirs(GIFS_PATH)
117
+ self.env.plot_env(self.global_step, GIFS_PATH, step, max(travel_dist_list), robots_route)
118
+
119
+ if done:
120
+ break
121
+
122
+ for robot in self.robot_list:
123
+ for i in range(15):
124
+ self.episode_buffer[i] += robot.episode_buffer[i]
125
+
126
+ self.perf_metrics['travel_dist'] = max(travel_dist_list)
127
+ self.perf_metrics['explored_rate'] = self.env.explored_rate
128
+ self.perf_metrics['targets_found'] = self.env.targets_found_rate
129
+ self.perf_metrics['success_rate'] = done
130
+
131
+ # save gif
132
+ if self.save_image:
133
+ path = GIFS_PATH
134
+ self.make_gif(path, curr_episode)
135
+
136
+ print(YELLOW, f"[Eps {curr_episode} Completed] Time Taken: {time()-eps_start:.2f}s, Steps: {step+1}", NC)
137
+
138
+
139
+ def get_observations(self, robot_position):
140
+ """ Get robot's sensor observation of environment given position """
141
+ current_node_index = self.env.find_index_from_coords(robot_position)
142
+ current_index = torch.tensor([current_node_index]).unsqueeze(0).unsqueeze(0).to(self.device) # (1,1,1)
143
+
144
+ node_coords = copy.deepcopy(self.env.node_coords)
145
+ graph = copy.deepcopy(self.env.graph)
146
+ node_utility = copy.deepcopy(self.env.node_utility)
147
+ guidepost = copy.deepcopy(self.env.guidepost)
148
+ segmentation_info_mask = copy.deepcopy(self.env.filtered_seg_info_mask)
149
+
150
+ n_nodes = node_coords.shape[0]
151
+ node_coords = node_coords / 640
152
+ node_utility = node_utility / 50
153
+
154
+ node_utility_inputs = node_utility.reshape((n_nodes, 1))
155
+
156
+ occupied_node = np.zeros((n_nodes, 1))
157
+ for position in self.all_robot_positions:
158
+ index = self.env.find_index_from_coords(position)
159
+ if index == current_index.item():
160
+ occupied_node[index] = -1
161
+ else:
162
+ occupied_node[index] = 1
163
+
164
+ node_inputs = np.concatenate((node_coords, segmentation_info_mask, guidepost), axis=1)
165
+ node_inputs = torch.FloatTensor(node_inputs).unsqueeze(0).to(self.device) # (1, node_padding_size+1, 3)
166
+
167
+ assert node_coords.shape[0] < self.node_padding_size
168
+ padding = torch.nn.ZeroPad2d((0, 0, 0, self.node_padding_size - node_coords.shape[0]))
169
+ node_inputs = padding(node_inputs)
170
+
171
+ node_padding_mask = torch.zeros((1, 1, node_coords.shape[0]), dtype=torch.int64).to(self.device)
172
+ node_padding = torch.ones((1, 1, self.node_padding_size - node_coords.shape[0]), dtype=torch.int64).to(
173
+ self.device)
174
+ node_padding_mask = torch.cat((node_padding_mask, node_padding), dim=-1)
175
+
176
+ graph = list(graph.values())
177
+ edge_inputs = []
178
+ for node in graph:
179
+ node_edges = list(map(int, node))
180
+ edge_inputs.append(node_edges)
181
+
182
+ bias_matrix = self.calculate_edge_mask(edge_inputs)
183
+ edge_mask = torch.from_numpy(bias_matrix).float().unsqueeze(0).to(self.device)
184
+
185
+ assert len(edge_inputs) < self.node_padding_size
186
+ padding = torch.nn.ConstantPad2d(
187
+ (0, self.node_padding_size - len(edge_inputs), 0, self.node_padding_size - len(edge_inputs)), 1)
188
+ edge_mask = padding(edge_mask)
189
+ padding2 = torch.nn.ZeroPad2d((0, 0, 0, self.node_padding_size - len(edge_inputs)))
190
+
191
+ for edges in edge_inputs:
192
+ while len(edges) < self.k_size:
193
+ edges.append(0)
194
+
195
+ edge_inputs = torch.tensor(edge_inputs).unsqueeze(0).to(self.device) # (1, node_padding_size+1, k_size)
196
+ edge_inputs = padding2(edge_inputs)
197
+
198
+ edge_padding_mask = torch.zeros((1, len(edge_inputs), K_SIZE), dtype=torch.int64).to(self.device)
199
+ one = torch.ones_like(edge_padding_mask, dtype=torch.int64).to(self.device)
200
+ edge_padding_mask = torch.where(edge_inputs == 0, one, edge_padding_mask)
201
+
202
+ observations = node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask
203
+ return observations
204
+
205
+
206
+ def select_node(self, observations):
207
+ """ Forward pass through policy to get next position to go to on map """
208
+ node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask = observations
209
+ with torch.no_grad():
210
+ logp_list = self.local_policy_net(node_inputs, edge_inputs, current_index, node_padding_mask,
211
+ edge_padding_mask, edge_mask)
212
+
213
+ if self.greedy:
214
+ action_index = torch.argmax(logp_list, dim=1).long()
215
+ else:
216
+ action_index = torch.multinomial(logp_list.exp(), 1).long().squeeze(1)
217
+
218
+ next_node_index = edge_inputs[:, current_index.item(), action_index.item()]
219
+
220
+ next_position = self.env.node_coords[next_node_index]
221
+
222
+ return next_position, action_index
223
+
224
+
225
+ def solve_conflict(self, arriving_sequence, next_position_list, dist_list):
226
+ """ Deconflict if 2 agents choose the same target position """
227
+ for j, [robot_id, next_position] in enumerate(zip(arriving_sequence, next_position_list)):
228
+ moving_robot = self.robot_list[robot_id]
229
+ # if next_position[0] + next_position[1] * 1j in (next_position_list[:, 0] + next_position_list[:, 1] * 1j)[:j]:
230
+ # dist_to_next_position = np.argsort(np.linalg.norm(self.env.node_coords - next_position, axis=1))
231
+ # k = 0
232
+ # while next_position[0] + next_position[1] * 1j in (next_position_list[:, 0] + next_position_list[:, 1] * 1j)[:j]:
233
+ # k += 1
234
+ # next_position = self.env.node_coords[dist_to_next_position[k]]
235
+
236
+ dist = np.linalg.norm(next_position - moving_robot.robot_position)
237
+ next_position_list[j] = next_position
238
+ dist_list[j] = dist
239
+ moving_robot.travel_dist += dist
240
+ moving_robot.robot_position = next_position
241
+
242
+ return next_position_list, dist_list
243
+
244
+
245
+ def work(self, currEpisode):
246
+ '''
247
+ Interacts with the environment. The agent gets either gradients or experience buffer
248
+ '''
249
+ self.run_episode(currEpisode)
250
+
251
+ def calculate_edge_mask(self, edge_inputs):
252
+ size = len(edge_inputs)
253
+ bias_matrix = np.ones((size, size))
254
+ for i in range(size):
255
+ for j in range(size):
256
+ if j in edge_inputs[i]:
257
+ bias_matrix[i][j] = 0
258
+ return bias_matrix
259
+
260
+
261
+ def make_gif(self, path, n):
262
+ """ Generate a gif given list of images """
263
+ with imageio.get_writer('{}/{}_target_rate_{:.2f}.gif'.format(path, n, self.env.targets_found_rate), mode='I',
264
+ fps=5) as writer:
265
+ for frame in self.env.frame_files:
266
+ image = imageio.imread(frame)
267
+ writer.append_data(image)
268
+ print('gif complete\n')
269
+
270
+ # Remove files
271
+ for filename in self.env.frame_files[:-1]:
272
+ os.remove(filename)