Spaces:
Running
on
Zero
Running
on
Zero
derektan
commited on
Commit
·
e330ebf
0
Parent(s):
Initial Commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +40 -0
- .gitignore +38 -0
- .vscode/launch.json +15 -0
- README.md +13 -0
- app.py +528 -0
- app_multimodal_inference.py +350 -0
- examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/0b9cc264-a2ba-44bd-8e41-0d01a6edd1e8.jpg +3 -0
- examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/12372063.mp3 +3 -0
- examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/277303_38.72364_-75.07749.jpg +3 -0
- examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg +3 -0
- examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/89758229.mp3 +3 -0
- examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/cc1ebaf9-899d-49f2-81c8-d452249a8087.jpg +3 -0
- examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/5041_-0.28573_-90.54837.jpg +3 -0
- examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/c834edf7-b073-4cd5-8726-9c88ebe943c0.jpg +3 -0
- examples/Animalia_Chordata_Mammalia_Artiodactyla_Cervidae_Cervus_nippon/1a51f978-c76d-4816-a716-473c21b5495c.jpg +3 -0
- examples/Animalia_Chordata_Mammalia_Artiodactyla_Cervidae_Cervus_nippon/245767_53.0076_-6.35201.jpg +3 -0
- examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg +3 -0
- examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/189318716.mp3 +3 -0
- examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg +3 -0
- examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg +3 -0
- examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg +3 -0
- examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/27423_35.64005_-121.17595.jpg +3 -0
- examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3123948.mp3 +3 -0
- examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3aac526d-c921-452a-af6a-cb4f2f52e2c4.jpg +3 -0
- examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/166631961.mp3 +3 -0
- examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/1b8064f8-7deb-4b30-98cd-69da98ba6a3d.jpg +3 -0
- examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/28871_-12.80255_-69.29999.jpg +3 -0
- examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/327e1f07-692b-4140-8a3e-bd098bc064ff.jpg +3 -0
- examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/388246_45.49036_7.14796.jpg +3 -0
- examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/59677071.mp3 +3 -0
- examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/340271_10.52832_-83.49678.jpg +3 -0
- examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/45193295.mp3 +3 -0
- examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/938aab7b-4509-4de7-afad-2c8ea51f4799.jpg +3 -0
- examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg +3 -0
- examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg +3 -0
- examples/metadata.json +173 -0
- inference/model/avs_rl_policy.pth +3 -0
- maps/example/masks_val/MSK_0001.png +3 -0
- maps/gpt4o/envs_val/MSK_0001.png +3 -0
- planner/env.py +610 -0
- planner/graph.py +167 -0
- planner/graph_generator.py +300 -0
- planner/model.py +312 -0
- planner/node.py +96 -0
- planner/robot.py +58 -0
- planner/sensor.py +128 -0
- planner/test_info_surfing.py +1071 -0
- planner/test_parameter.py +118 -0
- planner/test_worker.py +590 -0
- planner/worker.py +272 -0
.gitattributes
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.idea/
|
| 2 |
+
*.vscode/
|
| 3 |
+
|
| 4 |
+
*__pycache__/
|
| 5 |
+
*.pyc
|
| 6 |
+
*.pyo
|
| 7 |
+
*.pyd
|
| 8 |
+
|
| 9 |
+
*.pt
|
| 10 |
+
*.pth
|
| 11 |
+
*.tar.gz
|
| 12 |
+
*.zip
|
| 13 |
+
|
| 14 |
+
!**/train/
|
| 15 |
+
**/train/*
|
| 16 |
+
!**/train/saved
|
| 17 |
+
|
| 18 |
+
!**/inference/
|
| 19 |
+
**/inference/*
|
| 20 |
+
!**/inference/saved
|
| 21 |
+
|
| 22 |
+
!**/maps/
|
| 23 |
+
**/maps/*
|
| 24 |
+
!**/maps/example
|
| 25 |
+
!**/maps/gpt4o
|
| 26 |
+
!**/maps/lisa
|
| 27 |
+
|
| 28 |
+
# For taxabind_avs
|
| 29 |
+
**/dataset/
|
| 30 |
+
**/checkpoints/
|
| 31 |
+
|
| 32 |
+
!**/lightning_logs/
|
| 33 |
+
**/lightning_logs/*
|
| 34 |
+
!**/lightning_logs/saved
|
| 35 |
+
|
| 36 |
+
# Saved weights & logs
|
| 37 |
+
**avs_rl_policy.pth
|
| 38 |
+
**/avs_rl_policy_21.5k/*
|
.vscode/launch.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"version": "0.2.0",
|
| 3 |
+
"configurations": [
|
| 4 |
+
{
|
| 5 |
+
"name": "Debug app.py",
|
| 6 |
+
"type": "debugpy",
|
| 7 |
+
"request": "launch",
|
| 8 |
+
"program": "${workspaceFolder}/app.py",
|
| 9 |
+
"cwd": "${workspaceFolder}",
|
| 10 |
+
"console": "integratedTerminal",
|
| 11 |
+
"justMyCode": false,
|
| 12 |
+
"python": "/home/user/anaconda3/envs/vlm-search/bin/python3"
|
| 13 |
+
}
|
| 14 |
+
]
|
| 15 |
+
}
|
README.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Search-TTA
|
| 3 |
+
emoji: 🦁
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: gray
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.31.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
short_description: Multimodal Test-time Adaptation Framework for Visual Search
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simplified Gradio demo for Search-TTA evaluation.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
# ────────────────────────── imports ───────────────────────────────────
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import matplotlib
|
| 8 |
+
matplotlib.use("Agg", force=True)
|
| 9 |
+
|
| 10 |
+
import gradio as gr
|
| 11 |
+
import ctypes # for safely stopping background threads
|
| 12 |
+
import os, glob, threading, time
|
| 13 |
+
import torch
|
| 14 |
+
from PIL import Image
|
| 15 |
+
import json
|
| 16 |
+
import shutil
|
| 17 |
+
import spaces # integration with ZeroGPU on hf
|
| 18 |
+
from planner.test_parameter import *
|
| 19 |
+
from planner.model import PolicyNet
|
| 20 |
+
from planner.test_worker import TestWorker
|
| 21 |
+
from taxabind_avs.satbind.clip_seg_tta import ClipSegTTA
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Helper to kill a Python thread by injecting SystemExit
|
| 25 |
+
def _stop_thread(thread: threading.Thread):
|
| 26 |
+
"""Forcefully raise SystemExit in the given thread (best-effort)."""
|
| 27 |
+
if thread is None or not thread.is_alive():
|
| 28 |
+
return
|
| 29 |
+
tid = thread.ident
|
| 30 |
+
if tid is None:
|
| 31 |
+
return
|
| 32 |
+
# Ask CPython to raise SystemExit in the thread context
|
| 33 |
+
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), ctypes.py_object(SystemExit))
|
| 34 |
+
if res > 1:
|
| 35 |
+
# If it returned >1, cleanup and fail safe
|
| 36 |
+
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), None)
|
| 37 |
+
|
| 38 |
+
# ──────────── Thread Registry for Cleanup on Tab Switch ─────────────
|
| 39 |
+
_running_threads: list[threading.Thread] = []
|
| 40 |
+
_running_threads_lock = threading.Lock()
|
| 41 |
+
|
| 42 |
+
# Map worker threads to their ClipSegTTA instance so UI can read executing_tta flag
|
| 43 |
+
_thread_clip_map: dict[threading.Thread, ClipSegTTA] = {}
|
| 44 |
+
|
| 45 |
+
# ──────────── Run directory rotation ─────────────
|
| 46 |
+
RUN_HISTORY_LIMIT = 30 # keep at most this many timestamped run directories per instance
|
| 47 |
+
|
| 48 |
+
def _prune_old_run_dirs(base_dir: str, limit: int = RUN_HISTORY_LIMIT):
|
| 49 |
+
"""Delete oldest timestamp-named run directories leaving only *limit* of the newest ones."""
|
| 50 |
+
try:
|
| 51 |
+
dirs = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]
|
| 52 |
+
dirs.sort()
|
| 53 |
+
if len(dirs) > limit:
|
| 54 |
+
for obsolete in dirs[:-limit]:
|
| 55 |
+
shutil.rmtree(os.path.join(base_dir, obsolete), ignore_errors=True)
|
| 56 |
+
except Exception:
|
| 57 |
+
pass
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# CHANGE ME!
|
| 61 |
+
POLL_INTERVAL = 1.0 # For visualization
|
| 62 |
+
|
| 63 |
+
# Prepare the model
|
| 64 |
+
device = torch.device('cuda') if USE_GPU and torch.cuda.is_available() else torch.device('cpu')
|
| 65 |
+
policy_net = PolicyNet(INPUT_DIM, EMBEDDING_DIM).to(device)
|
| 66 |
+
script_dir = Path(__file__).resolve().parent
|
| 67 |
+
print("real_script_dir: ", script_dir)
|
| 68 |
+
checkpoint = torch.load(f'{MODEL_PATH}/{MODEL_NAME}')
|
| 69 |
+
policy_net.load_state_dict(checkpoint['policy_model'])
|
| 70 |
+
print('Model loaded!')
|
| 71 |
+
|
| 72 |
+
# Load metadata json
|
| 73 |
+
tgts_metadata_json_path = os.path.join(script_dir, "examples/metadata.json")
|
| 74 |
+
tgts_metadata = json.load(open(tgts_metadata_json_path))
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# ────────────────────────── Gradio process fn ─────────────────────────
|
| 78 |
+
|
| 79 |
+
### integration with ZeroGPU on hf
|
| 80 |
+
# @spaces.GPU
|
| 81 |
+
def process_search_tta(
|
| 82 |
+
sat_path: str | None,
|
| 83 |
+
ground_path: str | None,
|
| 84 |
+
taxonomy: str | None = None,
|
| 85 |
+
session_threads: list[threading.Thread] | None = None,
|
| 86 |
+
):
|
| 87 |
+
"""Run both TTA and non-TTA search episodes concurrently and stream both heat-maps."""
|
| 88 |
+
|
| 89 |
+
if session_threads is None:
|
| 90 |
+
session_threads = []
|
| 91 |
+
|
| 92 |
+
# Disable Run button and clear image/status outputs, hide sliders, clear frame states
|
| 93 |
+
yield (
|
| 94 |
+
gr.update(interactive=False),
|
| 95 |
+
gr.update(value=None),
|
| 96 |
+
gr.update(value=None),
|
| 97 |
+
gr.update(value="Initializing model…", visible=True),
|
| 98 |
+
gr.update(value="Initializing model…", visible=True),
|
| 99 |
+
gr.update(visible=False),
|
| 100 |
+
gr.update(visible=False),
|
| 101 |
+
[],
|
| 102 |
+
[],
|
| 103 |
+
session_threads,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# Bail early if satellite image missing
|
| 107 |
+
if sat_path is None:
|
| 108 |
+
yield (
|
| 109 |
+
gr.update(interactive=True),
|
| 110 |
+
gr.update(value=None),
|
| 111 |
+
gr.update(value=None),
|
| 112 |
+
gr.update(value="No satellite image provided.", visible=True),
|
| 113 |
+
gr.update(value="", visible=True),
|
| 114 |
+
gr.update(visible=False),
|
| 115 |
+
gr.update(visible=False),
|
| 116 |
+
[],
|
| 117 |
+
[],
|
| 118 |
+
session_threads,
|
| 119 |
+
)
|
| 120 |
+
return
|
| 121 |
+
|
| 122 |
+
# Prepare PIL images
|
| 123 |
+
sat_img = Image.open(sat_path).convert("RGB")
|
| 124 |
+
ground_img_pil = Image.open(ground_path).convert("RGB") if ground_path else None
|
| 125 |
+
|
| 126 |
+
# Lookup target positions metadata (may be empty)
|
| 127 |
+
tgt_positions = []
|
| 128 |
+
if taxonomy and taxonomy in tgts_metadata:
|
| 129 |
+
tgt_positions = [tuple(t) for t in tgts_metadata[taxonomy]["target_positions"]]
|
| 130 |
+
|
| 131 |
+
# Helper to build a TestWorker with/without TTA
|
| 132 |
+
def build_planner(enable_tta: bool, save_dir: str, clip_obj):
|
| 133 |
+
# Lazily (re)create a ClipSegTTA instance per thread if not provided
|
| 134 |
+
local_clip = clip_obj
|
| 135 |
+
if LOAD_AVS_BENCH and local_clip is None:
|
| 136 |
+
local_clip = ClipSegTTA(
|
| 137 |
+
img_dir=AVS_IMG_DIR,
|
| 138 |
+
imo_dir=AVS_IMO_DIR,
|
| 139 |
+
json_path=AVS_INAT_JSON_PATH,
|
| 140 |
+
sat_to_img_ids_path=AVS_SAT_TO_IMG_IDS_PATH,
|
| 141 |
+
sat_checkpoint_path=AVS_SAT_CHECKPOINT_PATH,
|
| 142 |
+
load_pretrained_hf_ckpt=AVS_LOAD_PRETRAINED_HF_CHECKPOINT,
|
| 143 |
+
blur_kernel = AVS_GAUSSIAN_BLUR_KERNEL,
|
| 144 |
+
sample_index=-1,
|
| 145 |
+
device=device,
|
| 146 |
+
sat_to_img_ids_json_is_train_dict=False,
|
| 147 |
+
tax_to_filter_val=QUERY_TAX,
|
| 148 |
+
load_model=USE_CLIP_PREDS,
|
| 149 |
+
query_modality=QUERY_MODALITY,
|
| 150 |
+
sound_dir = AVS_SOUND_DIR,
|
| 151 |
+
sound_checkpoint_path=AVS_SOUND_CHECKPOINT_PATH,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
if local_clip is not None:
|
| 155 |
+
# Feed inputs to ClipSegTTA copy
|
| 156 |
+
local_clip.img_paths = [ground_path] if ground_path else []
|
| 157 |
+
local_clip.imo_path = sat_path
|
| 158 |
+
local_clip.imgs = ([local_clip.dataset.img_transform(ground_img_pil).to(device)] if ground_img_pil else [])
|
| 159 |
+
local_clip.imo = local_clip.dataset.imo_transform(sat_img).to(device)
|
| 160 |
+
local_clip.sounds = []
|
| 161 |
+
local_clip.sound_ids = []
|
| 162 |
+
local_clip.species_name = taxonomy or ""
|
| 163 |
+
local_clip.gt_mask_name = taxonomy.replace(" ", "_") if taxonomy else ""
|
| 164 |
+
local_clip.target_positions = tgt_positions if tgt_positions else [(0, 0)]
|
| 165 |
+
|
| 166 |
+
planner = TestWorker(
|
| 167 |
+
meta_agent_id=0,
|
| 168 |
+
n_agent=1,
|
| 169 |
+
policy_net=policy_net,
|
| 170 |
+
global_step=-1,
|
| 171 |
+
device=device,
|
| 172 |
+
greedy=True,
|
| 173 |
+
save_image=SAVE_GIFS,
|
| 174 |
+
clip_seg_tta=local_clip,
|
| 175 |
+
)
|
| 176 |
+
planner.execute_tta = enable_tta
|
| 177 |
+
planner.gifs_path = save_dir
|
| 178 |
+
return planner
|
| 179 |
+
|
| 180 |
+
# ────────────── Per-run output directories ──────────────
|
| 181 |
+
# Ensure base directory exists
|
| 182 |
+
os.makedirs(GIFS_PATH, exist_ok=True)
|
| 183 |
+
|
| 184 |
+
run_id = time.strftime("%Y%m%d_%H%M%S") # unique timestamp
|
| 185 |
+
run_root = os.path.join(GIFS_PATH, run_id)
|
| 186 |
+
gifs_dir_tta = os.path.join(run_root, "with_tta")
|
| 187 |
+
gifs_dir_no = os.path.join(run_root, "no_tta")
|
| 188 |
+
|
| 189 |
+
os.makedirs(gifs_dir_tta, exist_ok=True)
|
| 190 |
+
os.makedirs(gifs_dir_no, exist_ok=True)
|
| 191 |
+
|
| 192 |
+
# House-keep old runs so we never keep more than RUN_HISTORY_LIMIT
|
| 193 |
+
_prune_old_run_dirs(GIFS_PATH, RUN_HISTORY_LIMIT)
|
| 194 |
+
|
| 195 |
+
# Shared dict to record if a thread hit an exception
|
| 196 |
+
error_flags = {"tta": False, "no": False}
|
| 197 |
+
|
| 198 |
+
def _planner_thread(enable_tta: bool, save_dir: str, clip_obj, key: str):
|
| 199 |
+
"""Prepare directory, build planner, run an episode, record errors."""
|
| 200 |
+
try:
|
| 201 |
+
planner = build_planner(enable_tta, save_dir, clip_obj)
|
| 202 |
+
_thread_clip_map[threading.current_thread()] = planner.clip_seg_tta
|
| 203 |
+
planner.run_episode(0)
|
| 204 |
+
except Exception as exc:
|
| 205 |
+
# Mark that this planner crashed so UI can show an error status
|
| 206 |
+
error_flags[key] = True
|
| 207 |
+
# Log full traceback so developers can debug via console logs
|
| 208 |
+
import traceback, sys
|
| 209 |
+
traceback.print_exc()
|
| 210 |
+
# Still exit the thread
|
| 211 |
+
return
|
| 212 |
+
|
| 213 |
+
# Launch both planners in background threads – preparation included
|
| 214 |
+
thread_tta = threading.Thread(
|
| 215 |
+
target=_planner_thread,
|
| 216 |
+
args=(True, gifs_dir_tta, None, "tta"),
|
| 217 |
+
daemon=True,
|
| 218 |
+
)
|
| 219 |
+
thread_no = threading.Thread(
|
| 220 |
+
target=_planner_thread,
|
| 221 |
+
args=(False, gifs_dir_no, None, "no"),
|
| 222 |
+
daemon=True,
|
| 223 |
+
)
|
| 224 |
+
# Track threads for this user session
|
| 225 |
+
session_threads.extend([thread_tta, thread_no])
|
| 226 |
+
thread_tta.start()
|
| 227 |
+
thread_no.start()
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
sent_tta: set[str] = set()
|
| 231 |
+
sent_no: set[str] = set()
|
| 232 |
+
last_tta = None
|
| 233 |
+
last_no = None
|
| 234 |
+
# Track previous status strings so we can emit updates when only the
|
| 235 |
+
# status (Running…/Done.) changes even if no new frame was produced.
|
| 236 |
+
# Previous status values so we can detect changes and yield updates
|
| 237 |
+
prev_status_tta = "Initializing model…"
|
| 238 |
+
prev_status_no = "Initializing model…"
|
| 239 |
+
|
| 240 |
+
try:
|
| 241 |
+
while thread_tta.is_alive() or thread_no.is_alive():
|
| 242 |
+
updated = False
|
| 243 |
+
# Collect new frames from TTA dir
|
| 244 |
+
pngs = glob.glob(os.path.join(gifs_dir_tta, "*.png"))
|
| 245 |
+
pngs.sort(key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
|
| 246 |
+
for fp in pngs:
|
| 247 |
+
if fp not in sent_tta:
|
| 248 |
+
# Ensure file is fully written (non-empty & readable)
|
| 249 |
+
try:
|
| 250 |
+
if os.path.getsize(fp) == 0:
|
| 251 |
+
continue
|
| 252 |
+
with open(fp, "rb") as fh:
|
| 253 |
+
fh.read(1)
|
| 254 |
+
except Exception:
|
| 255 |
+
# Skip this round; we'll retry next poll
|
| 256 |
+
continue
|
| 257 |
+
sent_tta.add(fp)
|
| 258 |
+
last_tta = fp
|
| 259 |
+
updated = True
|
| 260 |
+
# Collect new frames from no-TTA dir
|
| 261 |
+
pngs = glob.glob(os.path.join(gifs_dir_no, "*.png"))
|
| 262 |
+
pngs.sort(key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
|
| 263 |
+
for fp in pngs:
|
| 264 |
+
if fp not in sent_no:
|
| 265 |
+
try:
|
| 266 |
+
if os.path.getsize(fp) == 0:
|
| 267 |
+
continue
|
| 268 |
+
with open(fp, "rb") as fh:
|
| 269 |
+
fh.read(1)
|
| 270 |
+
except Exception:
|
| 271 |
+
continue
|
| 272 |
+
sent_no.add(fp)
|
| 273 |
+
last_no = fp
|
| 274 |
+
updated = True
|
| 275 |
+
|
| 276 |
+
# Determine status based on whether we already have a frame and whether
|
| 277 |
+
# the corresponding thread is still alive.
|
| 278 |
+
def _mk_status(last_frame, thread_alive, errored: bool, running_tta: bool=False):
|
| 279 |
+
if errored:
|
| 280 |
+
return "Error!"
|
| 281 |
+
if last_frame is None:
|
| 282 |
+
return "Initializing model…"
|
| 283 |
+
if not thread_alive:
|
| 284 |
+
return "Done."
|
| 285 |
+
return "Executing TTA (Scheduling GPUs)…" if running_tta else "Executing Planner…"
|
| 286 |
+
|
| 287 |
+
exec_tta_flag = False
|
| 288 |
+
if thread_tta.is_alive():
|
| 289 |
+
clip_obj = _thread_clip_map.get(thread_tta)
|
| 290 |
+
if clip_obj is not None and getattr(clip_obj, "executing_tta", False):
|
| 291 |
+
exec_tta_flag = True
|
| 292 |
+
|
| 293 |
+
status_tta = _mk_status(last_tta, thread_tta.is_alive(), error_flags["tta"], exec_tta_flag)
|
| 294 |
+
status_no = _mk_status(last_no, thread_no.is_alive(), error_flags["no"], False)
|
| 295 |
+
|
| 296 |
+
# Determine if we should reveal sliders (once corresponding thread has finished)
|
| 297 |
+
show_slider_tta = (not thread_tta.is_alive()) and (last_tta is not None)
|
| 298 |
+
show_slider_no = (not thread_no.is_alive()) and (last_no is not None)
|
| 299 |
+
|
| 300 |
+
# Build slider updates
|
| 301 |
+
slider_tta_upd = gr.update()
|
| 302 |
+
slider_no_upd = gr.update()
|
| 303 |
+
frames_tta_upd = gr.update()
|
| 304 |
+
frames_no_upd = gr.update()
|
| 305 |
+
|
| 306 |
+
if show_slider_tta:
|
| 307 |
+
n_tta_frames = max(len(sent_tta), 1)
|
| 308 |
+
slider_tta_upd = gr.update(visible=True, minimum=1, maximum=n_tta_frames, value=n_tta_frames)
|
| 309 |
+
frames_tta_upd = sorted(sent_tta, key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
|
| 310 |
+
if show_slider_no:
|
| 311 |
+
n_no_frames = max(len(sent_no), 1)
|
| 312 |
+
slider_no_upd = gr.update(visible=True, minimum=1, maximum=n_no_frames, value=n_no_frames)
|
| 313 |
+
frames_no_upd = sorted(sent_no, key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
|
| 314 |
+
|
| 315 |
+
# Emit update if we have a new frame OR status changed OR slider visibility changed
|
| 316 |
+
if (
|
| 317 |
+
updated
|
| 318 |
+
or status_tta != prev_status_tta
|
| 319 |
+
or status_no != prev_status_no
|
| 320 |
+
or show_slider_tta
|
| 321 |
+
or show_slider_no
|
| 322 |
+
):
|
| 323 |
+
yield (
|
| 324 |
+
gr.update(interactive=False),
|
| 325 |
+
last_tta,
|
| 326 |
+
last_no,
|
| 327 |
+
gr.update(value=status_tta, visible=True),
|
| 328 |
+
gr.update(value=status_no, visible=True),
|
| 329 |
+
slider_tta_upd,
|
| 330 |
+
slider_no_upd,
|
| 331 |
+
frames_tta_upd,
|
| 332 |
+
frames_no_upd,
|
| 333 |
+
session_threads,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
prev_status_tta = status_tta
|
| 337 |
+
prev_status_no = status_no
|
| 338 |
+
|
| 339 |
+
time.sleep(POLL_INTERVAL)
|
| 340 |
+
finally:
|
| 341 |
+
# Ensure background threads are stopped on cancel
|
| 342 |
+
for th in (thread_tta, thread_no):
|
| 343 |
+
if th.is_alive():
|
| 344 |
+
_stop_thread(th)
|
| 345 |
+
th.join(timeout=1)
|
| 346 |
+
|
| 347 |
+
# Remove finished threads from global registry
|
| 348 |
+
with _running_threads_lock:
|
| 349 |
+
# Clear session thread list
|
| 350 |
+
session_threads.clear()
|
| 351 |
+
|
| 352 |
+
# Small delay to ensure last frame files are fully flushed
|
| 353 |
+
time.sleep(0.2)
|
| 354 |
+
# One last scan after both threads have finished to catch any frame
|
| 355 |
+
# that may have been written just before termination but after the last
|
| 356 |
+
# polling iteration.
|
| 357 |
+
for fp in sorted(glob.glob(os.path.join(gifs_dir_tta, "*.png")), key=lambda p: int(os.path.splitext(os.path.basename(p))[0])):
|
| 358 |
+
if fp not in sent_tta:
|
| 359 |
+
sent_tta.add(fp)
|
| 360 |
+
last_tta = fp
|
| 361 |
+
for fp in sorted(glob.glob(os.path.join(gifs_dir_no, "*.png")), key=lambda p: int(os.path.splitext(os.path.basename(p))[0])):
|
| 362 |
+
if fp not in sent_no:
|
| 363 |
+
sent_no.add(fp)
|
| 364 |
+
last_no = fp
|
| 365 |
+
|
| 366 |
+
# Prepare frames list and slider configs
|
| 367 |
+
frames_tta = sorted(glob.glob(os.path.join(gifs_dir_tta, "*.png")), key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
|
| 368 |
+
frames_no = sorted(glob.glob(os.path.join(gifs_dir_no, "*.png")), key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
|
| 369 |
+
if last_tta is None and frames_tta:
|
| 370 |
+
last_tta = frames_tta[-1]
|
| 371 |
+
if last_no is None and frames_no:
|
| 372 |
+
last_no = frames_no[-1]
|
| 373 |
+
n_tta = len(frames_tta) or 1 # prevent zero-range slider
|
| 374 |
+
n_no = len(frames_no) or 1
|
| 375 |
+
|
| 376 |
+
# Final emit: re-enable button, hide statuses, show sliders set to last frame
|
| 377 |
+
yield (
|
| 378 |
+
gr.update(interactive=True),
|
| 379 |
+
last_tta,
|
| 380 |
+
last_no,
|
| 381 |
+
gr.update(visible=False),
|
| 382 |
+
gr.update(visible=False),
|
| 383 |
+
gr.update(visible=True, minimum=1, maximum=n_tta, value=n_tta),
|
| 384 |
+
gr.update(visible=True, minimum=1, maximum=n_no, value=n_no),
|
| 385 |
+
frames_tta,
|
| 386 |
+
frames_no,
|
| 387 |
+
session_threads,
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
# ────────────────────────── Gradio UI ─────────────────────────────────
|
| 392 |
+
with gr.Blocks(title="Search-TTA (Simplified)", theme=gr.themes.Base()) as demo:
|
| 393 |
+
|
| 394 |
+
gr.Markdown(
|
| 395 |
+
"""
|
| 396 |
+
# Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild Demo
|
| 397 |
+
Click on any of the <b>examples below</b> and run the <b>TTA demo</b>. Check out the <b>multimodal heatmap generation feature</b> by switching to the other tab above. <br>
|
| 398 |
+
Note that the model initialization, RL planner, and TTA updates are not fully optimized on GPU for this huggingface demo, and hence may experience some lag during execution. <br>
|
| 399 |
+
If you encounter an 'Error' status, refresh the browser and rerun the demo, or try again the next day. We will improve this in the future. <br>
|
| 400 |
+
<a href="https://search-tta.github.io">Project Website</a>
|
| 401 |
+
"""
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
with gr.Row(variant="panel"):
|
| 405 |
+
with gr.Column():
|
| 406 |
+
gr.Markdown("### Model Inputs")
|
| 407 |
+
sat_input = gr.Image(
|
| 408 |
+
label="Satellite Image",
|
| 409 |
+
sources=["upload"],
|
| 410 |
+
type="filepath",
|
| 411 |
+
height=320,
|
| 412 |
+
)
|
| 413 |
+
taxonomy_input = gr.Textbox(
|
| 414 |
+
label="Full Taxonomy Name (optional)",
|
| 415 |
+
placeholder="e.g. Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
|
| 416 |
+
)
|
| 417 |
+
ground_input = gr.Image(
|
| 418 |
+
label="Ground-level Image (optional)",
|
| 419 |
+
sources=["upload"],
|
| 420 |
+
type="filepath",
|
| 421 |
+
height=320,
|
| 422 |
+
)
|
| 423 |
+
run_btn = gr.Button("Run Search-TTA", variant="primary")
|
| 424 |
+
|
| 425 |
+
with gr.Column():
|
| 426 |
+
gr.Markdown("### Live Heatmap Output")
|
| 427 |
+
display_img_tta = gr.Image(label="Heatmap (TTA per 20 steps)", type="filepath", height=400) # 512
|
| 428 |
+
status_tta = gr.Markdown("")
|
| 429 |
+
slider_tta = gr.Slider(label="TTA Frame", minimum=1, maximum=1, step=1, value=1, visible=False)
|
| 430 |
+
|
| 431 |
+
display_img_no_tta = gr.Image(label="Heatmap (no TTA)", type="filepath", height=400) # 512
|
| 432 |
+
status_no_tta = gr.Markdown("")
|
| 433 |
+
slider_no = gr.Slider(label="No-TTA Frame", minimum=1, maximum=1, step=1, value=1, visible=False)
|
| 434 |
+
|
| 435 |
+
frames_state_tta = gr.State([])
|
| 436 |
+
frames_state_no = gr.State([])
|
| 437 |
+
session_threads_state = gr.State([])
|
| 438 |
+
|
| 439 |
+
# Slider callbacks (updates image when user drags slider)
|
| 440 |
+
def _show_frame(idx: int, frames: list[str]):
|
| 441 |
+
# Slider is 1-indexed; convert to 0-indexed list access
|
| 442 |
+
if 1 <= idx <= len(frames):
|
| 443 |
+
return frames[idx - 1]
|
| 444 |
+
return gr.update()
|
| 445 |
+
|
| 446 |
+
slider_tta.change(_show_frame, inputs=[slider_tta, frames_state_tta], outputs=display_img_tta)
|
| 447 |
+
slider_no.change(_show_frame, inputs=[slider_no, frames_state_no], outputs=display_img_no_tta)
|
| 448 |
+
|
| 449 |
+
# EXAMPLES
|
| 450 |
+
with gr.Row():
|
| 451 |
+
gr.Markdown("### Taxonomy")
|
| 452 |
+
with gr.Row():
|
| 453 |
+
gr.Examples(
|
| 454 |
+
examples=[
|
| 455 |
+
[
|
| 456 |
+
"examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg",
|
| 457 |
+
"examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg",
|
| 458 |
+
"Animalia Chordata Reptilia Squamata Varanidae Varanus salvator",
|
| 459 |
+
],
|
| 460 |
+
[
|
| 461 |
+
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg",
|
| 462 |
+
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg",
|
| 463 |
+
"Animalia Chordata Mammalia Carnivora Canidae Canis aureus",
|
| 464 |
+
],
|
| 465 |
+
[
|
| 466 |
+
"examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/340271_10.52832_-83.49678.jpg",
|
| 467 |
+
"examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/938aab7b-4509-4de7-afad-2c8ea51f4799.jpg",
|
| 468 |
+
"Animalia Chordata Reptilia Crocodylia Alligatoridae Caiman crocodilus",
|
| 469 |
+
],
|
| 470 |
+
[
|
| 471 |
+
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg",
|
| 472 |
+
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg",
|
| 473 |
+
"Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis",
|
| 474 |
+
],
|
| 475 |
+
],
|
| 476 |
+
inputs=[sat_input, ground_input, taxonomy_input],
|
| 477 |
+
outputs=[run_btn, display_img_tta, display_img_no_tta, status_tta, status_no_tta, slider_tta, slider_no, frames_state_tta, frames_state_no],
|
| 478 |
+
fn=process_search_tta,
|
| 479 |
+
cache_examples=False,
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
run_btn.click(
|
| 483 |
+
fn=process_search_tta,
|
| 484 |
+
inputs=[sat_input, ground_input, taxonomy_input, session_threads_state],
|
| 485 |
+
outputs=[run_btn, display_img_tta, display_img_no_tta, status_tta, status_no_tta, slider_tta, slider_no, frames_state_tta, frames_state_no, session_threads_state],
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
# Footer to point out to model and data from app page.
|
| 489 |
+
gr.Markdown(
|
| 490 |
+
"""
|
| 491 |
+
The satellite image CLIP encoder is fine-tuned using [Sentinel-2 Level 2A](https://docs.sentinel-hub.com/api/latest/data/sentinel-2-l2a/) satellite image and taxonomy images (with GPS locations) from [iNaturalist](https://inaturalist.org/). The sound CLIP encoder is fine-tuned with a subset of the same taxonomy images and their corresponding sounds from [iNaturalist](https://inaturalist.org/). Some of these iNaturalist data are also used in [Taxabind](https://arxiv.org/abs/2411.00683). Note that while some of the examples above result in poor probability distributions, they will be improved using our test-time adaptation framework during the search process.
|
| 492 |
+
"""
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
if __name__ == "__main__":
|
| 497 |
+
|
| 498 |
+
# Build UI with explicit Tabs so we can detect tab selection and clean up
|
| 499 |
+
from app_multimodal_inference import demo as multimodal_demo
|
| 500 |
+
|
| 501 |
+
with gr.Blocks() as root:
|
| 502 |
+
with gr.Tabs() as tabs:
|
| 503 |
+
with gr.TabItem("Multimodal Inference"):
|
| 504 |
+
multimodal_demo.render()
|
| 505 |
+
with gr.TabItem("Search-TTA"):
|
| 506 |
+
demo.render()
|
| 507 |
+
|
| 508 |
+
# Hidden textbox purely to satisfy Gradio's need for an output component.
|
| 509 |
+
_cleanup_status = gr.Textbox(visible=False)
|
| 510 |
+
|
| 511 |
+
outputs_on_tab = [_cleanup_status]
|
| 512 |
+
|
| 513 |
+
def _on_tab_change(evt: gr.SelectData, session_threads: list[threading.Thread]):
|
| 514 |
+
# evt.value contains the name of the newly-selected tab.
|
| 515 |
+
if evt.value == "Multimodal Inference":
|
| 516 |
+
# Stop only threads started in this session
|
| 517 |
+
for th in list(session_threads):
|
| 518 |
+
if th is not None and th.is_alive():
|
| 519 |
+
_stop_thread(th)
|
| 520 |
+
th.join(timeout=1)
|
| 521 |
+
session_threads.clear()
|
| 522 |
+
return "Stopped running Search-TTA threads."
|
| 523 |
+
return ""
|
| 524 |
+
|
| 525 |
+
tabs.select(_on_tab_change, inputs=[session_threads_state], outputs=outputs_on_tab)
|
| 526 |
+
|
| 527 |
+
root.queue(max_size=15)
|
| 528 |
+
root.launch(share=True)
|
app_multimodal_inference.py
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Search-TTA multimodal heatmap generation demo
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
# ────────────────────────── imports ───────────────────────────────────
|
| 6 |
+
import cv2
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
import io
|
| 13 |
+
import torchaudio
|
| 14 |
+
import spaces # integration with ZeroGPU on hf
|
| 15 |
+
|
| 16 |
+
from torchvision import transforms
|
| 17 |
+
import open_clip
|
| 18 |
+
from taxabind_avs.satbind.clip_vision_per_patch_model import CLIPVisionPerPatchModel
|
| 19 |
+
from transformers import ClapAudioModelWithProjection
|
| 20 |
+
from transformers import ClapProcessor
|
| 21 |
+
|
| 22 |
+
# ────────────────────────── global config & models ────────────────────
|
| 23 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 24 |
+
|
| 25 |
+
# BioCLIP (ground-image & text encoder)
|
| 26 |
+
bio_model, _, _ = open_clip.create_model_and_transforms("hf-hub:imageomics/bioclip")
|
| 27 |
+
bio_model = bio_model.to(device).eval()
|
| 28 |
+
bio_tokenizer = open_clip.get_tokenizer("hf-hub:imageomics/bioclip")
|
| 29 |
+
|
| 30 |
+
# Satellite patch encoder CLIP-L-336 per-patch)
|
| 31 |
+
sat_model: CLIPVisionPerPatchModel = (
|
| 32 |
+
CLIPVisionPerPatchModel.from_pretrained("derektan95/search-tta-sat")
|
| 33 |
+
.to(device)
|
| 34 |
+
.eval()
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Sound CLAP model
|
| 38 |
+
sound_model: ClapAudioModelWithProjection = (
|
| 39 |
+
ClapAudioModelWithProjection.from_pretrained("derektan95/search-tta-sound")
|
| 40 |
+
.to(device)
|
| 41 |
+
.eval()
|
| 42 |
+
)
|
| 43 |
+
sound_processor: ClapProcessor = ClapProcessor.from_pretrained("derektan95/search-tta-sound")
|
| 44 |
+
SAMPLE_RATE = 48000
|
| 45 |
+
|
| 46 |
+
logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 47 |
+
logit_scale = logit_scale.exp()
|
| 48 |
+
blur_kernel = (5,5)
|
| 49 |
+
|
| 50 |
+
# ────────────────────────── transforms (exact spec) ───────────────────
|
| 51 |
+
img_transform = transforms.Compose(
|
| 52 |
+
[
|
| 53 |
+
transforms.Resize((256, 256)),
|
| 54 |
+
transforms.CenterCrop((224, 224)),
|
| 55 |
+
transforms.ToTensor(),
|
| 56 |
+
transforms.Normalize(
|
| 57 |
+
mean=[0.485, 0.456, 0.406],
|
| 58 |
+
std=[0.229, 0.224, 0.225],
|
| 59 |
+
),
|
| 60 |
+
]
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
imo_transform = transforms.Compose(
|
| 64 |
+
[
|
| 65 |
+
transforms.Resize((336, 336)),
|
| 66 |
+
transforms.ToTensor(),
|
| 67 |
+
transforms.Normalize(
|
| 68 |
+
mean=[0.485, 0.456, 0.406],
|
| 69 |
+
std=[0.229, 0.224, 0.225],
|
| 70 |
+
),
|
| 71 |
+
]
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def get_audio_clap(path_to_audio,format="mp3",padding="repeatpad",truncation="fusion"):
|
| 75 |
+
track, sr = torchaudio.load(path_to_audio, format=format) # torchaudio.load(path_to_audio)
|
| 76 |
+
track = track.mean(axis=0)
|
| 77 |
+
track = torchaudio.functional.resample(track, orig_freq=sr, new_freq=SAMPLE_RATE)
|
| 78 |
+
output = sound_processor(audios=track, sampling_rate=SAMPLE_RATE, max_length_s=10, return_tensors="pt",padding=padding,truncation=truncation)
|
| 79 |
+
return output
|
| 80 |
+
|
| 81 |
+
# ────────────────────────── helpers ───────────────────────────────────
|
| 82 |
+
|
| 83 |
+
@torch.no_grad()
|
| 84 |
+
def _encode_ground(img_pil: Image.Image) -> torch.Tensor:
|
| 85 |
+
img = img_transform(img_pil).unsqueeze(0).to(device)
|
| 86 |
+
img_embeds, *_ = bio_model(img)
|
| 87 |
+
return img_embeds
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@torch.no_grad()
|
| 91 |
+
def _encode_text(text: str) -> torch.Tensor:
|
| 92 |
+
toks = bio_tokenizer(text).to(device)
|
| 93 |
+
_, txt_embeds, _ = bio_model(text=toks)
|
| 94 |
+
return txt_embeds
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@torch.no_grad()
|
| 98 |
+
def _encode_sat(img_pil: Image.Image) -> torch.Tensor:
|
| 99 |
+
imo = imo_transform(img_pil).unsqueeze(0).to(device)
|
| 100 |
+
imo_embeds = sat_model(imo)
|
| 101 |
+
return imo_embeds
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@torch.no_grad()
|
| 105 |
+
def _encode_sound(sound) -> torch.Tensor:
|
| 106 |
+
processed_sound = get_audio_clap(sound)
|
| 107 |
+
for k in processed_sound.keys():
|
| 108 |
+
processed_sound[k] = processed_sound[k].to(device)
|
| 109 |
+
unnormalized_audio_embeds = sound_model(**processed_sound).audio_embeds
|
| 110 |
+
sound_embeds = torch.nn.functional.normalize(unnormalized_audio_embeds, dim=-1)
|
| 111 |
+
return sound_embeds
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _similarity_heatmap(query: torch.Tensor, patches: torch.Tensor) -> np.ndarray:
|
| 115 |
+
sims = torch.matmul(query, patches.t()) * logit_scale
|
| 116 |
+
sims = sims.t().sigmoid()
|
| 117 |
+
sims = sims[1:].squeeze() # drop CLS token
|
| 118 |
+
side = int(np.sqrt(len(sims)))
|
| 119 |
+
sims = sims.reshape(side, side)
|
| 120 |
+
return sims.cpu().detach().numpy()
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def _array_to_pil(arr: np.ndarray) -> Image.Image:
|
| 124 |
+
"""
|
| 125 |
+
Render arr with viridis, automatically stretching its own min→max to 0→1
|
| 126 |
+
so that the most-similar patches appear yellow.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
# Gausian Smoothing
|
| 130 |
+
if blur_kernel != (0,0):
|
| 131 |
+
arr = cv2.GaussianBlur(arr, blur_kernel, 0)
|
| 132 |
+
|
| 133 |
+
# --- contrast-stretch to local 0-1 range --------------------------
|
| 134 |
+
arr_min, arr_max = float(arr.min()), float(arr.max())
|
| 135 |
+
if arr_max - arr_min < 1e-6: # avoid /0 when the heat-map is flat
|
| 136 |
+
arr_scaled = np.zeros_like(arr)
|
| 137 |
+
else:
|
| 138 |
+
arr_scaled = (arr - arr_min) / (arr_max - arr_min)
|
| 139 |
+
# ------------------------------------------------------------------
|
| 140 |
+
fig, ax = plt.subplots(figsize=(2.6, 2.6), dpi=96)
|
| 141 |
+
ax.imshow(arr_scaled, cmap="viridis", vmin=0.0, vmax=1.0)
|
| 142 |
+
ax.axis("off")
|
| 143 |
+
buf = io.BytesIO()
|
| 144 |
+
plt.tight_layout(pad=0)
|
| 145 |
+
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
|
| 146 |
+
plt.close(fig)
|
| 147 |
+
buf.seek(0)
|
| 148 |
+
return Image.open(buf)
|
| 149 |
+
|
| 150 |
+
# ────────────────────────── main inference ────────────────────────────
|
| 151 |
+
# integration with ZeroGPU on hf
|
| 152 |
+
@spaces.GPU(duration=5)
|
| 153 |
+
def process(
|
| 154 |
+
sat_img: Image.Image,
|
| 155 |
+
taxonomy: str,
|
| 156 |
+
ground_img: Image.Image | None,
|
| 157 |
+
sound: torch.Tensor | None,
|
| 158 |
+
):
|
| 159 |
+
if sat_img is None:
|
| 160 |
+
return None, None
|
| 161 |
+
|
| 162 |
+
patches = _encode_sat(sat_img)
|
| 163 |
+
|
| 164 |
+
heat_ground, heat_text, heat_sound = None, None, None
|
| 165 |
+
|
| 166 |
+
if ground_img is not None:
|
| 167 |
+
q_img = _encode_ground(ground_img)
|
| 168 |
+
heat_ground = _array_to_pil(_similarity_heatmap(q_img, patches))
|
| 169 |
+
|
| 170 |
+
if taxonomy.strip():
|
| 171 |
+
q_txt = _encode_text(taxonomy.strip())
|
| 172 |
+
heat_text = _array_to_pil(_similarity_heatmap(q_txt, patches))
|
| 173 |
+
|
| 174 |
+
if sound is not None:
|
| 175 |
+
q_sound = _encode_sound(sound)
|
| 176 |
+
heat_sound = _array_to_pil(_similarity_heatmap(q_sound, patches))
|
| 177 |
+
|
| 178 |
+
return heat_ground, heat_text, heat_sound
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# ────────────────────────── Gradio UI ─────────────────────────────────
|
| 182 |
+
with gr.Blocks(title="Search-TTA", theme=gr.themes.Base()) as demo:
|
| 183 |
+
|
| 184 |
+
gr.Markdown(
|
| 185 |
+
"""
|
| 186 |
+
# Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild Demo
|
| 187 |
+
Click on any of the <b>examples below</b> and run the <b>multimodal inference demo</b>. Check out the <b>test-time adaptation feature</b> by switching to the other tab above. <br>
|
| 188 |
+
If you encounter any errors, refresh the browser and rerun the demo, or try again the next day. We will improve this in the future. <br>
|
| 189 |
+
<a href="https://search-tta.github.io">Project Website</a>
|
| 190 |
+
"""
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
with gr.Row(variant="panel"):
|
| 194 |
+
|
| 195 |
+
# LEFT COLUMN (satellite, taxonomy, run)
|
| 196 |
+
with gr.Column():
|
| 197 |
+
sat_input = gr.Image(
|
| 198 |
+
label="Satellite Image",
|
| 199 |
+
sources=["upload"],
|
| 200 |
+
type="pil",
|
| 201 |
+
height=320,
|
| 202 |
+
)
|
| 203 |
+
taxonomy_input = gr.Textbox(
|
| 204 |
+
label="Full Taxonomy Name (optional)",
|
| 205 |
+
placeholder="e.g. Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# ─── NEW: sound input ───────────────────────────
|
| 209 |
+
sound_input = gr.Audio(
|
| 210 |
+
label="Sound Input (optional)",
|
| 211 |
+
sources=["upload"],
|
| 212 |
+
type="filepath",
|
| 213 |
+
)
|
| 214 |
+
run_btn = gr.Button("Run", variant="primary")
|
| 215 |
+
|
| 216 |
+
# RIGHT COLUMN (ground image + two heat-maps)
|
| 217 |
+
with gr.Column():
|
| 218 |
+
ground_input = gr.Image(
|
| 219 |
+
label="Ground-level Image (optional)",
|
| 220 |
+
sources=["upload"],
|
| 221 |
+
type="pil",
|
| 222 |
+
height=320,
|
| 223 |
+
)
|
| 224 |
+
gr.Markdown("### Heat-map Results")
|
| 225 |
+
with gr.Row():
|
| 226 |
+
# Separate label and image to avoid overlap
|
| 227 |
+
with gr.Column(scale=1, min_width=100):
|
| 228 |
+
gr.Markdown("**Ground Image Query**", elem_id="label-ground")
|
| 229 |
+
heat_ground_out = gr.Image(
|
| 230 |
+
show_label=False,
|
| 231 |
+
height=160,
|
| 232 |
+
)
|
| 233 |
+
with gr.Column(scale=1, min_width=100):
|
| 234 |
+
gr.Markdown("**Text Query**", elem_id="label-text")
|
| 235 |
+
heat_text_out = gr.Image(
|
| 236 |
+
show_label=False,
|
| 237 |
+
height=160,
|
| 238 |
+
)
|
| 239 |
+
with gr.Column(scale=1, min_width=100):
|
| 240 |
+
gr.Markdown("**Sound Query**", elem_id="label-sound")
|
| 241 |
+
heat_sound_out = gr.Image(
|
| 242 |
+
show_label=False,
|
| 243 |
+
height=160,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
# EXAMPLES
|
| 248 |
+
with gr.Row():
|
| 249 |
+
gr.Markdown("### In-Domain Taxonomy")
|
| 250 |
+
with gr.Row():
|
| 251 |
+
gr.Examples(
|
| 252 |
+
examples=[
|
| 253 |
+
[
|
| 254 |
+
"examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg",
|
| 255 |
+
"examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/cc1ebaf9-899d-49f2-81c8-d452249a8087.jpg",
|
| 256 |
+
"Animalia Chordata Aves Charadriiformes Laridae Larus marinus",
|
| 257 |
+
"examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/89758229.mp3"
|
| 258 |
+
],
|
| 259 |
+
[
|
| 260 |
+
"examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/28871_-12.80255_-69.29999.jpg",
|
| 261 |
+
"examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/1b8064f8-7deb-4b30-98cd-69da98ba6a3d.jpg",
|
| 262 |
+
"Animalia Chordata Mammalia Rodentia Caviidae Hydrochoerus hydrochaeris",
|
| 263 |
+
"examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/166631961.mp3"
|
| 264 |
+
],
|
| 265 |
+
[
|
| 266 |
+
"examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/277303_38.72364_-75.07749.jpg",
|
| 267 |
+
"examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/0b9cc264-a2ba-44bd-8e41-0d01a6edd1e8.jpg",
|
| 268 |
+
"Animalia Arthropoda Malacostraca Decapoda Ocypodidae Ocypode quadrata",
|
| 269 |
+
"examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/12372063.mp3"
|
| 270 |
+
],
|
| 271 |
+
[
|
| 272 |
+
"examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/388246_45.49036_7.14796.jpg",
|
| 273 |
+
"examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/327e1f07-692b-4140-8a3e-bd098bc064ff.jpg",
|
| 274 |
+
"Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
|
| 275 |
+
"examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/59677071.mp3"
|
| 276 |
+
],
|
| 277 |
+
[
|
| 278 |
+
"examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg",
|
| 279 |
+
"examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg",
|
| 280 |
+
"Animalia Chordata Reptilia Squamata Varanidae Varanus salvator",
|
| 281 |
+
None
|
| 282 |
+
],
|
| 283 |
+
],
|
| 284 |
+
inputs=[sat_input, ground_input, taxonomy_input, sound_input],
|
| 285 |
+
outputs=[heat_ground_out, heat_text_out, heat_sound_out],
|
| 286 |
+
fn=process,
|
| 287 |
+
cache_examples=False,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# EXAMPLES
|
| 291 |
+
with gr.Row():
|
| 292 |
+
gr.Markdown("### Out-Domain Taxonomy")
|
| 293 |
+
with gr.Row():
|
| 294 |
+
gr.Examples(
|
| 295 |
+
examples=[
|
| 296 |
+
[
|
| 297 |
+
"examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/27423_35.64005_-121.17595.jpg",
|
| 298 |
+
"examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3aac526d-c921-452a-af6a-cb4f2f52e2c4.jpg",
|
| 299 |
+
"Animalia Chordata Mammalia Carnivora Phocidae Mirounga angustirostris",
|
| 300 |
+
"examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3123948.mp3"
|
| 301 |
+
],
|
| 302 |
+
[
|
| 303 |
+
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg",
|
| 304 |
+
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg",
|
| 305 |
+
"Animalia Chordata Mammalia Carnivora Canidae Canis aureus",
|
| 306 |
+
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/189318716.mp3"
|
| 307 |
+
],
|
| 308 |
+
[
|
| 309 |
+
"examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/340271_10.52832_-83.49678.jpg",
|
| 310 |
+
"examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/938aab7b-4509-4de7-afad-2c8ea51f4799.jpg",
|
| 311 |
+
"Animalia Chordata Reptilia Crocodylia Alligatoridae Caiman crocodilus",
|
| 312 |
+
"examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/45193295.mp3"
|
| 313 |
+
],
|
| 314 |
+
[
|
| 315 |
+
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg",
|
| 316 |
+
"examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg",
|
| 317 |
+
"Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis",
|
| 318 |
+
None
|
| 319 |
+
],
|
| 320 |
+
[
|
| 321 |
+
"examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/5041_-0.28573_-90.54837.jpg",
|
| 322 |
+
"examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/c834edf7-b073-4cd5-8726-9c88ebe943c0.jpg",
|
| 323 |
+
"Animalia Chordata Elasmobranchii Carcharhiniformes Carcharhinidae Triaenodon obesus",
|
| 324 |
+
None
|
| 325 |
+
],
|
| 326 |
+
],
|
| 327 |
+
inputs=[sat_input, ground_input, taxonomy_input, sound_input],
|
| 328 |
+
outputs=[heat_ground_out, heat_text_out, heat_sound_out],
|
| 329 |
+
fn=process,
|
| 330 |
+
cache_examples=False,
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
# CALLBACK
|
| 334 |
+
run_btn.click(
|
| 335 |
+
fn=process,
|
| 336 |
+
inputs=[sat_input, taxonomy_input, ground_input, sound_input],
|
| 337 |
+
outputs=[heat_ground_out, heat_text_out, heat_sound_out],
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
# Footer to point out to model and data from app page.
|
| 341 |
+
gr.Markdown(
|
| 342 |
+
"""
|
| 343 |
+
The satellite image CLIP encoder is fine-tuned using [Sentinel-2 Level 2A](https://docs.sentinel-hub.com/api/latest/data/sentinel-2-l2a/) satellite image and taxonomy images (with GPS locations) from [iNaturalist](https://inaturalist.org/). The sound CLIP encoder is fine-tuned with a subset of the same taxonomy images and their corresponding sounds from [iNaturalist](https://inaturalist.org/). Some of these iNaturalist data are also used in [Taxabind](https://arxiv.org/abs/2411.00683). Note that while some of the examples above result in poor probability distributions, they will be improved using our test-time adaptation framework during the search process.
|
| 344 |
+
"""
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
# LAUNCH
|
| 348 |
+
if __name__ == "__main__":
|
| 349 |
+
demo.queue(max_size=15)
|
| 350 |
+
demo.launch(share=True)
|
examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/0b9cc264-a2ba-44bd-8e41-0d01a6edd1e8.jpg
ADDED
|
Git LFS Details
|
examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/12372063.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:575959883981159f2e40593bf5be87be006026c41da36a34d1e40783de648116
|
| 3 |
+
size 54027
|
examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/277303_38.72364_-75.07749.jpg
ADDED
|
Git LFS Details
|
examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg
ADDED
|
Git LFS Details
|
examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/89758229.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e2c7ad6df49668d29f9b7f9f9f0739b97ef4edc5219413a41d01983a9863cccc
|
| 3 |
+
size 2601487
|
examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/cc1ebaf9-899d-49f2-81c8-d452249a8087.jpg
ADDED
|
Git LFS Details
|
examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/5041_-0.28573_-90.54837.jpg
ADDED
|
Git LFS Details
|
examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/c834edf7-b073-4cd5-8726-9c88ebe943c0.jpg
ADDED
|
Git LFS Details
|
examples/Animalia_Chordata_Mammalia_Artiodactyla_Cervidae_Cervus_nippon/1a51f978-c76d-4816-a716-473c21b5495c.jpg
ADDED
|
Git LFS Details
|
examples/Animalia_Chordata_Mammalia_Artiodactyla_Cervidae_Cervus_nippon/245767_53.0076_-6.35201.jpg
ADDED
|
Git LFS Details
|
examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg
ADDED
|
Git LFS Details
|
examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/189318716.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4639c226ea5a0464b98e89b33a1f821b6625c6637d206d3d355e05bc7c89c641
|
| 3 |
+
size 148019
|
examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg
ADDED
|
Git LFS Details
|
examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg
ADDED
|
Git LFS Details
|
examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg
ADDED
|
Git LFS Details
|
examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/27423_35.64005_-121.17595.jpg
ADDED
|
Git LFS Details
|
examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3123948.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:96ca3a92e6f614cce82972dacb04f5c0c170c1aea3d70d15778af56820ed02c9
|
| 3 |
+
size 276768
|
examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3aac526d-c921-452a-af6a-cb4f2f52e2c4.jpg
ADDED
|
Git LFS Details
|
examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/166631961.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cc02eca19d0c408d038e205d82f6624c0515858ac374cf7298161a14e169e6a9
|
| 3 |
+
size 266258
|
examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/1b8064f8-7deb-4b30-98cd-69da98ba6a3d.jpg
ADDED
|
Git LFS Details
|
examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/28871_-12.80255_-69.29999.jpg
ADDED
|
Git LFS Details
|
examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/327e1f07-692b-4140-8a3e-bd098bc064ff.jpg
ADDED
|
Git LFS Details
|
examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/388246_45.49036_7.14796.jpg
ADDED
|
Git LFS Details
|
examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/59677071.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bb043991fe851d6a1e12f32c5a9277dad5a77a939cf15ccb4afcb215b4bc08e3
|
| 3 |
+
size 92876
|
examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/340271_10.52832_-83.49678.jpg
ADDED
|
Git LFS Details
|
examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/45193295.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c4cd2e4fd7094a07d79da7fd54788705e8ce7567e65911d87edfd23ff1c0e484
|
| 3 |
+
size 247762
|
examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/938aab7b-4509-4de7-afad-2c8ea51f4799.jpg
ADDED
|
Git LFS Details
|
examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg
ADDED
|
Git LFS Details
|
examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg
ADDED
|
Git LFS Details
|
examples/metadata.json
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"Animalia Chordata Reptilia Squamata Varanidae Varanus salvator": {
|
| 3 |
+
"id": 410613,
|
| 4 |
+
"sat_key": "410613_5.35573_100.28948",
|
| 5 |
+
"sat_path": "410613_5.35573_100.28948.jpg",
|
| 6 |
+
"taxonomy": "Animalia Chordata Reptilia Squamata Varanidae Varanus salvator",
|
| 7 |
+
"count": 6,
|
| 8 |
+
"spread": 58.00460580210422,
|
| 9 |
+
"sat_bounds": {
|
| 10 |
+
"min_lat": 5.344155081363914,
|
| 11 |
+
"max_lat": 5.367304914271601,
|
| 12 |
+
"min_lon": 100.27793148340874,
|
| 13 |
+
"max_lon": 100.30102851659126
|
| 14 |
+
},
|
| 15 |
+
"img_ids": [
|
| 16 |
+
707815,
|
| 17 |
+
411949,
|
| 18 |
+
701168,
|
| 19 |
+
1619682,
|
| 20 |
+
2100008,
|
| 21 |
+
1548498
|
| 22 |
+
],
|
| 23 |
+
"target_positions": [
|
| 24 |
+
[
|
| 25 |
+
225,
|
| 26 |
+
240
|
| 27 |
+
],
|
| 28 |
+
[
|
| 29 |
+
232,
|
| 30 |
+
275
|
| 31 |
+
],
|
| 32 |
+
[
|
| 33 |
+
277,
|
| 34 |
+
449
|
| 35 |
+
],
|
| 36 |
+
[
|
| 37 |
+
220,
|
| 38 |
+
369
|
| 39 |
+
],
|
| 40 |
+
[
|
| 41 |
+
180,
|
| 42 |
+
393
|
| 43 |
+
],
|
| 44 |
+
[
|
| 45 |
+
294,
|
| 46 |
+
478
|
| 47 |
+
]
|
| 48 |
+
],
|
| 49 |
+
"num_landmarks": 2
|
| 50 |
+
},
|
| 51 |
+
"Animalia Chordata Mammalia Carnivora Canidae Canis aureus": {
|
| 52 |
+
"id": 1528408,
|
| 53 |
+
"sat_key": "1528408_13.00422_80.23033",
|
| 54 |
+
"sat_path": "1528408_13.00422_80.23033.jpg",
|
| 55 |
+
"taxonomy": "Animalia Chordata Mammalia Carnivora Canidae Canis aureus",
|
| 56 |
+
"count": 3,
|
| 57 |
+
"spread": 58.14007011752667,
|
| 58 |
+
"sat_bounds": {
|
| 59 |
+
"min_lat": 12.992649951077192,
|
| 60 |
+
"max_lat": 13.015790038631529,
|
| 61 |
+
"min_lon": 80.21853090802841,
|
| 62 |
+
"max_lon": 80.24212909197156
|
| 63 |
+
},
|
| 64 |
+
"img_ids": [
|
| 65 |
+
1528479,
|
| 66 |
+
2555188,
|
| 67 |
+
2555189
|
| 68 |
+
],
|
| 69 |
+
"target_positions": [
|
| 70 |
+
[
|
| 71 |
+
309,
|
| 72 |
+
128
|
| 73 |
+
],
|
| 74 |
+
[
|
| 75 |
+
239,
|
| 76 |
+
428
|
| 77 |
+
],
|
| 78 |
+
[
|
| 79 |
+
240,
|
| 80 |
+
419
|
| 81 |
+
]
|
| 82 |
+
],
|
| 83 |
+
"num_landmarks": 3
|
| 84 |
+
},
|
| 85 |
+
"Animalia Chordata Reptilia Crocodylia Alligatoridae Caiman crocodilus": {
|
| 86 |
+
"id": 340271,
|
| 87 |
+
"sat_key": "340271_10.52832_-83.49678",
|
| 88 |
+
"sat_path": "340271_10.52832_-83.49678.jpg",
|
| 89 |
+
"taxonomy": "Animalia Chordata Reptilia Crocodylia Alligatoridae Caiman crocodilus",
|
| 90 |
+
"count": 7,
|
| 91 |
+
"spread": 40.13902957324975,
|
| 92 |
+
"sat_bounds": {
|
| 93 |
+
"min_lat": 10.516747947357544,
|
| 94 |
+
"max_lat": 10.53989204420829,
|
| 95 |
+
"min_lon": -83.50847402265151,
|
| 96 |
+
"max_lon": -83.48508597734848
|
| 97 |
+
},
|
| 98 |
+
"img_ids": [
|
| 99 |
+
1683531,
|
| 100 |
+
1281855,
|
| 101 |
+
223089,
|
| 102 |
+
688111,
|
| 103 |
+
330757,
|
| 104 |
+
2408375,
|
| 105 |
+
1955359
|
| 106 |
+
],
|
| 107 |
+
"target_positions": [
|
| 108 |
+
[
|
| 109 |
+
347,
|
| 110 |
+
75
|
| 111 |
+
],
|
| 112 |
+
[
|
| 113 |
+
47,
|
| 114 |
+
22
|
| 115 |
+
],
|
| 116 |
+
[
|
| 117 |
+
111,
|
| 118 |
+
43
|
| 119 |
+
],
|
| 120 |
+
[
|
| 121 |
+
116,
|
| 122 |
+
51
|
| 123 |
+
],
|
| 124 |
+
[
|
| 125 |
+
86,
|
| 126 |
+
108
|
| 127 |
+
],
|
| 128 |
+
[
|
| 129 |
+
31,
|
| 130 |
+
62
|
| 131 |
+
],
|
| 132 |
+
[
|
| 133 |
+
4,
|
| 134 |
+
78
|
| 135 |
+
]
|
| 136 |
+
],
|
| 137 |
+
"num_landmarks": 3
|
| 138 |
+
},
|
| 139 |
+
"Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis": {
|
| 140 |
+
"id": 304160,
|
| 141 |
+
"sat_key": "304160_34.0144_-119.54417",
|
| 142 |
+
"sat_path": "304160_34.0144_-119.54417.jpg",
|
| 143 |
+
"taxonomy": "Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis",
|
| 144 |
+
"count": 3,
|
| 145 |
+
"spread": 237.64152837579553,
|
| 146 |
+
"sat_bounds": {
|
| 147 |
+
"min_lat": 34.00286041606169,
|
| 148 |
+
"max_lat": 34.02593956225012,
|
| 149 |
+
"min_lon": -119.55802743361286,
|
| 150 |
+
"max_lon": -119.53031256638712
|
| 151 |
+
},
|
| 152 |
+
"img_ids": [
|
| 153 |
+
304160,
|
| 154 |
+
1473173,
|
| 155 |
+
384867
|
| 156 |
+
],
|
| 157 |
+
"target_positions": [
|
| 158 |
+
[
|
| 159 |
+
255,
|
| 160 |
+
256
|
| 161 |
+
],
|
| 162 |
+
[
|
| 163 |
+
19,
|
| 164 |
+
22
|
| 165 |
+
],
|
| 166 |
+
[
|
| 167 |
+
29,
|
| 168 |
+
274
|
| 169 |
+
]
|
| 170 |
+
],
|
| 171 |
+
"num_landmarks": 3
|
| 172 |
+
}
|
| 173 |
+
}
|
inference/model/avs_rl_policy.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:44e642df9aaa2847ba44dd4707985c67ef712f5264272ef7993aeb7805c80f5a
|
| 3 |
+
size 52167246
|
maps/example/masks_val/MSK_0001.png
ADDED
|
Git LFS Details
|
maps/gpt4o/envs_val/MSK_0001.png
ADDED
|
Git LFS Details
|
planner/env.py
ADDED
|
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#######################################################################
|
| 2 |
+
# Name: env.py
|
| 3 |
+
#
|
| 4 |
+
# - Reads and processes training and test maps
|
| 5 |
+
# - Processes rewards, new frontiers given action
|
| 6 |
+
# - Updates a graph representation of environment for input into network
|
| 7 |
+
#######################################################################
|
| 8 |
+
|
| 9 |
+
import sys
|
| 10 |
+
if sys.modules['TRAINING']:
|
| 11 |
+
from .parameter import *
|
| 12 |
+
else:
|
| 13 |
+
from .test_parameter import *
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import cv2
|
| 17 |
+
import copy
|
| 18 |
+
import matplotlib.image as mpimg
|
| 19 |
+
import matplotlib.pyplot as plt
|
| 20 |
+
from skimage import io
|
| 21 |
+
from skimage.measure import block_reduce
|
| 22 |
+
from scipy.ndimage import label, find_objects
|
| 23 |
+
from .sensor import *
|
| 24 |
+
from .graph_generator import *
|
| 25 |
+
from .node import *
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Env():
|
| 29 |
+
def __init__(self, map_index, n_agent, k_size=20, plot=False, test=False, mask_index=None):
|
| 30 |
+
self.n_agent = n_agent
|
| 31 |
+
self.test = test
|
| 32 |
+
self.map_dir = GRIDMAP_SET_DIR
|
| 33 |
+
|
| 34 |
+
# Import environment gridmap
|
| 35 |
+
self.map_list = os.listdir(self.map_dir)
|
| 36 |
+
self.map_list.sort(reverse=True)
|
| 37 |
+
|
| 38 |
+
# NEW: Import segmentation utility map
|
| 39 |
+
self.seg_dir = MASK_SET_DIR
|
| 40 |
+
self.segmentation_mask, self.target_positions, self.target_found_idxs = None, [], []
|
| 41 |
+
self.segmentation_mask_list = os.listdir(self.seg_dir)
|
| 42 |
+
self.segmentation_mask_list.sort(reverse=True)
|
| 43 |
+
|
| 44 |
+
# # NEW: Find common files in both directories
|
| 45 |
+
self.map_index = map_index % len(self.map_list)
|
| 46 |
+
if mask_index is not None:
|
| 47 |
+
self.mask_index = mask_index % len(self.segmentation_mask_list)
|
| 48 |
+
else:
|
| 49 |
+
self.mask_index = map_index % len(self.segmentation_mask_list)
|
| 50 |
+
|
| 51 |
+
# Import ground truth and segmentation mask
|
| 52 |
+
self.ground_truth, self.map_start_position = self.import_ground_truth(
|
| 53 |
+
os.path.join(self.map_dir, self.map_list[self.map_index]))
|
| 54 |
+
self.ground_truth_size = np.shape(self.ground_truth)
|
| 55 |
+
self.robot_belief = np.ones(self.ground_truth_size) * 127 # unexplored 127
|
| 56 |
+
self.downsampled_belief = None
|
| 57 |
+
self.old_robot_belief = copy.deepcopy(self.robot_belief)
|
| 58 |
+
self.coverage_belief = np.ones(self.ground_truth_size) * 127 # unexplored 127
|
| 59 |
+
|
| 60 |
+
# Import segmentation mask
|
| 61 |
+
mask_filename = self.segmentation_mask_list[self.mask_index]
|
| 62 |
+
self.segmentation_mask = self.import_segmentation_mask(
|
| 63 |
+
os.path.join(self.seg_dir, mask_filename))
|
| 64 |
+
|
| 65 |
+
# Overwrite target positions if directory specified
|
| 66 |
+
if self.test and TARGETS_SET_DIR != "":
|
| 67 |
+
self.target_positions = self.import_targets(
|
| 68 |
+
os.path.join(TARGETS_SET_DIR, self.map_list[self.map_index]))
|
| 69 |
+
|
| 70 |
+
self.segmentation_info_mask = None
|
| 71 |
+
self.segmentation_info_mask_unnormalized = None
|
| 72 |
+
self.filtered_seg_info_mask = None
|
| 73 |
+
self.num_targets_found = 0
|
| 74 |
+
self.num_new_targets_found = 0
|
| 75 |
+
self.resolution = 4
|
| 76 |
+
self.sensor_range = SENSOR_RANGE
|
| 77 |
+
self.explored_rate = 0
|
| 78 |
+
self.targets_found_rate = 0
|
| 79 |
+
self.frontiers = None
|
| 80 |
+
self.start_positions = []
|
| 81 |
+
self.plot = plot
|
| 82 |
+
self.frame_files = []
|
| 83 |
+
self.graph_generator = Graph_generator(map_size=self.ground_truth_size, sensor_range=self.sensor_range, k_size=k_size, plot=plot)
|
| 84 |
+
self.node_coords, self.graph, self.node_utility, self.guidepost = None, None, None, None
|
| 85 |
+
|
| 86 |
+
self.begin(self.map_start_position)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def find_index_from_coords(self, position):
|
| 90 |
+
index = np.argmin(np.linalg.norm(self.node_coords - position, axis=1))
|
| 91 |
+
return index
|
| 92 |
+
|
| 93 |
+
def begin(self, start_position):
|
| 94 |
+
self.robot_belief = self.ground_truth
|
| 95 |
+
self.downsampled_belief = block_reduce(self.robot_belief.copy(), block_size=(self.resolution, self.resolution), func=np.min)
|
| 96 |
+
self.frontiers = self.find_frontier()
|
| 97 |
+
self.old_robot_belief = copy.deepcopy(self.robot_belief)
|
| 98 |
+
|
| 99 |
+
self.node_coords, self.graph, self.node_utility, self.guidepost = self.graph_generator.generate_graph(
|
| 100 |
+
self.robot_belief, self.frontiers)
|
| 101 |
+
|
| 102 |
+
# Define start positions
|
| 103 |
+
if FIX_START_POSITION:
|
| 104 |
+
coords_res_row = int(self.robot_belief.shape[0]/NUM_COORDS_HEIGHT)
|
| 105 |
+
coords_res_col = int(self.robot_belief.shape[1]/NUM_COORDS_WIDTH)
|
| 106 |
+
self.start_positions = [(int(self.robot_belief.shape[1]/2)-coords_res_col/2,int(self.robot_belief.shape[0]/2)-coords_res_row/2) for _ in range(self.n_agent)]
|
| 107 |
+
else:
|
| 108 |
+
nearby_coords = self.graph_generator.get_neighbors_grid_coords(start_position)
|
| 109 |
+
itr = 0
|
| 110 |
+
for i in range(self.n_agent):
|
| 111 |
+
if i == 0 or len(nearby_coords) == 0:
|
| 112 |
+
self.start_positions.append(start_position)
|
| 113 |
+
else:
|
| 114 |
+
idx = min(itr, len(nearby_coords)-1)
|
| 115 |
+
self.start_positions.append(nearby_coords[idx])
|
| 116 |
+
itr += 1
|
| 117 |
+
|
| 118 |
+
for i in range(len(self.start_positions)):
|
| 119 |
+
self.start_positions[i] = self.node_coords[self.find_index_from_coords(self.start_positions[i])]
|
| 120 |
+
self.coverage_belief = self.update_robot_belief(self.start_positions[i], self.sensor_range, self.coverage_belief,
|
| 121 |
+
self.ground_truth)
|
| 122 |
+
|
| 123 |
+
for start_position in self.start_positions:
|
| 124 |
+
self.graph_generator.route_node.append(start_position)
|
| 125 |
+
|
| 126 |
+
# Info map from ground truth
|
| 127 |
+
rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH)
|
| 128 |
+
rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT)
|
| 129 |
+
self.segmentation_info_mask = np.zeros((len(self.node_coords), 1))
|
| 130 |
+
for i, node_coord in enumerate(self.node_coords):
|
| 131 |
+
max_x = min(node_coord[0] + int(math.ceil(rng_x)), self.ground_truth.shape[1])
|
| 132 |
+
min_x = max(node_coord[0] - int(math.ceil(rng_x)), 0)
|
| 133 |
+
max_y = min(node_coord[1] + int(math.ceil(rng_y)), self.ground_truth.shape[0])
|
| 134 |
+
min_y = max(node_coord[1] - int(math.ceil(rng_y)), 0)
|
| 135 |
+
|
| 136 |
+
if TARGETS_SET_DIR == "":
|
| 137 |
+
exclude = {208} # Exclude target positions
|
| 138 |
+
else:
|
| 139 |
+
exclude = {}
|
| 140 |
+
self.segmentation_info_mask[i] = max(x for x in self.segmentation_mask[min_y:max_y, min_x:max_x].flatten() if x not in exclude) / 100.0
|
| 141 |
+
|
| 142 |
+
self.filtered_seg_info_mask = copy.deepcopy(self.segmentation_info_mask)
|
| 143 |
+
done, num_targets_found = self.check_done()
|
| 144 |
+
self.num_targets_found = num_targets_found
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def multi_robot_step(self, next_position_list, dist_list, travel_dist_list):
|
| 148 |
+
reward_list = []
|
| 149 |
+
for dist, robot_position in zip(dist_list, next_position_list):
|
| 150 |
+
self.graph_generator.route_node.append(robot_position)
|
| 151 |
+
next_node_index = self.find_index_from_coords(robot_position)
|
| 152 |
+
self.graph_generator.nodes_list[next_node_index].set_visited()
|
| 153 |
+
self.coverage_belief = self.update_robot_belief(robot_position, self.sensor_range, self.coverage_belief,
|
| 154 |
+
self.ground_truth)
|
| 155 |
+
self.robot_belief = self.ground_truth
|
| 156 |
+
self.downsampled_belief = block_reduce(self.robot_belief.copy(),
|
| 157 |
+
block_size=(self.resolution, self.resolution),
|
| 158 |
+
func=np.min)
|
| 159 |
+
|
| 160 |
+
frontiers = self.find_frontier()
|
| 161 |
+
individual_reward = -dist / 32
|
| 162 |
+
|
| 163 |
+
info_gain_reward = 0
|
| 164 |
+
robot_position_idx = self.find_index_from_coords(robot_position)
|
| 165 |
+
info_gain_reward = self.filtered_seg_info_mask[robot_position_idx][0] * 1.5
|
| 166 |
+
if self.guidepost[robot_position_idx] == 0.0:
|
| 167 |
+
info_gain_reward += 0.2
|
| 168 |
+
individual_reward += info_gain_reward
|
| 169 |
+
|
| 170 |
+
reward_list.append(individual_reward)
|
| 171 |
+
|
| 172 |
+
self.node_coords, self.graph, self.node_utility, self.guidepost = self.graph_generator.update_graph(self.robot_belief, self.old_robot_belief, frontiers, self.frontiers)
|
| 173 |
+
self.old_robot_belief = copy.deepcopy(self.robot_belief)
|
| 174 |
+
|
| 175 |
+
self.filtered_seg_info_mask = [info[0] if self.guidepost[i] == 0.0 else 0.0 for i, info in enumerate(self.segmentation_info_mask)]
|
| 176 |
+
self.filtered_seg_info_mask = np.expand_dims(np.array(self.filtered_seg_info_mask), axis=1)
|
| 177 |
+
|
| 178 |
+
self.frontiers = frontiers
|
| 179 |
+
self.explored_rate = self.evaluate_exploration_rate()
|
| 180 |
+
|
| 181 |
+
done, num_targets_found = self.check_done()
|
| 182 |
+
self.num_new_targets_found = num_targets_found - self.num_targets_found
|
| 183 |
+
team_reward = 0.0
|
| 184 |
+
|
| 185 |
+
self.num_targets_found = num_targets_found
|
| 186 |
+
self.targets_found_rate = self.evaluate_targets_found_rate()
|
| 187 |
+
|
| 188 |
+
if done:
|
| 189 |
+
team_reward += 40
|
| 190 |
+
for i in range(len(reward_list)):
|
| 191 |
+
reward_list[i] += team_reward
|
| 192 |
+
|
| 193 |
+
return reward_list, done
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def import_ground_truth(self, map_index):
|
| 197 |
+
# occupied 1, free 255, unexplored 127
|
| 198 |
+
|
| 199 |
+
try:
|
| 200 |
+
ground_truth = (io.imread(map_index, 1)).astype(int)
|
| 201 |
+
if np.all(ground_truth == 0):
|
| 202 |
+
ground_truth = (io.imread(map_index, 1) * 255).astype(int)
|
| 203 |
+
except:
|
| 204 |
+
new_map_index = self.map_dir + '/' + self.map_list[0]
|
| 205 |
+
ground_truth = (io.imread(new_map_index, 1)).astype(int)
|
| 206 |
+
print('could not read the map_path ({}), hence skipping it and using ({}).'.format(map_index, new_map_index))
|
| 207 |
+
|
| 208 |
+
robot_location = np.nonzero(ground_truth == 208)
|
| 209 |
+
robot_location = np.array([np.array(robot_location)[1, 127], np.array(robot_location)[0, 127]])
|
| 210 |
+
ground_truth = (ground_truth > 150)
|
| 211 |
+
ground_truth = ground_truth * 254 + 1
|
| 212 |
+
return ground_truth, robot_location
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def import_segmentation_mask(self, map_index):
|
| 216 |
+
mask = cv2.imread(map_index).astype(int)
|
| 217 |
+
return mask
|
| 218 |
+
|
| 219 |
+
def import_targets(self, map_index):
|
| 220 |
+
# occupied 1, free 255, unexplored 127, target 208
|
| 221 |
+
mask = cv2.imread(map_index).astype(int)
|
| 222 |
+
target_positions = self.find_target_locations(mask)
|
| 223 |
+
return target_positions
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def find_target_locations(self, image_array, grey_value=208):
|
| 227 |
+
|
| 228 |
+
grey_pixels = np.where(image_array == grey_value)
|
| 229 |
+
binary_array = np.zeros_like(image_array, dtype=bool)
|
| 230 |
+
binary_array[grey_pixels] = True
|
| 231 |
+
labeled_array, num_features = label(binary_array)
|
| 232 |
+
slices = find_objects(labeled_array)
|
| 233 |
+
|
| 234 |
+
# Calculate the center of each box
|
| 235 |
+
centers = []
|
| 236 |
+
for slice in slices:
|
| 237 |
+
row_center = (slice[0].start + slice[0].stop - 1) // 2
|
| 238 |
+
col_center = (slice[1].start + slice[1].stop - 1) // 2
|
| 239 |
+
centers.append((col_center, row_center)) # (y,x)
|
| 240 |
+
|
| 241 |
+
return centers
|
| 242 |
+
|
| 243 |
+
def free_cells(self):
|
| 244 |
+
index = np.where(self.ground_truth == 255)
|
| 245 |
+
free = np.asarray([index[1], index[0]]).T
|
| 246 |
+
return free
|
| 247 |
+
|
| 248 |
+
def update_robot_belief(self, robot_position, sensor_range, robot_belief, ground_truth):
|
| 249 |
+
robot_belief = sensor_work(robot_position, sensor_range, robot_belief, ground_truth)
|
| 250 |
+
return robot_belief
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def check_done(self):
|
| 254 |
+
done = False
|
| 255 |
+
num_targets_found = 0
|
| 256 |
+
self.target_found_idxs = []
|
| 257 |
+
for i, target in enumerate(self.target_positions):
|
| 258 |
+
if self.coverage_belief[target[1], target[0]] == 255:
|
| 259 |
+
num_targets_found += 1
|
| 260 |
+
self.target_found_idxs.append(i)
|
| 261 |
+
|
| 262 |
+
if TERMINATE_ON_TGTS_FOUND and num_targets_found >= len(self.target_positions):
|
| 263 |
+
done = True
|
| 264 |
+
if not TERMINATE_ON_TGTS_FOUND and np.sum(self.coverage_belief == 255) / np.sum(self.ground_truth == 255) >= 0.99:
|
| 265 |
+
done = True
|
| 266 |
+
|
| 267 |
+
return done, num_targets_found
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def calculate_num_observed_frontiers(self, old_frontiers, frontiers):
|
| 271 |
+
frontiers_to_check = frontiers[:, 0] + frontiers[:, 1] * 1j
|
| 272 |
+
pre_frontiers_to_check = old_frontiers[:, 0] + old_frontiers[:, 1] * 1j
|
| 273 |
+
frontiers_num = np.intersect1d(frontiers_to_check, pre_frontiers_to_check).shape[0]
|
| 274 |
+
pre_frontiers_num = pre_frontiers_to_check.shape[0]
|
| 275 |
+
delta_num = pre_frontiers_num - frontiers_num
|
| 276 |
+
|
| 277 |
+
return delta_num
|
| 278 |
+
|
| 279 |
+
def calculate_reward(self, dist, frontiers):
|
| 280 |
+
reward = 0
|
| 281 |
+
reward -= dist / 64
|
| 282 |
+
|
| 283 |
+
frontiers_to_check = frontiers[:, 0] + frontiers[:, 1] * 1j
|
| 284 |
+
pre_frontiers_to_check = self.frontiers[:, 0] + self.frontiers[:, 1] * 1j
|
| 285 |
+
frontiers_num = np.intersect1d(frontiers_to_check, pre_frontiers_to_check).shape[0]
|
| 286 |
+
pre_frontiers_num = pre_frontiers_to_check.shape[0]
|
| 287 |
+
delta_num = pre_frontiers_num - frontiers_num
|
| 288 |
+
|
| 289 |
+
reward += delta_num / 50
|
| 290 |
+
|
| 291 |
+
return reward
|
| 292 |
+
|
| 293 |
+
def evaluate_exploration_rate(self):
|
| 294 |
+
rate = np.sum(self.coverage_belief == 255) / np.sum(self.ground_truth == 255)
|
| 295 |
+
return rate
|
| 296 |
+
|
| 297 |
+
def evaluate_targets_found_rate(self):
|
| 298 |
+
if len(self.target_positions) == 0:
|
| 299 |
+
return 0
|
| 300 |
+
else:
|
| 301 |
+
rate = self.num_targets_found / len(self.target_positions)
|
| 302 |
+
return rate
|
| 303 |
+
|
| 304 |
+
def calculate_new_free_area(self):
|
| 305 |
+
old_free_area = self.old_robot_belief == 255
|
| 306 |
+
current_free_area = self.robot_belief == 255
|
| 307 |
+
|
| 308 |
+
new_free_area = (current_free_area.astype(np.int) - old_free_area.astype(np.int)) * 255
|
| 309 |
+
|
| 310 |
+
return new_free_area, np.sum(old_free_area)
|
| 311 |
+
|
| 312 |
+
def calculate_dist_path(self, path):
|
| 313 |
+
dist = 0
|
| 314 |
+
start = path[0]
|
| 315 |
+
end = path[-1]
|
| 316 |
+
for index in path:
|
| 317 |
+
if index == end:
|
| 318 |
+
break
|
| 319 |
+
dist += np.linalg.norm(self.node_coords[start] - self.node_coords[index])
|
| 320 |
+
start = index
|
| 321 |
+
return dist
|
| 322 |
+
|
| 323 |
+
def find_frontier(self):
|
| 324 |
+
y_len = self.downsampled_belief.shape[0]
|
| 325 |
+
x_len = self.downsampled_belief.shape[1]
|
| 326 |
+
mapping = self.downsampled_belief.copy()
|
| 327 |
+
belief = self.downsampled_belief.copy()
|
| 328 |
+
# 0-1 unknown area map
|
| 329 |
+
mapping = (mapping == 127) * 1
|
| 330 |
+
mapping = np.lib.pad(mapping, ((1, 1), (1, 1)), 'constant', constant_values=0)
|
| 331 |
+
fro_map = mapping[2:][:, 1:x_len + 1] + mapping[:y_len][:, 1:x_len + 1] + mapping[1:y_len + 1][:, 2:] + \
|
| 332 |
+
mapping[1:y_len + 1][:, :x_len] + mapping[:y_len][:, 2:] + mapping[2:][:, :x_len] + mapping[2:][:,
|
| 333 |
+
2:] + \
|
| 334 |
+
mapping[:y_len][:, :x_len]
|
| 335 |
+
ind_free = np.where(belief.ravel(order='F') == 255)[0]
|
| 336 |
+
ind_fron_1 = np.where(1 < fro_map.ravel(order='F'))[0]
|
| 337 |
+
ind_fron_2 = np.where(fro_map.ravel(order='F') < 8)[0]
|
| 338 |
+
ind_fron = np.intersect1d(ind_fron_1, ind_fron_2)
|
| 339 |
+
ind_to = np.intersect1d(ind_free, ind_fron)
|
| 340 |
+
|
| 341 |
+
map_x = x_len
|
| 342 |
+
map_y = y_len
|
| 343 |
+
x = np.linspace(0, map_x - 1, map_x)
|
| 344 |
+
y = np.linspace(0, map_y - 1, map_y)
|
| 345 |
+
t1, t2 = np.meshgrid(x, y)
|
| 346 |
+
points = np.vstack([t1.T.ravel(), t2.T.ravel()]).T
|
| 347 |
+
|
| 348 |
+
f = points[ind_to]
|
| 349 |
+
f = f.astype(int)
|
| 350 |
+
|
| 351 |
+
f = f * self.resolution
|
| 352 |
+
|
| 353 |
+
return f
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def plot_env(self, n, path, step, travel_dist, robots_route, img_path_override=None, sat_path_override=None, msk_name_override=None, sound_id_override=None):
|
| 358 |
+
|
| 359 |
+
plt.switch_backend('agg')
|
| 360 |
+
plt.cla()
|
| 361 |
+
color_list = ["r", "g", "c", "m", "y", "k"]
|
| 362 |
+
|
| 363 |
+
if not LOAD_AVS_BENCH:
|
| 364 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
|
| 365 |
+
else:
|
| 366 |
+
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(20, 5.5))
|
| 367 |
+
|
| 368 |
+
### Fig: Segmentation Mask ###
|
| 369 |
+
if LOAD_AVS_BENCH:
|
| 370 |
+
ax = ax1
|
| 371 |
+
image = mpimg.imread(img_path_override)
|
| 372 |
+
ax.imshow(image)
|
| 373 |
+
ax.set_title("Ground Image")
|
| 374 |
+
ax.axis("off")
|
| 375 |
+
|
| 376 |
+
### Fig: Environment ###
|
| 377 |
+
msk_name = ""
|
| 378 |
+
if LOAD_AVS_BENCH:
|
| 379 |
+
image = mpimg.imread(sat_path_override)
|
| 380 |
+
msk_name = msk_name_override
|
| 381 |
+
|
| 382 |
+
### Fig1: Environment ###
|
| 383 |
+
ax = ax2
|
| 384 |
+
ax.imshow(image)
|
| 385 |
+
ax.axis((0, self.ground_truth_size[1], self.ground_truth_size[0], 0))
|
| 386 |
+
ax.set_title("Image")
|
| 387 |
+
for i, route in enumerate(robots_route):
|
| 388 |
+
robot_marker_color = color_list[i % len(color_list)]
|
| 389 |
+
xPoints = route[0]
|
| 390 |
+
yPoints = route[1]
|
| 391 |
+
ax.plot(xPoints, yPoints, c=robot_marker_color, linewidth=2)
|
| 392 |
+
ax.plot(xPoints[-1], yPoints[-1], markersize=12, zorder=99, marker="^", ls="-", c=robot_marker_color, mec="black")
|
| 393 |
+
ax.plot(xPoints[0], yPoints[0], 'co', c=robot_marker_color, markersize=8, zorder=5)
|
| 394 |
+
|
| 395 |
+
# Sensor range
|
| 396 |
+
rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH)
|
| 397 |
+
rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT)
|
| 398 |
+
max_x = min(xPoints[-1] + int(math.ceil(rng_x)), self.ground_truth.shape[1])
|
| 399 |
+
min_x = max(xPoints[-1] - int(math.ceil(rng_x)), 0)
|
| 400 |
+
max_y = min(yPoints[-1] + int(math.ceil(rng_y)), self.ground_truth.shape[0])
|
| 401 |
+
min_y = max(yPoints[-1] - int(math.ceil(rng_y)), 0)
|
| 402 |
+
ax.plot((min_x, min_x), (min_y, max_y), c=robot_marker_color, linewidth=1)
|
| 403 |
+
ax.plot((min_x, max_x), (max_y, max_y), c=robot_marker_color, linewidth=1)
|
| 404 |
+
ax.plot((max_x, max_x), (max_y, min_y), c=robot_marker_color, linewidth=1)
|
| 405 |
+
ax.plot((max_x, min_x), (min_y, min_y), c=robot_marker_color, linewidth=1)
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
### Fig: Graph ###
|
| 409 |
+
ax = ax3 if LOAD_AVS_BENCH else ax1
|
| 410 |
+
ax.imshow(self.coverage_belief, cmap='gray')
|
| 411 |
+
ax.axis((0, self.ground_truth_size[1], self.ground_truth_size[0], 0))
|
| 412 |
+
ax.set_title("Information Graph")
|
| 413 |
+
if VIZ_GRAPH_EDGES:
|
| 414 |
+
for i in range(len(self.graph_generator.x)):
|
| 415 |
+
ax.plot(self.graph_generator.x[i], self.graph_generator.y[i], 'tan', zorder=1)
|
| 416 |
+
ax.scatter(self.node_coords[:, 0], self.node_coords[:, 1], c=self.filtered_seg_info_mask, zorder=5, s=8)
|
| 417 |
+
|
| 418 |
+
for i, route in enumerate(robots_route):
|
| 419 |
+
robot_marker_color = color_list[i % len(color_list)]
|
| 420 |
+
xPoints = route[0]
|
| 421 |
+
yPoints = route[1]
|
| 422 |
+
ax.plot(xPoints, yPoints, c=robot_marker_color, linewidth=2)
|
| 423 |
+
ax.plot(xPoints[-1], yPoints[-1], markersize=12, zorder=99, marker="^", ls="-", c=robot_marker_color, mec="black")
|
| 424 |
+
ax.plot(xPoints[0], yPoints[0], 'co', c=robot_marker_color, markersize=8, zorder=5)
|
| 425 |
+
|
| 426 |
+
# Sensor range
|
| 427 |
+
rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH)
|
| 428 |
+
rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT)
|
| 429 |
+
max_x = min(xPoints[-1] + int(math.ceil(rng_x)), self.ground_truth.shape[1])
|
| 430 |
+
min_x = max(xPoints[-1] - int(math.ceil(rng_x)), 0)
|
| 431 |
+
max_y = min(yPoints[-1] + int(math.ceil(rng_y)), self.ground_truth.shape[0])
|
| 432 |
+
min_y = max(yPoints[-1] - int(math.ceil(rng_y)), 0)
|
| 433 |
+
ax.plot((min_x, min_x), (min_y, max_y), c=robot_marker_color, linewidth=1)
|
| 434 |
+
ax.plot((min_x, max_x), (max_y, max_y), c=robot_marker_color, linewidth=1)
|
| 435 |
+
ax.plot((max_x, max_x), (max_y, min_y), c=robot_marker_color, linewidth=1)
|
| 436 |
+
ax.plot((max_x, min_x), (min_y, min_y), c=robot_marker_color, linewidth=1)
|
| 437 |
+
|
| 438 |
+
# Plot target positions
|
| 439 |
+
for target in self.target_positions:
|
| 440 |
+
if self.coverage_belief[target[1], target[0]] == 255:
|
| 441 |
+
ax.plot(target[0], target[1], color='g', marker='x', linestyle='-', markersize=12, markeredgewidth=4, zorder=99)
|
| 442 |
+
else:
|
| 443 |
+
ax.plot(target[0], target[1], color='r', marker='x', linestyle='-', markersize=12, markeredgewidth=4, zorder=99)
|
| 444 |
+
|
| 445 |
+
### Fig: Segmentation Mask ###
|
| 446 |
+
ax = ax4 if LOAD_AVS_BENCH else ax2
|
| 447 |
+
if LOAD_AVS_BENCH and USE_CLIP_PREDS:
|
| 448 |
+
H, W = self.ground_truth_size
|
| 449 |
+
mask_viz = self.segmentation_info_mask.squeeze().reshape((NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT)).T
|
| 450 |
+
im = ax.imshow(
|
| 451 |
+
mask_viz,
|
| 452 |
+
cmap="viridis",
|
| 453 |
+
origin="upper",
|
| 454 |
+
extent=[0, W, H, 0],
|
| 455 |
+
interpolation="nearest",
|
| 456 |
+
zorder=0,
|
| 457 |
+
)
|
| 458 |
+
ax.set_xlim(0, W)
|
| 459 |
+
ax.set_ylim(H, 0)
|
| 460 |
+
ax.set_axis_off()
|
| 461 |
+
else:
|
| 462 |
+
im = ax.imshow(self.segmentation_mask.mean(axis=-1), cmap='viridis', vmin=0, vmax=100) # cmap='gray'
|
| 463 |
+
ax.axis((0, self.ground_truth_size[1], self.ground_truth_size[0], 0))
|
| 464 |
+
ax.set_title(f"Predicted Mask (Normalized)")
|
| 465 |
+
for i, route in enumerate(robots_route):
|
| 466 |
+
robot_marker_color = color_list[i % len(color_list)]
|
| 467 |
+
xPoints = route[0]
|
| 468 |
+
yPoints = route[1]
|
| 469 |
+
ax.plot(xPoints, yPoints, c=robot_marker_color, linewidth=2)
|
| 470 |
+
ax.plot(xPoints[-1], yPoints[-1], markersize=12, zorder=99, marker="^", ls="-", c=robot_marker_color, mec="black")
|
| 471 |
+
ax.plot(xPoints[0], yPoints[0], 'co', c=robot_marker_color, markersize=8, zorder=5)
|
| 472 |
+
|
| 473 |
+
# Sensor range
|
| 474 |
+
rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH)
|
| 475 |
+
rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT)
|
| 476 |
+
max_x = min(xPoints[-1] + int(math.ceil(rng_x)), self.ground_truth.shape[1])
|
| 477 |
+
min_x = max(xPoints[-1] - int(math.ceil(rng_x)), 0)
|
| 478 |
+
max_y = min(yPoints[-1] + int(math.ceil(rng_y)), self.ground_truth.shape[0])
|
| 479 |
+
min_y = max(yPoints[-1] - int(math.ceil(rng_y)), 0)
|
| 480 |
+
ax.plot((min_x, min_x), (min_y, max_y), c=robot_marker_color, linewidth=1)
|
| 481 |
+
ax.plot((min_x, max_x), (max_y, max_y), c=robot_marker_color, linewidth=1)
|
| 482 |
+
ax.plot((max_x, max_x), (max_y, min_y), c=robot_marker_color, linewidth=1)
|
| 483 |
+
ax.plot((max_x, min_x), (min_y, min_y), c=robot_marker_color, linewidth=1)
|
| 484 |
+
|
| 485 |
+
# Add a colorbar
|
| 486 |
+
cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
| 487 |
+
cbar.set_label("Normalized Probs")
|
| 488 |
+
|
| 489 |
+
if sound_id_override is not None:
|
| 490 |
+
plt.suptitle('Targets Found: {}/{} Coverage ratio: {:.4g} Travel Dist: {:.4g} \n ({}) \n (Sound ID: {})'.format(self.num_targets_found, len(self.target_positions), self.explored_rate, travel_dist, msk_name, sound_id_override))
|
| 491 |
+
elif msk_name != "":
|
| 492 |
+
plt.suptitle('Targets Found: {}/{} Coverage ratio: {:.4g} Travel Dist: {:.4g} \n ({})'.format(self.num_targets_found, len(self.target_positions), self.explored_rate, travel_dist, msk_name))
|
| 493 |
+
else:
|
| 494 |
+
plt.suptitle('Targets Found: {}/{} Coverage ratio: {:.4g} Travel Dist: {:.4g}'.format(self.num_targets_found, len(self.target_positions), self.explored_rate, travel_dist))
|
| 495 |
+
|
| 496 |
+
plt.tight_layout()
|
| 497 |
+
plt.savefig('{}/{}_{}_samples.png'.format(path, n, step, dpi=100))
|
| 498 |
+
frame = '{}/{}_{}_samples.png'.format(path, n, step)
|
| 499 |
+
self.frame_files.append(frame)
|
| 500 |
+
plt.close()
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
####################
|
| 504 |
+
# ADDED: For app.py
|
| 505 |
+
####################
|
| 506 |
+
|
| 507 |
+
def plot_heatmap(self, save_dir, step, travel_dist, robots_route=None):
|
| 508 |
+
"""Plot only the segmentation heatmap and save it as ``{step}.png`` in
|
| 509 |
+
``save_dir``. This lightweight helper is meant for asynchronous
|
| 510 |
+
streaming in the Gradio demo when full `plot_env` is too heavy.
|
| 511 |
+
|
| 512 |
+
Parameters
|
| 513 |
+
----------
|
| 514 |
+
save_dir : str
|
| 515 |
+
Directory to save the generated PNG file.
|
| 516 |
+
step : int
|
| 517 |
+
Current timestep; becomes the filename ``{step}.png``.
|
| 518 |
+
robots_route : list | None
|
| 519 |
+
Optional list of routes (xPoints, yPoints) to overlay.
|
| 520 |
+
Returns
|
| 521 |
+
-------
|
| 522 |
+
str
|
| 523 |
+
Full path to the generated PNG file.
|
| 524 |
+
"""
|
| 525 |
+
import os
|
| 526 |
+
plt.switch_backend('agg')
|
| 527 |
+
# Do not clear the global figure state in case it interferes with
|
| 528 |
+
# the current figure. Each call creates its own Figure object that
|
| 529 |
+
# we close explicitly at the end, so a global clear is unnecessary
|
| 530 |
+
# and may break concurrent drawing.
|
| 531 |
+
# plt.cla()
|
| 532 |
+
|
| 533 |
+
color_list = ["r", "g", "c", "m", "y", "k"]
|
| 534 |
+
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
|
| 535 |
+
|
| 536 |
+
# Select the mask to visualise
|
| 537 |
+
# if TAXABIND_TTA and USE_CLIP_PREDS:
|
| 538 |
+
side_dim = int(np.sqrt(self.segmentation_info_mask.shape[0]))
|
| 539 |
+
mask_viz = self.segmentation_info_mask.squeeze().reshape((side_dim, side_dim)).T
|
| 540 |
+
|
| 541 |
+
# Properly map image to pixel coordinates and keep limits fixed
|
| 542 |
+
H, W = self.ground_truth_size # rows (y), cols (x)
|
| 543 |
+
im = ax.imshow(
|
| 544 |
+
mask_viz,
|
| 545 |
+
cmap="viridis",
|
| 546 |
+
origin="upper",
|
| 547 |
+
extent=[0, W, H, 0], # x: 0..W, y: H..0 (origin at top-left)
|
| 548 |
+
interpolation="nearest", # keep cell edges sharp & aligned
|
| 549 |
+
zorder=0,
|
| 550 |
+
)
|
| 551 |
+
ax.set_xlim(0, W)
|
| 552 |
+
ax.set_ylim(H, 0)
|
| 553 |
+
ax.set_axis_off() # hide ticks but keep limits
|
| 554 |
+
# else:
|
| 555 |
+
# im = ax.imshow(self.segmentation_mask.mean(axis=-1), cmap='viridis', vmin=0, vmax=100)
|
| 556 |
+
# ax.axis((0, self.ground_truth_size[1], self.ground_truth_size[0], 0))
|
| 557 |
+
|
| 558 |
+
# Optionally overlay robot paths
|
| 559 |
+
if robots_route is not None:
|
| 560 |
+
for i, route in enumerate(robots_route):
|
| 561 |
+
robot_marker_color = color_list[i % len(color_list)]
|
| 562 |
+
xPoints, yPoints = route
|
| 563 |
+
ax.plot(xPoints, yPoints, c=robot_marker_color, linewidth=2)
|
| 564 |
+
ax.plot(xPoints[-1], yPoints[-1], markersize=12, zorder=99, marker="^", ls="-", c=robot_marker_color, mec="black")
|
| 565 |
+
ax.plot(xPoints[0], yPoints[0], 'co', c=robot_marker_color, markersize=8, zorder=5)
|
| 566 |
+
|
| 567 |
+
# Plot target positions
|
| 568 |
+
for target in self.target_positions:
|
| 569 |
+
if self.coverage_belief[target[1], target[0]] == 255:
|
| 570 |
+
# ax.plot(target[0], target[1], 'go', markersize=8, zorder=99)
|
| 571 |
+
ax.plot(target[0], target[1], color='g', marker='x', linestyle='-', markersize=12, markeredgewidth=4, zorder=99)
|
| 572 |
+
else:
|
| 573 |
+
# ax.plot(target[0], target[1], 'ro', markersize=8, zorder=99)
|
| 574 |
+
ax.plot(target[0], target[1], color='r', marker='x', linestyle='-', markersize=12, markeredgewidth=4, zorder=99)
|
| 575 |
+
|
| 576 |
+
# Sensor range
|
| 577 |
+
rng_x = 0.5 * (self.ground_truth.shape[1] / NUM_COORDS_WIDTH)
|
| 578 |
+
rng_y = 0.5 * (self.ground_truth.shape[0] / NUM_COORDS_HEIGHT)
|
| 579 |
+
max_x = min(xPoints[-1] + int(math.ceil(rng_x)), self.ground_truth.shape[1])
|
| 580 |
+
min_x = max(xPoints[-1] - int(math.ceil(rng_x)), 0)
|
| 581 |
+
max_y = min(yPoints[-1] + int(math.ceil(rng_y)), self.ground_truth.shape[0])
|
| 582 |
+
min_y = max(yPoints[-1] - int(math.ceil(rng_y)), 0)
|
| 583 |
+
ax.plot((min_x, min_x), (min_y, max_y), c=robot_marker_color, linewidth=1)
|
| 584 |
+
ax.plot((min_x, max_x), (max_y, max_y), c=robot_marker_color, linewidth=1)
|
| 585 |
+
ax.plot((max_x, max_x), (max_y, min_y), c=robot_marker_color, linewidth=1)
|
| 586 |
+
ax.plot((max_x, min_x), (min_y, min_y), c=robot_marker_color, linewidth=1)
|
| 587 |
+
|
| 588 |
+
# Color bar
|
| 589 |
+
cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
| 590 |
+
cbar.set_label("Normalized Probs")
|
| 591 |
+
|
| 592 |
+
# Change coverage to 1dp
|
| 593 |
+
plt.suptitle('Targets Found: {}/{} Coverage: {:.1f}% Steps: {}/{}'.format(
|
| 594 |
+
self.num_targets_found, \
|
| 595 |
+
len(self.target_positions),
|
| 596 |
+
self.explored_rate*100,
|
| 597 |
+
step+1,
|
| 598 |
+
NUM_EPS_STEPS),
|
| 599 |
+
y=0.94, # Closer to plot
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
plt.tight_layout()
|
| 603 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 604 |
+
out_path = os.path.join(save_dir, f"{step}.png")
|
| 605 |
+
# Save atomically: write to temp file then move into place so the poller never sees a partial file.
|
| 606 |
+
tmp_path = out_path + ".tmp"
|
| 607 |
+
fig.savefig(tmp_path, dpi=100, format='png')
|
| 608 |
+
os.replace(tmp_path, out_path) # atomic on same filesystem
|
| 609 |
+
plt.close(fig)
|
| 610 |
+
return out_path
|
planner/graph.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#######################################################################
|
| 2 |
+
# Name: env.py
|
| 3 |
+
#
|
| 4 |
+
# - Adapted from https://gist.github.com/betandr/541a1f6466b6855471de5ca30b74cb31
|
| 5 |
+
# - Simple graph class to perform distance calculations (E.g. A-Star, Djikstra)
|
| 6 |
+
#######################################################################
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Edge:
|
| 10 |
+
def __init__(self, to_node, length):
|
| 11 |
+
self.to_node = to_node
|
| 12 |
+
self.length = length
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Graph:
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self.nodes = set()
|
| 18 |
+
self.edges = dict()
|
| 19 |
+
|
| 20 |
+
def add_node(self, node):
|
| 21 |
+
self.nodes.add(node)
|
| 22 |
+
|
| 23 |
+
def add_edge(self, from_node, to_node, length):
|
| 24 |
+
edge = Edge(to_node, length)
|
| 25 |
+
if from_node in self.edges:
|
| 26 |
+
from_node_edges = self.edges[from_node]
|
| 27 |
+
else:
|
| 28 |
+
self.edges[from_node] = dict()
|
| 29 |
+
from_node_edges = self.edges[from_node]
|
| 30 |
+
from_node_edges[to_node] = edge
|
| 31 |
+
|
| 32 |
+
def clear_edge(self, from_node):
|
| 33 |
+
if from_node in self.edges:
|
| 34 |
+
self.edges[from_node] = dict()
|
| 35 |
+
|
| 36 |
+
def min_dist(q, dist):
|
| 37 |
+
"""
|
| 38 |
+
Returns the node with the smallest distance in q.
|
| 39 |
+
Implemented to keep the main algorithm clean.
|
| 40 |
+
"""
|
| 41 |
+
min_node = None
|
| 42 |
+
for node in q:
|
| 43 |
+
if min_node == None:
|
| 44 |
+
min_node = node
|
| 45 |
+
elif dist[node] < dist[min_node]:
|
| 46 |
+
min_node = node
|
| 47 |
+
|
| 48 |
+
return min_node
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
INFINITY = float('Infinity')
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def dijkstra(graph, source):
|
| 55 |
+
q = set()
|
| 56 |
+
dist = {}
|
| 57 |
+
prev = {}
|
| 58 |
+
|
| 59 |
+
for v in graph.nodes:
|
| 60 |
+
dist[v] = INFINITY # unknown distance from source to v
|
| 61 |
+
prev[v] = INFINITY # previous node in optimal path from source
|
| 62 |
+
q.add(v) # all nodes initially in q (unvisited nodes)
|
| 63 |
+
|
| 64 |
+
# distance from source to source
|
| 65 |
+
dist[source] = 0
|
| 66 |
+
|
| 67 |
+
while q:
|
| 68 |
+
# node with the least distance selected first
|
| 69 |
+
u = min_dist(q, dist)
|
| 70 |
+
|
| 71 |
+
q.remove(u)
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
if u in graph.edges:
|
| 75 |
+
for _, v in graph.edges[u].items():
|
| 76 |
+
alt = dist[u] + v.length
|
| 77 |
+
if alt < dist[v.to_node]:
|
| 78 |
+
# a shorter path to v has been found
|
| 79 |
+
dist[v.to_node] = alt
|
| 80 |
+
prev[v.to_node] = u
|
| 81 |
+
except:
|
| 82 |
+
pass
|
| 83 |
+
|
| 84 |
+
return dist, prev
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def to_array(prev, from_node):
|
| 88 |
+
"""Creates an ordered list of labels as a route."""
|
| 89 |
+
previous_node = prev[from_node]
|
| 90 |
+
route = [from_node]
|
| 91 |
+
|
| 92 |
+
while previous_node != INFINITY:
|
| 93 |
+
route.append(previous_node)
|
| 94 |
+
temp = previous_node
|
| 95 |
+
previous_node = prev[temp]
|
| 96 |
+
|
| 97 |
+
route.reverse()
|
| 98 |
+
return route
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def h(index, destination, node_coords):
|
| 102 |
+
current = node_coords[index]
|
| 103 |
+
end = node_coords[destination]
|
| 104 |
+
h = abs(end[0] - current[0]) + abs(end[1] - current[1])
|
| 105 |
+
return h
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def a_star(start, destination, node_coords, graph):
|
| 109 |
+
if start == destination:
|
| 110 |
+
return [], 0
|
| 111 |
+
if str(destination) in graph.edges[str(start)].keys():
|
| 112 |
+
cost = graph.edges[str(start)][str(destination)].length
|
| 113 |
+
return [start, destination], cost
|
| 114 |
+
open_list = {start}
|
| 115 |
+
closed_list = set([])
|
| 116 |
+
|
| 117 |
+
g = {start: 0}
|
| 118 |
+
parents = {start: start}
|
| 119 |
+
|
| 120 |
+
while len(open_list) > 0:
|
| 121 |
+
n = None
|
| 122 |
+
h_n = 1e5
|
| 123 |
+
for v in open_list:
|
| 124 |
+
h_v = h(v, destination, node_coords)
|
| 125 |
+
if n is not None:
|
| 126 |
+
h_n = h(n, destination, node_coords)
|
| 127 |
+
if n is None or g[v] + h_v < g[n] + h_n:
|
| 128 |
+
n = v
|
| 129 |
+
|
| 130 |
+
if n is None:
|
| 131 |
+
print('Path does not exist!')
|
| 132 |
+
return None, 1e5
|
| 133 |
+
|
| 134 |
+
if n == destination:
|
| 135 |
+
reconst_path = []
|
| 136 |
+
while parents[n] != n:
|
| 137 |
+
reconst_path.append(n)
|
| 138 |
+
n = parents[n]
|
| 139 |
+
reconst_path.append(start)
|
| 140 |
+
reconst_path.reverse()
|
| 141 |
+
return reconst_path, g[destination]
|
| 142 |
+
|
| 143 |
+
for edge in graph.edges[str(n)].values():
|
| 144 |
+
m = int(edge.to_node)
|
| 145 |
+
cost = edge.length
|
| 146 |
+
if m not in open_list and m not in closed_list:
|
| 147 |
+
open_list.add(m)
|
| 148 |
+
parents[m] = n
|
| 149 |
+
g[m] = g[n] + cost
|
| 150 |
+
|
| 151 |
+
else:
|
| 152 |
+
if g[m] > g[n] + cost:
|
| 153 |
+
g[m] = g[n] + cost
|
| 154 |
+
parents[m] = n
|
| 155 |
+
|
| 156 |
+
if m in closed_list:
|
| 157 |
+
closed_list.remove(m)
|
| 158 |
+
open_list.add(m)
|
| 159 |
+
|
| 160 |
+
open_list.remove(n)
|
| 161 |
+
closed_list.add(n)
|
| 162 |
+
|
| 163 |
+
print('Path does not exist!')
|
| 164 |
+
return None, 1e5
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
|
planner/graph_generator.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#######################################################################
|
| 2 |
+
# Name: graph_generator.py
|
| 3 |
+
#
|
| 4 |
+
# - Wrapper for graph.py
|
| 5 |
+
# - Sends the formatted inputs into graph.py to get useful info
|
| 6 |
+
#######################################################################
|
| 7 |
+
|
| 8 |
+
import sys
|
| 9 |
+
if sys.modules['TRAINING']:
|
| 10 |
+
from .parameter import *
|
| 11 |
+
else:
|
| 12 |
+
from .test_parameter import *
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import shapely.geometry
|
| 16 |
+
from sklearn.neighbors import NearestNeighbors
|
| 17 |
+
from .node import Node
|
| 18 |
+
from .graph import Graph, a_star
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Graph_generator:
|
| 22 |
+
def __init__(self, map_size, k_size, sensor_range, plot=False):
|
| 23 |
+
self.k_size = k_size
|
| 24 |
+
self.graph = Graph()
|
| 25 |
+
self.node_coords = None
|
| 26 |
+
self.plot = plot
|
| 27 |
+
self.x = []
|
| 28 |
+
self.y = []
|
| 29 |
+
self.map_x = map_size[1]
|
| 30 |
+
self.map_y = map_size[0]
|
| 31 |
+
self.uniform_points, self.grid_coords = self.generate_uniform_points()
|
| 32 |
+
self.sensor_range = sensor_range
|
| 33 |
+
self.route_node = []
|
| 34 |
+
self.nodes_list = []
|
| 35 |
+
self.node_utility = None
|
| 36 |
+
self.guidepost = None
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def edge_clear_all_nodes(self):
|
| 40 |
+
self.graph = Graph()
|
| 41 |
+
self.x = []
|
| 42 |
+
self.y = []
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def edge_clear(self, coords):
|
| 46 |
+
node_index = str(self.find_index_from_coords(self.node_coords, coords))
|
| 47 |
+
self.graph.clear_edge(node_index)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def generate_graph(self, robot_belief, frontiers):
|
| 51 |
+
self.edge_clear_all_nodes()
|
| 52 |
+
free_area = self.free_area(robot_belief)
|
| 53 |
+
|
| 54 |
+
free_area_to_check = free_area[:, 0] + free_area[:, 1] * 1j
|
| 55 |
+
uniform_points_to_check = self.uniform_points[:, 0] + self.uniform_points[:, 1] * 1j
|
| 56 |
+
_, _, candidate_indices = np.intersect1d(free_area_to_check, uniform_points_to_check, return_indices=True)
|
| 57 |
+
node_coords = self.uniform_points[candidate_indices]
|
| 58 |
+
|
| 59 |
+
self.node_coords = self.unique_coords(node_coords).reshape(-1, 2)
|
| 60 |
+
self.find_nearest_neighbor_all_nodes(self.node_coords, robot_belief)
|
| 61 |
+
|
| 62 |
+
self.node_utility = []
|
| 63 |
+
for coords in self.node_coords:
|
| 64 |
+
node = Node(coords, frontiers, robot_belief)
|
| 65 |
+
self.nodes_list.append(node)
|
| 66 |
+
utility = node.utility
|
| 67 |
+
self.node_utility.append(utility)
|
| 68 |
+
self.node_utility = np.array(self.node_utility)
|
| 69 |
+
|
| 70 |
+
self.guidepost = np.zeros((self.node_coords.shape[0], 1))
|
| 71 |
+
x = self.node_coords[:,0] + self.node_coords[:,1]*1j
|
| 72 |
+
for node in self.route_node:
|
| 73 |
+
index = np.argwhere(x.reshape(-1) == node[0]+node[1]*1j)[0]
|
| 74 |
+
self.guidepost[index] = 1
|
| 75 |
+
|
| 76 |
+
return self.node_coords, self.graph.edges, self.node_utility, self.guidepost
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def update_graph(self, robot_belief, old_robot_belief, frontiers, old_frontiers):
|
| 80 |
+
new_free_area = self.free_area((robot_belief - old_robot_belief > 0) * 255)
|
| 81 |
+
free_area_to_check = new_free_area[:, 0] + new_free_area[:, 1] * 1j
|
| 82 |
+
uniform_points_to_check = self.uniform_points[:, 0] + self.uniform_points[:, 1] * 1j
|
| 83 |
+
_, _, candidate_indices = np.intersect1d(free_area_to_check, uniform_points_to_check, return_indices=True)
|
| 84 |
+
new_node_coords = self.uniform_points[candidate_indices]
|
| 85 |
+
self.node_coords = np.concatenate((self.node_coords, new_node_coords))
|
| 86 |
+
|
| 87 |
+
old_node_to_update = []
|
| 88 |
+
for coords in new_node_coords:
|
| 89 |
+
neighbor_indices = self.find_k_neighbor(coords, self.node_coords, robot_belief)
|
| 90 |
+
old_node_to_update += neighbor_indices
|
| 91 |
+
old_node_to_update = set(old_node_to_update)
|
| 92 |
+
for index in old_node_to_update:
|
| 93 |
+
coords = self.node_coords[index]
|
| 94 |
+
self.edge_clear(coords)
|
| 95 |
+
self.find_k_neighbor(coords, self.node_coords, robot_belief)
|
| 96 |
+
|
| 97 |
+
old_frontiers_to_check = old_frontiers[:, 0] + old_frontiers[:, 1] * 1j
|
| 98 |
+
new_frontiers_to_check = frontiers[:, 0] + frontiers[:, 1] * 1j
|
| 99 |
+
observed_frontiers_index = np.where(
|
| 100 |
+
np.isin(old_frontiers_to_check, new_frontiers_to_check, assume_unique=True) == False)
|
| 101 |
+
new_frontiers_index = np.where(
|
| 102 |
+
np.isin(new_frontiers_to_check, old_frontiers_to_check, assume_unique=True) == False)
|
| 103 |
+
observed_frontiers = old_frontiers[observed_frontiers_index]
|
| 104 |
+
new_frontiers = frontiers[new_frontiers_index]
|
| 105 |
+
for node in self.nodes_list:
|
| 106 |
+
if node.zero_utility_node is True:
|
| 107 |
+
pass
|
| 108 |
+
else:
|
| 109 |
+
node.update_observable_frontiers(observed_frontiers, new_frontiers, robot_belief)
|
| 110 |
+
|
| 111 |
+
for new_coords in new_node_coords:
|
| 112 |
+
node = Node(new_coords, frontiers, robot_belief)
|
| 113 |
+
self.nodes_list.append(node)
|
| 114 |
+
|
| 115 |
+
self.node_utility = []
|
| 116 |
+
for i, coords in enumerate(self.node_coords):
|
| 117 |
+
utility = self.nodes_list[i].utility
|
| 118 |
+
self.node_utility.append(utility)
|
| 119 |
+
self.node_utility = np.array(self.node_utility)
|
| 120 |
+
|
| 121 |
+
self.guidepost = np.zeros((self.node_coords.shape[0], 1))
|
| 122 |
+
x = self.node_coords[:, 0] + self.node_coords[:, 1] * 1j
|
| 123 |
+
for node in self.route_node:
|
| 124 |
+
index = np.argwhere(x.reshape(-1) == node[0] + node[1] * 1j)
|
| 125 |
+
self.guidepost[index] = 1
|
| 126 |
+
|
| 127 |
+
return self.node_coords, self.graph.edges, self.node_utility, self.guidepost
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def generate_uniform_points(self):
|
| 131 |
+
padding_x = 0.5 * (self.map_x / NUM_COORDS_WIDTH)
|
| 132 |
+
padding_y = 0.5 * (self.map_y / NUM_COORDS_HEIGHT)
|
| 133 |
+
x = np.linspace(padding_x, self.map_x - padding_x - 1, NUM_COORDS_WIDTH).round().astype(int)
|
| 134 |
+
y = np.linspace(padding_y, self.map_y - padding_y - 1, NUM_COORDS_HEIGHT).round().astype(int)
|
| 135 |
+
|
| 136 |
+
t1, t2 = np.meshgrid(x, y)
|
| 137 |
+
points = np.vstack([t1.T.ravel(), t2.T.ravel()]).T
|
| 138 |
+
matrix = np.stack((t1, t2), axis=-1)
|
| 139 |
+
return points, matrix
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def free_area(self, robot_belief):
|
| 143 |
+
index = np.where(robot_belief == 255)
|
| 144 |
+
free = np.asarray([index[1], index[0]]).T
|
| 145 |
+
return free
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def unique_coords(self, coords):
|
| 149 |
+
x = coords[:, 0] + coords[:, 1] * 1j
|
| 150 |
+
indices = np.unique(x, return_index=True)[1]
|
| 151 |
+
coords = np.array([coords[idx] for idx in sorted(indices)])
|
| 152 |
+
return coords
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def find_k_neighbor(self, coords, node_coords, robot_belief):
|
| 156 |
+
dist_list = np.linalg.norm((coords-node_coords), axis=-1)
|
| 157 |
+
sorted_index = np.argsort(dist_list)
|
| 158 |
+
k = 0
|
| 159 |
+
neighbor_index_list = []
|
| 160 |
+
while k < self.k_size and k< node_coords.shape[0]:
|
| 161 |
+
neighbor_index = sorted_index[k]
|
| 162 |
+
neighbor_index_list.append(neighbor_index)
|
| 163 |
+
dist = dist_list[k]
|
| 164 |
+
start = coords
|
| 165 |
+
end = node_coords[neighbor_index]
|
| 166 |
+
if not self.check_collision(start, end, robot_belief):
|
| 167 |
+
a = str(self.find_index_from_coords(node_coords, start))
|
| 168 |
+
b = str(neighbor_index)
|
| 169 |
+
self.graph.add_node(a)
|
| 170 |
+
self.graph.add_edge(a, b, dist)
|
| 171 |
+
|
| 172 |
+
if self.plot:
|
| 173 |
+
self.x.append([start[0], end[0]])
|
| 174 |
+
self.y.append([start[1], end[1]])
|
| 175 |
+
k += 1
|
| 176 |
+
return neighbor_index_list
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def find_k_neighbor_all_nodes(self, node_coords, robot_belief):
|
| 180 |
+
X = node_coords
|
| 181 |
+
if len(node_coords) >= self.k_size:
|
| 182 |
+
knn = NearestNeighbors(n_neighbors=self.k_size)
|
| 183 |
+
else:
|
| 184 |
+
knn = NearestNeighbors(n_neighbors=len(node_coords))
|
| 185 |
+
knn.fit(X)
|
| 186 |
+
distances, indices = knn.kneighbors(X)
|
| 187 |
+
|
| 188 |
+
for i, p in enumerate(X):
|
| 189 |
+
for j, neighbour in enumerate(X[indices[i][:]]):
|
| 190 |
+
start = p
|
| 191 |
+
end = neighbour
|
| 192 |
+
if not self.check_collision(start, end, robot_belief):
|
| 193 |
+
a = str(self.find_index_from_coords(node_coords, p))
|
| 194 |
+
b = str(self.find_index_from_coords(node_coords, neighbour))
|
| 195 |
+
self.graph.add_node(a)
|
| 196 |
+
self.graph.add_edge(a, b, distances[i, j])
|
| 197 |
+
|
| 198 |
+
if self.plot:
|
| 199 |
+
self.x.append([p[0], neighbour[0]])
|
| 200 |
+
self.y.append([p[1], neighbour[1]])
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def find_nearest_neighbor_all_nodes(self, node_coords, robot_belief):
|
| 204 |
+
for i, p in enumerate(node_coords):
|
| 205 |
+
filtered_coords = self.get_neighbors_grid_coords(p)
|
| 206 |
+
|
| 207 |
+
for j, neighbour in enumerate(filtered_coords):
|
| 208 |
+
start = p
|
| 209 |
+
end = neighbour
|
| 210 |
+
if not self.check_collision(start, end, robot_belief):
|
| 211 |
+
a = str(self.find_index_from_coords(node_coords, p))
|
| 212 |
+
b = str(self.find_index_from_coords(node_coords, neighbour))
|
| 213 |
+
self.graph.add_node(a)
|
| 214 |
+
self.graph.add_edge(a, b, np.linalg.norm(start-end))
|
| 215 |
+
|
| 216 |
+
if self.plot:
|
| 217 |
+
self.x.append([p[0], neighbour[0]])
|
| 218 |
+
self.y.append([p[1], neighbour[1]])
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def find_index_from_coords(self, node_coords, p):
|
| 222 |
+
return np.where(np.linalg.norm(node_coords - p, axis=1) < 1e-5)[0][0]
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def find_closest_index_from_coords(self, node_coords, p):
|
| 226 |
+
return np.argmin(np.linalg.norm(node_coords - p, axis=1))
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def find_index_from_grid_coords_2d(self, p):
|
| 230 |
+
diffs = np.linalg.norm(self.grid_coords - p, axis=2)
|
| 231 |
+
indices = np.where(diffs < 1e-5)
|
| 232 |
+
|
| 233 |
+
if indices[0].size > 0:
|
| 234 |
+
return indices[0][0], indices[1][0]
|
| 235 |
+
else:
|
| 236 |
+
raise ValueError(f"Coordinate {p} not found in self.grid_coords.")
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def find_closest_index_from_grid_coords_2d(self, p):
|
| 240 |
+
distances = np.linalg.norm(self.grid_coords - p, axis=2)
|
| 241 |
+
flat_index = np.argmin(distances)
|
| 242 |
+
return np.unravel_index(flat_index, distances.shape)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def check_collision(self, start, end, robot_belief):
|
| 246 |
+
collision = False
|
| 247 |
+
line = shapely.geometry.LineString([start, end])
|
| 248 |
+
|
| 249 |
+
sortx = np.sort([start[0], end[0]])
|
| 250 |
+
sorty = np.sort([start[1], end[1]])
|
| 251 |
+
|
| 252 |
+
robot_belief = robot_belief[sorty[0]:sorty[1]+1, sortx[0]:sortx[1]+1]
|
| 253 |
+
|
| 254 |
+
occupied_area_index = np.where(robot_belief == 1)
|
| 255 |
+
occupied_area_coords = np.asarray([occupied_area_index[1]+sortx[0], occupied_area_index[0]+sorty[0]]).T
|
| 256 |
+
unexplored_area_index = np.where(robot_belief == 127)
|
| 257 |
+
unexplored_area_coords = np.asarray([unexplored_area_index[1]+sortx[0], unexplored_area_index[0]+sorty[0]]).T
|
| 258 |
+
unfree_area_coords = np.concatenate((occupied_area_coords, unexplored_area_coords))
|
| 259 |
+
|
| 260 |
+
for i in range(unfree_area_coords.shape[0]):
|
| 261 |
+
coords = ([(unfree_area_coords[i][0], unfree_area_coords[i][1]),
|
| 262 |
+
(unfree_area_coords[i][0] + 1, unfree_area_coords[i][1]),
|
| 263 |
+
(unfree_area_coords[i][0], unfree_area_coords[i][1] + 1),
|
| 264 |
+
(unfree_area_coords[i][0] + 1, unfree_area_coords[i][1] + 1)])
|
| 265 |
+
obstacle = shapely.geometry.Polygon(coords)
|
| 266 |
+
collision = line.intersects(obstacle)
|
| 267 |
+
if collision:
|
| 268 |
+
break
|
| 269 |
+
|
| 270 |
+
return collision
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def find_shortest_path(self, current, destination, node_coords):
|
| 274 |
+
start_node = str(self.find_index_from_coords(node_coords, current))
|
| 275 |
+
end_node = str(self.find_index_from_coords(node_coords, destination))
|
| 276 |
+
route, dist = a_star(int(start_node), int(end_node), self.node_coords, self.graph)
|
| 277 |
+
if start_node != end_node:
|
| 278 |
+
assert route != []
|
| 279 |
+
route = list(map(str, route))
|
| 280 |
+
return dist, route
|
| 281 |
+
|
| 282 |
+
def get_neighbors_grid_coords(self, coord):
|
| 283 |
+
# Return the 8 closest neighbors of a given coordinate
|
| 284 |
+
|
| 285 |
+
nearest_coord = self.node_coords[self.find_closest_index_from_coords(self.node_coords, coord)]
|
| 286 |
+
rows, cols = self.grid_coords.shape[:2]
|
| 287 |
+
neighbors = []
|
| 288 |
+
i, j = self.find_index_from_grid_coords_2d(nearest_coord)
|
| 289 |
+
|
| 290 |
+
# Create a range of indices for rows and columns
|
| 291 |
+
row_range = np.clip([i - 1, i, i + 1], 0, rows - 1)
|
| 292 |
+
col_range = np.clip([j - 1, j, j + 1], 0, cols - 1)
|
| 293 |
+
|
| 294 |
+
# Iterate over the valid indices
|
| 295 |
+
for ni in row_range:
|
| 296 |
+
for nj in col_range:
|
| 297 |
+
if (ni, nj) != (i, j): # Skip the center point
|
| 298 |
+
neighbors.append(tuple(self.grid_coords[ni, nj]))
|
| 299 |
+
|
| 300 |
+
return neighbors
|
planner/model.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#######################################################################
|
| 2 |
+
# Name: model.py
|
| 3 |
+
#
|
| 4 |
+
# - Attention-based encoders & decoders
|
| 5 |
+
# - Policy Net: Input = Augmented Graph, Output = Node to go to
|
| 6 |
+
# - Critic Net: Input = Augmented Graph + Action, Output = Q_Value
|
| 7 |
+
#######################################################################
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import math
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SingleHeadAttention(nn.Module):
|
| 15 |
+
def __init__(self, embedding_dim):
|
| 16 |
+
super(SingleHeadAttention, self).__init__()
|
| 17 |
+
self.input_dim = embedding_dim
|
| 18 |
+
self.embedding_dim = embedding_dim
|
| 19 |
+
self.value_dim = embedding_dim
|
| 20 |
+
self.key_dim = self.value_dim
|
| 21 |
+
self.tanh_clipping = 10
|
| 22 |
+
self.norm_factor = 1 / math.sqrt(self.key_dim)
|
| 23 |
+
|
| 24 |
+
self.w_query = nn.Parameter(torch.Tensor(self.input_dim, self.key_dim))
|
| 25 |
+
self.w_key = nn.Parameter(torch.Tensor(self.input_dim, self.key_dim))
|
| 26 |
+
|
| 27 |
+
self.init_parameters()
|
| 28 |
+
|
| 29 |
+
def init_parameters(self):
|
| 30 |
+
for param in self.parameters():
|
| 31 |
+
stdv = 1. / math.sqrt(param.size(-1))
|
| 32 |
+
param.data.uniform_(-stdv, stdv)
|
| 33 |
+
|
| 34 |
+
def forward(self, q, k, mask=None):
|
| 35 |
+
|
| 36 |
+
n_batch, n_key, n_dim = k.size()
|
| 37 |
+
n_query = q.size(1)
|
| 38 |
+
|
| 39 |
+
k_flat = k.reshape(-1, n_dim)
|
| 40 |
+
q_flat = q.reshape(-1, n_dim)
|
| 41 |
+
|
| 42 |
+
shape_k = (n_batch, n_key, -1)
|
| 43 |
+
shape_q = (n_batch, n_query, -1)
|
| 44 |
+
|
| 45 |
+
Q = torch.matmul(q_flat, self.w_query).view(shape_q)
|
| 46 |
+
K = torch.matmul(k_flat, self.w_key).view(shape_k)
|
| 47 |
+
|
| 48 |
+
U = self.norm_factor * torch.matmul(Q, K.transpose(1, 2))
|
| 49 |
+
U = self.tanh_clipping * torch.tanh(U)
|
| 50 |
+
|
| 51 |
+
if mask is not None:
|
| 52 |
+
U = U.masked_fill(mask == 1, -1e8)
|
| 53 |
+
attention = torch.log_softmax(U, dim=-1) # n_batch*n_query*n_key
|
| 54 |
+
|
| 55 |
+
return attention
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class MultiHeadAttention(nn.Module):
|
| 59 |
+
def __init__(self, embedding_dim, n_heads=8):
|
| 60 |
+
super(MultiHeadAttention, self).__init__()
|
| 61 |
+
self.n_heads = n_heads
|
| 62 |
+
self.input_dim = embedding_dim
|
| 63 |
+
self.embedding_dim = embedding_dim
|
| 64 |
+
self.value_dim = self.embedding_dim // self.n_heads
|
| 65 |
+
self.key_dim = self.value_dim
|
| 66 |
+
self.norm_factor = 1 / math.sqrt(self.key_dim)
|
| 67 |
+
|
| 68 |
+
self.w_query = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.key_dim))
|
| 69 |
+
self.w_key = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.key_dim))
|
| 70 |
+
self.w_value = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.value_dim))
|
| 71 |
+
self.w_out = nn.Parameter(torch.Tensor(self.n_heads, self.value_dim, self.embedding_dim))
|
| 72 |
+
|
| 73 |
+
self.init_parameters()
|
| 74 |
+
|
| 75 |
+
def init_parameters(self):
|
| 76 |
+
for param in self.parameters():
|
| 77 |
+
stdv = 1. / math.sqrt(param.size(-1))
|
| 78 |
+
param.data.uniform_(-stdv, stdv)
|
| 79 |
+
|
| 80 |
+
def forward(self, q, k=None, v=None, key_padding_mask=None, attn_mask=None):
|
| 81 |
+
if k is None:
|
| 82 |
+
k = q
|
| 83 |
+
if v is None:
|
| 84 |
+
v = q
|
| 85 |
+
|
| 86 |
+
n_batch, n_key, n_dim = k.size()
|
| 87 |
+
n_query = q.size(1)
|
| 88 |
+
n_value = v.size(1)
|
| 89 |
+
|
| 90 |
+
k_flat = k.contiguous().view(-1, n_dim)
|
| 91 |
+
v_flat = v.contiguous().view(-1, n_dim)
|
| 92 |
+
q_flat = q.contiguous().view(-1, n_dim)
|
| 93 |
+
shape_v = (self.n_heads, n_batch, n_value, -1)
|
| 94 |
+
shape_k = (self.n_heads, n_batch, n_key, -1)
|
| 95 |
+
shape_q = (self.n_heads, n_batch, n_query, -1)
|
| 96 |
+
|
| 97 |
+
Q = torch.matmul(q_flat, self.w_query).view(shape_q) # n_heads*batch_size*n_query*key_dim
|
| 98 |
+
K = torch.matmul(k_flat, self.w_key).view(shape_k) # n_heads*batch_size*targets_size*key_dim
|
| 99 |
+
V = torch.matmul(v_flat, self.w_value).view(shape_v) # n_heads*batch_size*targets_size*value_dim
|
| 100 |
+
|
| 101 |
+
U = self.norm_factor * torch.matmul(Q, K.transpose(2, 3)) # n_heads*batch_size*n_query*targets_size
|
| 102 |
+
|
| 103 |
+
if attn_mask is not None:
|
| 104 |
+
attn_mask = attn_mask.view(1, n_batch, n_query, n_key).expand_as(U)
|
| 105 |
+
|
| 106 |
+
if key_padding_mask is not None:
|
| 107 |
+
key_padding_mask = key_padding_mask.repeat(1, n_query, 1)
|
| 108 |
+
key_padding_mask = key_padding_mask.view(1, n_batch, n_query, n_key).expand_as(U) # copy for n_heads times
|
| 109 |
+
|
| 110 |
+
if attn_mask is not None and key_padding_mask is not None:
|
| 111 |
+
mask = (attn_mask + key_padding_mask)
|
| 112 |
+
elif attn_mask is not None:
|
| 113 |
+
mask = attn_mask
|
| 114 |
+
elif key_padding_mask is not None:
|
| 115 |
+
mask = key_padding_mask
|
| 116 |
+
else:
|
| 117 |
+
mask = None
|
| 118 |
+
|
| 119 |
+
if mask is not None:
|
| 120 |
+
U = U.masked_fill(mask > 0, -1e8)
|
| 121 |
+
|
| 122 |
+
attention = torch.softmax(U, dim=-1) # n_heads*batch_size*n_query*targets_size
|
| 123 |
+
heads = torch.matmul(attention, V) # n_heads*batch_size*n_query*value_dim
|
| 124 |
+
out = torch.mm(
|
| 125 |
+
heads.permute(1, 2, 0, 3).reshape(-1, self.n_heads * self.value_dim),
|
| 126 |
+
# batch_size*n_query*n_heads*value_dim
|
| 127 |
+
self.w_out.view(-1, self.embedding_dim)
|
| 128 |
+
# n_heads*value_dim*embedding_dim
|
| 129 |
+
).view(-1, n_query, self.embedding_dim)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
return out, attention # batch_size*n_query*embedding_dim
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class Normalization(nn.Module):
|
| 136 |
+
def __init__(self, embedding_dim):
|
| 137 |
+
super(Normalization, self).__init__()
|
| 138 |
+
self.normalizer = nn.LayerNorm(embedding_dim)
|
| 139 |
+
|
| 140 |
+
def forward(self, input):
|
| 141 |
+
return self.normalizer(input.view(-1, input.size(-1))).view(*input.size())
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class EncoderLayer(nn.Module):
|
| 145 |
+
def __init__(self, embedding_dim, n_head):
|
| 146 |
+
super(EncoderLayer, self).__init__()
|
| 147 |
+
self.multiHeadAttention = MultiHeadAttention(embedding_dim, n_head)
|
| 148 |
+
self.normalization1 = Normalization(embedding_dim)
|
| 149 |
+
self.feedForward = nn.Sequential(nn.Linear(embedding_dim, 512), nn.ReLU(inplace=True),
|
| 150 |
+
nn.Linear(512, embedding_dim))
|
| 151 |
+
self.normalization2 = Normalization(embedding_dim)
|
| 152 |
+
|
| 153 |
+
def forward(self, src, key_padding_mask=None, attn_mask=None):
|
| 154 |
+
h0 = src
|
| 155 |
+
h = self.normalization1(src)
|
| 156 |
+
h, _ = self.multiHeadAttention(q=h, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
| 157 |
+
h = h + h0
|
| 158 |
+
h1 = h
|
| 159 |
+
h = self.normalization2(h)
|
| 160 |
+
h = self.feedForward(h)
|
| 161 |
+
h2 = h + h1
|
| 162 |
+
return h2
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class DecoderLayer(nn.Module):
|
| 166 |
+
def __init__(self, embedding_dim, n_head):
|
| 167 |
+
super(DecoderLayer, self).__init__()
|
| 168 |
+
self.multiHeadAttention = MultiHeadAttention(embedding_dim, n_head)
|
| 169 |
+
self.normalization1 = Normalization(embedding_dim)
|
| 170 |
+
self.feedForward = nn.Sequential(nn.Linear(embedding_dim, 512),
|
| 171 |
+
nn.ReLU(inplace=True),
|
| 172 |
+
nn.Linear(512, embedding_dim))
|
| 173 |
+
self.normalization2 = Normalization(embedding_dim)
|
| 174 |
+
|
| 175 |
+
def forward(self, tgt, memory, key_padding_mask=None, attn_mask=None):
|
| 176 |
+
h0 = tgt
|
| 177 |
+
tgt = self.normalization1(tgt)
|
| 178 |
+
memory = self.normalization1(memory)
|
| 179 |
+
h, w = self.multiHeadAttention(q=tgt, k=memory, v=memory, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
| 180 |
+
h = h + h0
|
| 181 |
+
h1 = h
|
| 182 |
+
h = self.normalization2(h)
|
| 183 |
+
h = self.feedForward(h)
|
| 184 |
+
h2 = h + h1
|
| 185 |
+
return h2, w
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class Encoder(nn.Module):
|
| 189 |
+
def __init__(self, embedding_dim=128, n_head=8, n_layer=1):
|
| 190 |
+
super(Encoder, self).__init__()
|
| 191 |
+
self.layers = nn.ModuleList(EncoderLayer(embedding_dim, n_head) for i in range(n_layer))
|
| 192 |
+
|
| 193 |
+
def forward(self, src, key_padding_mask=None, attn_mask=None):
|
| 194 |
+
for layer in self.layers:
|
| 195 |
+
src = layer(src, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
| 196 |
+
return src
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class Decoder(nn.Module):
|
| 200 |
+
def __init__(self, embedding_dim=128, n_head=8, n_layer=1):
|
| 201 |
+
super(Decoder, self).__init__()
|
| 202 |
+
self.layers = nn.ModuleList([DecoderLayer(embedding_dim, n_head) for i in range(n_layer)])
|
| 203 |
+
|
| 204 |
+
def forward(self, tgt, memory, key_padding_mask=None, attn_mask=None):
|
| 205 |
+
for layer in self.layers:
|
| 206 |
+
tgt, w = layer(tgt, memory, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
| 207 |
+
return tgt, w
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class PolicyNet(nn.Module):
|
| 211 |
+
def __init__(self, input_dim, embedding_dim):
|
| 212 |
+
super(PolicyNet, self).__init__()
|
| 213 |
+
self.initial_embedding = nn.Linear(input_dim, embedding_dim) # layer for non-end position
|
| 214 |
+
|
| 215 |
+
self.current_embedding = nn.Linear(embedding_dim * 2, embedding_dim)
|
| 216 |
+
|
| 217 |
+
self.encoder = Encoder(embedding_dim=embedding_dim, n_head=8, n_layer=6)
|
| 218 |
+
self.decoder = Decoder(embedding_dim=embedding_dim, n_head=8, n_layer=1)
|
| 219 |
+
self.pointer = SingleHeadAttention(embedding_dim)
|
| 220 |
+
|
| 221 |
+
def encode_graph(self, node_inputs, node_padding_mask, edge_mask):
|
| 222 |
+
node_feature = self.initial_embedding(node_inputs)
|
| 223 |
+
enhanced_node_feature = self.encoder(src=node_feature, key_padding_mask=node_padding_mask, attn_mask=edge_mask)
|
| 224 |
+
|
| 225 |
+
return enhanced_node_feature
|
| 226 |
+
|
| 227 |
+
def output_policy(self, enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask):
|
| 228 |
+
k_size = edge_inputs.size()[2]
|
| 229 |
+
current_edge = torch.gather(edge_inputs, 1, current_index.repeat(1, 1, k_size))
|
| 230 |
+
current_edge = current_edge.permute(0, 2, 1)
|
| 231 |
+
embedding_dim = enhanced_node_feature.size()[2]
|
| 232 |
+
|
| 233 |
+
neigboring_feature = torch.gather(enhanced_node_feature, 1, current_edge.repeat(1, 1, embedding_dim))
|
| 234 |
+
|
| 235 |
+
current_node_feature = torch.gather(enhanced_node_feature, 1, current_index.repeat(1, 1, embedding_dim))
|
| 236 |
+
|
| 237 |
+
if edge_padding_mask is not None:
|
| 238 |
+
current_mask = torch.gather(edge_padding_mask, 1, current_index.repeat(1,1,k_size)).to(enhanced_node_feature.device)
|
| 239 |
+
else:
|
| 240 |
+
current_mask = None
|
| 241 |
+
current_mask[:,:,0] = 1 # don't stay at current position
|
| 242 |
+
|
| 243 |
+
if not 0 in current_mask:
|
| 244 |
+
current_mask[:,:,0] = 0
|
| 245 |
+
|
| 246 |
+
enhanced_current_node_feature, _ = self.decoder(current_node_feature, enhanced_node_feature, node_padding_mask)
|
| 247 |
+
enhanced_current_node_feature = self.current_embedding(torch.cat((enhanced_current_node_feature, current_node_feature), dim=-1))
|
| 248 |
+
logp = self.pointer(enhanced_current_node_feature, neigboring_feature, current_mask)
|
| 249 |
+
logp= logp.squeeze(1) # batch_size*k_size
|
| 250 |
+
|
| 251 |
+
return logp
|
| 252 |
+
|
| 253 |
+
def forward(self, node_inputs, edge_inputs, current_index, node_padding_mask=None, edge_padding_mask=None, edge_mask=None):
|
| 254 |
+
enhanced_node_feature = self.encode_graph(node_inputs, node_padding_mask, edge_mask)
|
| 255 |
+
logp = self.output_policy(enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask)
|
| 256 |
+
return logp
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class QNet(nn.Module):
|
| 260 |
+
def __init__(self, input_dim, embedding_dim):
|
| 261 |
+
super(QNet, self).__init__()
|
| 262 |
+
self.initial_embedding = nn.Linear(input_dim, embedding_dim) # layer for non-end position
|
| 263 |
+
self.action_embedding = nn.Linear(embedding_dim*3, embedding_dim)
|
| 264 |
+
|
| 265 |
+
self.encoder = Encoder(embedding_dim=embedding_dim, n_head=8, n_layer=6)
|
| 266 |
+
self.decoder = Decoder(embedding_dim=embedding_dim, n_head=8, n_layer=1)
|
| 267 |
+
|
| 268 |
+
self.q_values_layer = nn.Linear(embedding_dim, 1)
|
| 269 |
+
|
| 270 |
+
def encode_graph(self, node_inputs, node_padding_mask, edge_mask):
|
| 271 |
+
embedding_feature = self.initial_embedding(node_inputs)
|
| 272 |
+
embedding_feature = self.encoder(src=embedding_feature, key_padding_mask=node_padding_mask, attn_mask=edge_mask)
|
| 273 |
+
|
| 274 |
+
return embedding_feature
|
| 275 |
+
|
| 276 |
+
def output_q_values(self, enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask):
|
| 277 |
+
k_size = edge_inputs.size()[2]
|
| 278 |
+
current_edge = torch.gather(edge_inputs, 1, current_index.repeat(1, 1, k_size))
|
| 279 |
+
current_edge = current_edge.permute(0, 2, 1)
|
| 280 |
+
embedding_dim = enhanced_node_feature.size()[2]
|
| 281 |
+
|
| 282 |
+
neigboring_feature = torch.gather(enhanced_node_feature, 1, current_edge.repeat(1, 1, embedding_dim))
|
| 283 |
+
|
| 284 |
+
current_node_feature = torch.gather(enhanced_node_feature, 1, current_index.repeat(1, 1, embedding_dim))
|
| 285 |
+
|
| 286 |
+
enhanced_current_node_feature, attention_weights = self.decoder(current_node_feature, enhanced_node_feature, node_padding_mask)
|
| 287 |
+
action_features = torch.cat((enhanced_current_node_feature.repeat(1, k_size, 1), current_node_feature.repeat(1, k_size, 1), neigboring_feature), dim=-1)
|
| 288 |
+
action_features = self.action_embedding(action_features)
|
| 289 |
+
q_values = self.q_values_layer(action_features)
|
| 290 |
+
|
| 291 |
+
if edge_padding_mask is not None:
|
| 292 |
+
current_mask = torch.gather(edge_padding_mask, 1, current_index.repeat(1, 1, k_size)).to(
|
| 293 |
+
enhanced_node_feature.device)
|
| 294 |
+
else:
|
| 295 |
+
current_mask = None
|
| 296 |
+
current_mask[:, :, 0] = 1 # don't stay at current position
|
| 297 |
+
|
| 298 |
+
if not 0 in current_mask:
|
| 299 |
+
current_mask[:,:,0] = 0
|
| 300 |
+
|
| 301 |
+
current_mask = current_mask.permute(0, 2, 1)
|
| 302 |
+
zero = torch.zeros_like(q_values).to(q_values.device)
|
| 303 |
+
q_values = torch.where(current_mask == 1, zero, q_values)
|
| 304 |
+
|
| 305 |
+
return q_values, attention_weights
|
| 306 |
+
|
| 307 |
+
def forward(self, node_inputs, edge_inputs, current_index, node_padding_mask=None, edge_padding_mask=None,
|
| 308 |
+
edge_mask=None):
|
| 309 |
+
enhanced_node_feature = self.encode_graph(node_inputs, node_padding_mask, edge_mask)
|
| 310 |
+
q_values, attention_weights = self.output_q_values(enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask)
|
| 311 |
+
return q_values, attention_weights
|
| 312 |
+
|
planner/node.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#######################################################################
|
| 2 |
+
# Name: node.py
|
| 3 |
+
#
|
| 4 |
+
# - Contains info per node on graph (edge)
|
| 5 |
+
# - Contains: Position, Utility, Visitation History
|
| 6 |
+
#######################################################################
|
| 7 |
+
|
| 8 |
+
import sys
|
| 9 |
+
if sys.modules['TRAINING']:
|
| 10 |
+
from .parameter import *
|
| 11 |
+
else:
|
| 12 |
+
from .test_parameter import *
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import shapely.geometry
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Node():
|
| 19 |
+
def __init__(self, coords, frontiers, robot_belief):
|
| 20 |
+
self.coords = coords
|
| 21 |
+
self.observable_frontiers = []
|
| 22 |
+
self.sensor_range = SENSOR_RANGE
|
| 23 |
+
self.initialize_observable_frontiers(frontiers, robot_belief)
|
| 24 |
+
self.utility = self.get_node_utility()
|
| 25 |
+
if self.utility == 0:
|
| 26 |
+
self.zero_utility_node = True
|
| 27 |
+
else:
|
| 28 |
+
self.zero_utility_node = False
|
| 29 |
+
|
| 30 |
+
def initialize_observable_frontiers(self, frontiers, robot_belief):
|
| 31 |
+
dist_list = np.linalg.norm(frontiers - self.coords, axis=-1)
|
| 32 |
+
frontiers_in_range = frontiers[dist_list < self.sensor_range - 10]
|
| 33 |
+
for point in frontiers_in_range:
|
| 34 |
+
collision = self.check_collision(self.coords, point, robot_belief)
|
| 35 |
+
if not collision:
|
| 36 |
+
self.observable_frontiers.append(point)
|
| 37 |
+
|
| 38 |
+
def get_node_utility(self):
|
| 39 |
+
return len(self.observable_frontiers)
|
| 40 |
+
|
| 41 |
+
def update_observable_frontiers(self, observed_frontiers, new_frontiers, robot_belief):
|
| 42 |
+
if observed_frontiers != []:
|
| 43 |
+
observed_index = []
|
| 44 |
+
for i, point in enumerate(self.observable_frontiers):
|
| 45 |
+
if point[0] + point[1] * 1j in observed_frontiers[:, 0] + observed_frontiers[:, 1] * 1j:
|
| 46 |
+
observed_index.append(i)
|
| 47 |
+
for index in reversed(observed_index):
|
| 48 |
+
self.observable_frontiers.pop(index)
|
| 49 |
+
#
|
| 50 |
+
if new_frontiers != []:
|
| 51 |
+
dist_list = np.linalg.norm(new_frontiers - self.coords, axis=-1)
|
| 52 |
+
new_frontiers_in_range = new_frontiers[dist_list < self.sensor_range - 15]
|
| 53 |
+
for point in new_frontiers_in_range:
|
| 54 |
+
collision = self.check_collision(self.coords, point, robot_belief)
|
| 55 |
+
if not collision:
|
| 56 |
+
self.observable_frontiers.append(point)
|
| 57 |
+
|
| 58 |
+
self.utility = self.get_node_utility()
|
| 59 |
+
if self.utility == 0:
|
| 60 |
+
self.zero_utility_node = True
|
| 61 |
+
else:
|
| 62 |
+
self.zero_utility_node = False
|
| 63 |
+
|
| 64 |
+
def set_visited(self):
|
| 65 |
+
self.observable_frontiers = []
|
| 66 |
+
self.utility = 0
|
| 67 |
+
self.zero_utility_node = True
|
| 68 |
+
|
| 69 |
+
def check_collision(self, start, end, robot_belief):
|
| 70 |
+
collision = False
|
| 71 |
+
line = shapely.geometry.LineString([start, end])
|
| 72 |
+
|
| 73 |
+
sortx = np.sort([start[0], end[0]])
|
| 74 |
+
sorty = np.sort([start[1], end[1]])
|
| 75 |
+
|
| 76 |
+
robot_belief = robot_belief[sorty[0]:sorty[1] + 1, sortx[0]:sortx[1] + 1]
|
| 77 |
+
|
| 78 |
+
occupied_area_index = np.where(robot_belief == 1)
|
| 79 |
+
occupied_area_coords = np.asarray(
|
| 80 |
+
[occupied_area_index[1] + sortx[0], occupied_area_index[0] + sorty[0]]).T
|
| 81 |
+
unexplored_area_index = np.where(robot_belief == 127)
|
| 82 |
+
unexplored_area_coords = np.asarray(
|
| 83 |
+
[unexplored_area_index[1] + sortx[0], unexplored_area_index[0] + sorty[0]]).T
|
| 84 |
+
unfree_area_coords = np.concatenate((occupied_area_coords, unexplored_area_coords))
|
| 85 |
+
|
| 86 |
+
for i in range(unfree_area_coords.shape[0]):
|
| 87 |
+
coords = ([(unfree_area_coords[i][0], unfree_area_coords[i][1]),
|
| 88 |
+
(unfree_area_coords[i][0] + 1, unfree_area_coords[i][1]),
|
| 89 |
+
(unfree_area_coords[i][0], unfree_area_coords[i][1] + 1),
|
| 90 |
+
(unfree_area_coords[i][0] + 1, unfree_area_coords[i][1] + 1)])
|
| 91 |
+
obstacle = shapely.geometry.Polygon(coords)
|
| 92 |
+
collision = line.intersects(obstacle)
|
| 93 |
+
if collision:
|
| 94 |
+
break
|
| 95 |
+
|
| 96 |
+
return collision
|
planner/robot.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#######################################################################
|
| 2 |
+
# Name: robot.py
|
| 3 |
+
#
|
| 4 |
+
# - Stores S(t), A(t), R(t), S(t+1)
|
| 5 |
+
#######################################################################
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from copy import deepcopy
|
| 9 |
+
|
| 10 |
+
class Robot:
|
| 11 |
+
def __init__(self, robot_id, position, plot=False):
|
| 12 |
+
self.robot_id = robot_id
|
| 13 |
+
self.plot = plot
|
| 14 |
+
self.travel_dist = 0
|
| 15 |
+
self.robot_position = position
|
| 16 |
+
self.observations = None
|
| 17 |
+
self.trajectory_coords = []
|
| 18 |
+
self.targets_found_on_path = []
|
| 19 |
+
|
| 20 |
+
self.episode_buffer = []
|
| 21 |
+
for i in range(15):
|
| 22 |
+
self.episode_buffer.append([])
|
| 23 |
+
|
| 24 |
+
if self.plot:
|
| 25 |
+
self.xPoints = [self.robot_position[0]]
|
| 26 |
+
self.yPoints = [self.robot_position[1]]
|
| 27 |
+
|
| 28 |
+
def save_observations(self, observations):
|
| 29 |
+
node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask = observations
|
| 30 |
+
self.episode_buffer[0] += deepcopy(node_inputs).to('cpu')
|
| 31 |
+
self.episode_buffer[1] += deepcopy(edge_inputs).to('cpu')
|
| 32 |
+
self.episode_buffer[2] += deepcopy(current_index).to('cpu')
|
| 33 |
+
self.episode_buffer[3] += deepcopy(node_padding_mask).to('cpu')
|
| 34 |
+
self.episode_buffer[4] += deepcopy(edge_padding_mask).to('cpu')
|
| 35 |
+
self.episode_buffer[5] += deepcopy(edge_mask).to('cpu')
|
| 36 |
+
|
| 37 |
+
def save_action(self, action_index):
|
| 38 |
+
self.episode_buffer[6] += action_index.unsqueeze(0).unsqueeze(0)
|
| 39 |
+
|
| 40 |
+
def save_reward_done(self, reward, done):
|
| 41 |
+
self.episode_buffer[7] += deepcopy(torch.FloatTensor([[[reward]]])).to('cpu')
|
| 42 |
+
self.episode_buffer[8] += deepcopy(torch.tensor([[[(int(done))]]])).to('cpu')
|
| 43 |
+
if self.plot:
|
| 44 |
+
self.xPoints.append(self.robot_position[0])
|
| 45 |
+
self.yPoints.append(self.robot_position[1])
|
| 46 |
+
|
| 47 |
+
def save_next_observations(self, observations):
|
| 48 |
+
node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask = observations
|
| 49 |
+
self.episode_buffer[9] += deepcopy(node_inputs).to('cpu')
|
| 50 |
+
self.episode_buffer[10] += deepcopy(edge_inputs).to('cpu')
|
| 51 |
+
self.episode_buffer[11] += deepcopy(current_index).to('cpu')
|
| 52 |
+
self.episode_buffer[12] += deepcopy(node_padding_mask).to('cpu')
|
| 53 |
+
self.episode_buffer[13] += deepcopy(edge_padding_mask).to('cpu')
|
| 54 |
+
self.episode_buffer[14] += deepcopy(edge_mask).to('cpu')
|
| 55 |
+
|
| 56 |
+
def save_trajectory_coords(self, robot_position_coords, num_target_found):
|
| 57 |
+
self.trajectory_coords.append(robot_position_coords)
|
| 58 |
+
self.targets_found_on_path.append(num_target_found)
|
planner/sensor.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#######################################################################
|
| 2 |
+
# Name: sensor.py
|
| 3 |
+
#
|
| 4 |
+
# - Computes sensor related checks (e.g. collision, utility etc)
|
| 5 |
+
#######################################################################
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
if sys.modules['TRAINING']:
|
| 9 |
+
from .parameter import *
|
| 10 |
+
else:
|
| 11 |
+
from .test_parameter import *
|
| 12 |
+
|
| 13 |
+
import math
|
| 14 |
+
import numpy as np
|
| 15 |
+
import copy
|
| 16 |
+
|
| 17 |
+
def collision_check(x0, y0, x1, y1, ground_truth, robot_belief):
|
| 18 |
+
x0 = x0.round()
|
| 19 |
+
y0 = y0.round()
|
| 20 |
+
x1 = x1.round()
|
| 21 |
+
y1 = y1.round()
|
| 22 |
+
dx, dy = abs(x1 - x0), abs(y1 - y0)
|
| 23 |
+
x, y = x0, y0
|
| 24 |
+
error = dx - dy
|
| 25 |
+
x_inc = 1 if x1 > x0 else -1
|
| 26 |
+
y_inc = 1 if y1 > y0 else -1
|
| 27 |
+
dx *= 2
|
| 28 |
+
dy *= 2
|
| 29 |
+
|
| 30 |
+
collision_flag = 0
|
| 31 |
+
max_collision = 10
|
| 32 |
+
|
| 33 |
+
while 0 <= x < ground_truth.shape[1] and 0 <= y < ground_truth.shape[0]:
|
| 34 |
+
k = ground_truth.item(y, x)
|
| 35 |
+
if k == 1 and collision_flag < max_collision:
|
| 36 |
+
collision_flag += 1
|
| 37 |
+
if collision_flag >= max_collision:
|
| 38 |
+
break
|
| 39 |
+
|
| 40 |
+
if k !=1 and collision_flag > 0:
|
| 41 |
+
break
|
| 42 |
+
|
| 43 |
+
if x == x1 and y == y1:
|
| 44 |
+
break
|
| 45 |
+
|
| 46 |
+
robot_belief.itemset((y, x), k)
|
| 47 |
+
|
| 48 |
+
if error > 0:
|
| 49 |
+
x += x_inc
|
| 50 |
+
error -= dy
|
| 51 |
+
else:
|
| 52 |
+
y += y_inc
|
| 53 |
+
error += dx
|
| 54 |
+
|
| 55 |
+
return robot_belief
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def sensor_work(robot_position, sensor_range, robot_belief, ground_truth, sensor_model=SENSOR_MODEL):
|
| 59 |
+
x0 = robot_position[0]
|
| 60 |
+
y0 = robot_position[1]
|
| 61 |
+
rng_x = 0.5 * (ground_truth.shape[1] / NUM_COORDS_WIDTH)
|
| 62 |
+
rng_y = 0.5 * (ground_truth.shape[0] / NUM_COORDS_HEIGHT)
|
| 63 |
+
|
| 64 |
+
if sensor_model == "rectangular": # TODO: add collision check
|
| 65 |
+
max_x = min(x0 + int(math.ceil(rng_x)), ground_truth.shape[1])
|
| 66 |
+
min_x = max(x0 - int(math.ceil(rng_x)), 0)
|
| 67 |
+
max_y = min(y0 + int(math.ceil(rng_y)), ground_truth.shape[0])
|
| 68 |
+
min_y = max(y0 - int(math.ceil(rng_y)), 0)
|
| 69 |
+
robot_belief[min_y:max_y, min_x:max_x] = ground_truth[min_y:max_y, min_x:max_x]
|
| 70 |
+
else:
|
| 71 |
+
sensor_angle_inc = 0.5 / 180 * np.pi
|
| 72 |
+
sensor_angle = 0
|
| 73 |
+
while sensor_angle < 2 * np.pi:
|
| 74 |
+
x1 = x0 + np.cos(sensor_angle) * sensor_range
|
| 75 |
+
y1 = y0 + np.sin(sensor_angle) * sensor_range
|
| 76 |
+
robot_belief = collision_check(x0, y0, x1, y1, ground_truth, robot_belief)
|
| 77 |
+
sensor_angle += sensor_angle_inc
|
| 78 |
+
return robot_belief
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def unexplored_area_check(x0, y0, x1, y1, current_belief):
|
| 82 |
+
x0 = x0.round()
|
| 83 |
+
y0 = y0.round()
|
| 84 |
+
x1 = x1.round()
|
| 85 |
+
y1 = y1.round()
|
| 86 |
+
dx, dy = abs(x1 - x0), abs(y1 - y0)
|
| 87 |
+
x, y = x0, y0
|
| 88 |
+
error = dx - dy
|
| 89 |
+
x_inc = 1 if x1 > x0 else -1
|
| 90 |
+
y_inc = 1 if y1 > y0 else -1
|
| 91 |
+
dx *= 2
|
| 92 |
+
dy *= 2
|
| 93 |
+
|
| 94 |
+
while 0 <= x < current_belief.shape[1] and 0 <= y < current_belief.shape[0]:
|
| 95 |
+
k = current_belief.item(y, x)
|
| 96 |
+
if x == x1 and y == y1:
|
| 97 |
+
break
|
| 98 |
+
|
| 99 |
+
if k == 1:
|
| 100 |
+
break
|
| 101 |
+
|
| 102 |
+
if k == 127:
|
| 103 |
+
current_belief.itemset((y, x), 0)
|
| 104 |
+
break
|
| 105 |
+
|
| 106 |
+
if error > 0:
|
| 107 |
+
x += x_inc
|
| 108 |
+
error -= dy
|
| 109 |
+
else:
|
| 110 |
+
y += y_inc
|
| 111 |
+
error += dx
|
| 112 |
+
|
| 113 |
+
return current_belief
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def calculate_utility(waypoint_position, sensor_range, robot_belief):
|
| 117 |
+
sensor_angle_inc = 5 / 180 * np.pi
|
| 118 |
+
sensor_angle = 0
|
| 119 |
+
x0 = waypoint_position[0]
|
| 120 |
+
y0 = waypoint_position[1]
|
| 121 |
+
current_belief = copy.deepcopy(robot_belief)
|
| 122 |
+
while sensor_angle < 2 * np.pi:
|
| 123 |
+
x1 = x0 + np.cos(sensor_angle) * sensor_range
|
| 124 |
+
y1 = y0 + np.sin(sensor_angle) * sensor_range
|
| 125 |
+
current_belief = unexplored_area_check(x0, y0, x1, y1, current_belief)
|
| 126 |
+
sensor_angle += sensor_angle_inc
|
| 127 |
+
utility = np.sum(robot_belief == 127) - np.sum(current_belief == 127)
|
| 128 |
+
return utility
|
planner/test_info_surfing.py
ADDED
|
@@ -0,0 +1,1071 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#######################################################################
|
| 2 |
+
# Name: test_info_surfing.py
|
| 3 |
+
#
|
| 4 |
+
# - Runs robot in environment using Info Surfing Planner
|
| 5 |
+
#######################################################################
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
sys.modules['TRAINING'] = False # False = Inference Testing
|
| 9 |
+
|
| 10 |
+
import copy
|
| 11 |
+
import os
|
| 12 |
+
import imageio
|
| 13 |
+
import numpy as np
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from time import time
|
| 17 |
+
from types import SimpleNamespace
|
| 18 |
+
from skimage.transform import resize
|
| 19 |
+
from taxabind_avs.satbind.kmeans_clustering import CombinedSilhouetteInertiaClusterer
|
| 20 |
+
from .env import Env
|
| 21 |
+
from .test_parameter import *
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
OPPOSITE_ACTIONS = {1: 3, 2: 4, 3: 1, 4: 2, 5: 7, 6: 8, 7: 5, 8: 6}
|
| 25 |
+
# color
|
| 26 |
+
agentColor = (1, 0.2, 0.6)
|
| 27 |
+
agentCommColor = (1, 0.6, 0.2)
|
| 28 |
+
obstacleColor = (0., 0., 0.)
|
| 29 |
+
targetNotFound = (0., 1., 0.)
|
| 30 |
+
targetFound = (0.545, 0.27, 0.075)
|
| 31 |
+
highestProbColor = (1., 0., 0.)
|
| 32 |
+
highestUncertaintyColor = (0., 0., 1.)
|
| 33 |
+
lowestProbColor = (1., 1., 1.)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class ISEnv:
|
| 37 |
+
"""Custom Environment that follows gym interface"""
|
| 38 |
+
metadata = {'render.modes': ['human']}
|
| 39 |
+
|
| 40 |
+
def __init__(self, global_step=0, state=None, shape=(24, 24), numAgents=8, observationSize=11, sensorSize=1, diag=False, save_image=False, clip_seg_tta=None):
|
| 41 |
+
|
| 42 |
+
self.global_step = global_step
|
| 43 |
+
self.infoMap = None
|
| 44 |
+
self.targetMap = None
|
| 45 |
+
self.agents = []
|
| 46 |
+
self.targets = []
|
| 47 |
+
self.numAgents = numAgents
|
| 48 |
+
self.found_target = []
|
| 49 |
+
self.shape = shape
|
| 50 |
+
self.observationSize = observationSize
|
| 51 |
+
self.sensorSize = sensorSize
|
| 52 |
+
self.diag = diag
|
| 53 |
+
self.communicateCircle = 11
|
| 54 |
+
self.distribs = []
|
| 55 |
+
self.mask = None
|
| 56 |
+
self.finished = False
|
| 57 |
+
self.action_vects = [[-1., 0.], [0., 1.], [1., 0], [0., -1.]] if not diag else [[-1., 0.], [0., 1.], [1., 0], [0., -1.], [-0.707, -0.707], [-0.707, 0.707], [0.707, 0.707], [0.707, -0.707]]
|
| 58 |
+
self.actionlist = []
|
| 59 |
+
self.IS_step = 0
|
| 60 |
+
self.save_image = save_image
|
| 61 |
+
self.clip_seg_tta = clip_seg_tta
|
| 62 |
+
self.perf_metrics = dict()
|
| 63 |
+
self.steps_to_first_tgt = None
|
| 64 |
+
self.steps_to_mid_tgt = None
|
| 65 |
+
self.steps_to_last_tgt = None
|
| 66 |
+
self.targets_found_on_path = []
|
| 67 |
+
self.step_since_tta = 0
|
| 68 |
+
self.IS_frame_files = []
|
| 69 |
+
self.bad_mask_init = False
|
| 70 |
+
|
| 71 |
+
# define env
|
| 72 |
+
self.env = Env(map_index=self.global_step, n_agent=numAgents, k_size=K_SIZE, plot=save_image, test=True)
|
| 73 |
+
|
| 74 |
+
# Overwrite state
|
| 75 |
+
if self.clip_seg_tta is not None:
|
| 76 |
+
self.clip_seg_tta.reset(sample_idx=self.global_step)
|
| 77 |
+
|
| 78 |
+
# Override target positions in env
|
| 79 |
+
self.env.target_positions = [(pose[1], pose[0]) for pose in self.clip_seg_tta.target_positions]
|
| 80 |
+
|
| 81 |
+
# Override segmentation mask
|
| 82 |
+
if not USE_CLIP_PREDS and OVERRIDE_MASK_DIR != "":
|
| 83 |
+
score_mask_path = os.path.join(OVERRIDE_MASK_DIR, self.clip_seg_tta.gt_mask_name)
|
| 84 |
+
print("score_mask_path: ", score_mask_path)
|
| 85 |
+
if os.path.exists(score_mask_path):
|
| 86 |
+
self.env.segmentation_mask = self.env.import_segmentation_mask(score_mask_path)
|
| 87 |
+
self.env.begin(self.env.map_start_position)
|
| 88 |
+
else:
|
| 89 |
+
print(f"\n\n{RED}ERROR: Trying to override, but score mask not found at path:{NC} ", score_mask_path)
|
| 90 |
+
self.bad_mask_init = True
|
| 91 |
+
|
| 92 |
+
# Save clustered embeds from sat encoder
|
| 93 |
+
if USE_CLIP_PREDS:
|
| 94 |
+
self.kmeans_clusterer = CombinedSilhouetteInertiaClusterer(
|
| 95 |
+
k_min=1,
|
| 96 |
+
k_max=8,
|
| 97 |
+
k_avg_max=4,
|
| 98 |
+
silhouette_threshold=0.15,
|
| 99 |
+
relative_threshold=0.15,
|
| 100 |
+
random_state=0,
|
| 101 |
+
min_patch_size=5,
|
| 102 |
+
n_smooth_iter=2,
|
| 103 |
+
ignore_label=-1,
|
| 104 |
+
plot=self.save_image,
|
| 105 |
+
gifs_dir = GIFS_PATH
|
| 106 |
+
)
|
| 107 |
+
# Generate kmeans clusters
|
| 108 |
+
self.kmeans_sat_embeds_clusters = self.kmeans_clusterer.fit_predict(
|
| 109 |
+
patch_embeds=self.clip_seg_tta.patch_embeds,
|
| 110 |
+
map_shape=(CLIP_GRIDS_DIMS[0], CLIP_GRIDS_DIMS[1]),
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
if EXECUTE_TTA:
|
| 114 |
+
print("Will execute TTA...")
|
| 115 |
+
|
| 116 |
+
IS_info_map = copy.deepcopy(self.env.segmentation_info_mask)
|
| 117 |
+
IS_agent_loc = copy.deepcopy(self.env.start_positions)
|
| 118 |
+
IS_target_loc = copy.deepcopy(self.env.target_positions)
|
| 119 |
+
state=[IS_info_map, IS_agent_loc, IS_target_loc]
|
| 120 |
+
self.setWorld(state)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def init_render(self):
|
| 124 |
+
"""
|
| 125 |
+
Call this once (e.g., in __init__ or just before the scenario loop)
|
| 126 |
+
to initialize storage for agent paths and turn interactive plotting on.
|
| 127 |
+
"""
|
| 128 |
+
# Keep track of each agent's trajectory
|
| 129 |
+
self.trajectories = [[] for _ in range(self.numAgents)]
|
| 130 |
+
self.trajectories_upscaled = [[] for _ in range(self.numAgents)]
|
| 131 |
+
|
| 132 |
+
# Turn on interactive mode so we can update the same figure repeatedly
|
| 133 |
+
plt.ion()
|
| 134 |
+
plt.figure(figsize=(6,6))
|
| 135 |
+
plt.title("Information Map with Agents, Targets, and Sensor Ranges")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def record_positions(self):
|
| 139 |
+
"""
|
| 140 |
+
Call this after all agents have moved in a step (or whenever you want to update
|
| 141 |
+
the trajectory). It appends the current positions of each agent to `self.trajectories`.
|
| 142 |
+
"""
|
| 143 |
+
for idx, agent in enumerate(self.agents):
|
| 144 |
+
self.trajectories[idx].append((agent.row, agent.col))
|
| 145 |
+
self.trajectories_upscaled[idx].append(self.env.graph_generator.grid_coords[agent.row, agent.col])
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def render(self, episode_num, step_num):
|
| 149 |
+
"""
|
| 150 |
+
Renders the current state in a single matplotlib plot.
|
| 151 |
+
Ensures consistent image size for GIF generation.
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
# Completely reset the figure to avoid leftover state
|
| 155 |
+
plt.close('all')
|
| 156 |
+
fig = plt.figure(figsize=(6.4, 4.8), dpi=100)
|
| 157 |
+
ax = fig.add_subplot(111)
|
| 158 |
+
|
| 159 |
+
# Plot the information map
|
| 160 |
+
ax.imshow(self.infoMap, origin='lower', cmap='gray')
|
| 161 |
+
|
| 162 |
+
# Show agent positions and their trajectories
|
| 163 |
+
for idx, agent in enumerate(self.agents):
|
| 164 |
+
positions = self.trajectories[idx]
|
| 165 |
+
if len(positions) > 1:
|
| 166 |
+
rows = [p[0] for p in positions]
|
| 167 |
+
cols = [p[1] for p in positions]
|
| 168 |
+
ax.plot(cols, rows, linewidth=1)
|
| 169 |
+
|
| 170 |
+
ax.scatter(agent.col, agent.row, marker='o', s=50)
|
| 171 |
+
|
| 172 |
+
# Plot target locations
|
| 173 |
+
for t in self.targets:
|
| 174 |
+
color = 'green' if np.isnan(t.time_found) else 'red'
|
| 175 |
+
ax.scatter(t.col, t.row, marker='x', s=100, color=color)
|
| 176 |
+
|
| 177 |
+
# Title and axis formatting
|
| 178 |
+
ax.set_title(f"Step: {self.IS_step}")
|
| 179 |
+
ax.invert_yaxis()
|
| 180 |
+
|
| 181 |
+
# Create output folder if it doesn't exist
|
| 182 |
+
if not os.path.exists(GIFS_PATH):
|
| 183 |
+
os.makedirs(GIFS_PATH)
|
| 184 |
+
|
| 185 |
+
# Save the frame with consistent canvas
|
| 186 |
+
frame_path = f'{GIFS_PATH}/IS_{episode_num}_{step_num}.png'
|
| 187 |
+
plt.savefig(frame_path, bbox_inches='tight', pad_inches=0.1)
|
| 188 |
+
self.IS_frame_files.append(frame_path)
|
| 189 |
+
|
| 190 |
+
# Cleanup
|
| 191 |
+
plt.close(fig)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def setWorld(self, state=None):
|
| 195 |
+
"""
|
| 196 |
+
1. empty all the element
|
| 197 |
+
2. create the new episode
|
| 198 |
+
"""
|
| 199 |
+
if state is not None:
|
| 200 |
+
self.infoMap = copy.deepcopy(state[0].reshape(self.shape).T)
|
| 201 |
+
agents = []
|
| 202 |
+
self.numAgents = len(state[1])
|
| 203 |
+
for a in range(1, self.numAgents + 1):
|
| 204 |
+
abs_pos = state[1].pop(0)
|
| 205 |
+
abs_pos = np.array(abs_pos)
|
| 206 |
+
row, col = self.env.graph_generator.find_closest_index_from_grid_coords_2d(np.array(abs_pos))
|
| 207 |
+
agents.append(Agent(ID=a, row=row, col=col, sensorSize=self.sensorSize, infoMap=np.copy(self.infoMap),
|
| 208 |
+
uncertaintyMap=np.copy(self.infoMap), shape=self.shape, numAgents=self.numAgents))
|
| 209 |
+
self.agents = agents
|
| 210 |
+
|
| 211 |
+
targets, n_targets = [], 1
|
| 212 |
+
for t in range(len(state[2])):
|
| 213 |
+
abs_pos = state[2].pop(0)
|
| 214 |
+
abs_pos = np.array(abs_pos)
|
| 215 |
+
row, col = self.env.graph_generator.find_closest_index_from_grid_coords_2d(abs_pos)
|
| 216 |
+
targets.append(Target(ID=n_targets, row=row, col=col, time_found=np.nan))
|
| 217 |
+
n_targets = n_targets + 1
|
| 218 |
+
self.targets = targets
|
| 219 |
+
|
| 220 |
+
def extractObservation(self, agent):
|
| 221 |
+
"""
|
| 222 |
+
Extract observations from information map
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
transform_row = self.observationSize // 2 - agent.row
|
| 226 |
+
transform_col = self.observationSize // 2 - agent.col
|
| 227 |
+
|
| 228 |
+
observation_layers = np.zeros((1, self.observationSize, self.observationSize))
|
| 229 |
+
min_row = max((agent.row - self.observationSize // 2), 0)
|
| 230 |
+
max_row = min((agent.row + self.observationSize // 2 + 1), self.shape[0])
|
| 231 |
+
min_col = max((agent.col - self.observationSize // 2), 0)
|
| 232 |
+
max_col = min((agent.col + self.observationSize // 2 + 1), self.shape[1])
|
| 233 |
+
|
| 234 |
+
observation = np.full((self.observationSize, self.observationSize), 0.)
|
| 235 |
+
infoMap = np.full((self.observationSize, self.observationSize), 0.)
|
| 236 |
+
densityMap = np.full((self.observationSize, self.observationSize), 0.)
|
| 237 |
+
|
| 238 |
+
infoMap[(min_row + transform_row):(max_row + transform_row),
|
| 239 |
+
(min_col + transform_col):(max_col + transform_col)] = self.infoMap[
|
| 240 |
+
min_row:max_row, min_col:max_col]
|
| 241 |
+
observation_layers[0] = infoMap
|
| 242 |
+
|
| 243 |
+
return observation_layers
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def listNextValidActions(self, agent_id, prev_action=0):
|
| 247 |
+
"""
|
| 248 |
+
No movement: 0
|
| 249 |
+
North (-1,0): 1
|
| 250 |
+
East (0,1): 2
|
| 251 |
+
South (1,0): 3
|
| 252 |
+
West (0,-1): 4
|
| 253 |
+
"""
|
| 254 |
+
available_actions = [0]
|
| 255 |
+
agent = self.agents[agent_id - 1]
|
| 256 |
+
|
| 257 |
+
MOVES = [(-1, 0), (0, 1), (1, 0), (0, -1), (-1, -1), (-1, 1), (1, 1), (1, -1)]
|
| 258 |
+
size = 4 + self.diag * 4
|
| 259 |
+
for action in range(size):
|
| 260 |
+
out_of_bounds = agent.row + MOVES[action][0] >= self.shape[0] \
|
| 261 |
+
or agent.row + MOVES[action][0] < 0\
|
| 262 |
+
or agent.col + MOVES[action][1] >= self.shape[1] \
|
| 263 |
+
or agent.col + MOVES[action][1] < 0
|
| 264 |
+
|
| 265 |
+
if (not out_of_bounds) and not (prev_action == OPPOSITE_ACTIONS[action + 1]):
|
| 266 |
+
available_actions.append(action + 1)
|
| 267 |
+
|
| 268 |
+
return np.array(available_actions)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def executeAction(self, agentID, action, timeStep):
|
| 272 |
+
"""
|
| 273 |
+
No movement: 0
|
| 274 |
+
North (-1,0): 1
|
| 275 |
+
East (0,1): 2
|
| 276 |
+
South (1,0): 3
|
| 277 |
+
West (0,-1): 4
|
| 278 |
+
LeftUp (-1,-1) : 5
|
| 279 |
+
RightUP (-1,1) :6
|
| 280 |
+
RightDown (1,1) :7
|
| 281 |
+
RightLeft (1,-1) :8
|
| 282 |
+
"""
|
| 283 |
+
agent = self.agents[agentID - 1]
|
| 284 |
+
origLoc = agent.getLocation()
|
| 285 |
+
|
| 286 |
+
if (action >= 1) and (action <= 8):
|
| 287 |
+
agent.move(action)
|
| 288 |
+
row, col = agent.getLocation()
|
| 289 |
+
|
| 290 |
+
# If the move is not valid, roll it back
|
| 291 |
+
if (row < 0) or (col < 0) or (row >= self.shape[0]) or (col >= self.shape[1]):
|
| 292 |
+
self.updateInfoCheckTarget(agentID, timeStep, origLoc)
|
| 293 |
+
return 0
|
| 294 |
+
|
| 295 |
+
elif action == 0:
|
| 296 |
+
self.updateInfoCheckTarget(agentID, timeStep, origLoc)
|
| 297 |
+
return 0
|
| 298 |
+
|
| 299 |
+
else:
|
| 300 |
+
print("INVALID ACTION: {}".format(action))
|
| 301 |
+
sys.exit()
|
| 302 |
+
|
| 303 |
+
newLoc = agent.getLocation()
|
| 304 |
+
self.updateInfoCheckTarget(agentID, timeStep, origLoc)
|
| 305 |
+
return action
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def updateInfoCheckTarget(self, agentID, timeStep, origLoc):
|
| 309 |
+
"""
|
| 310 |
+
update the self.infoMap and check whether the agent has found a target
|
| 311 |
+
"""
|
| 312 |
+
agent = self.agents[agentID - 1]
|
| 313 |
+
transform_row = self.sensorSize // 2 - agent.row
|
| 314 |
+
transform_col = self.sensorSize // 2 - agent.col
|
| 315 |
+
|
| 316 |
+
min_row = max((agent.row - self.sensorSize // 2), 0)
|
| 317 |
+
max_row = min((agent.row + self.sensorSize // 2 + 1), self.shape[0])
|
| 318 |
+
min_col = max((agent.col - self.sensorSize // 2), 0)
|
| 319 |
+
max_col = min((agent.col + self.sensorSize // 2 + 1), self.shape[1])
|
| 320 |
+
for t in self.targets:
|
| 321 |
+
if (t.row == agent.row) and (t.col == agent.col):
|
| 322 |
+
t.updateFound(timeStep)
|
| 323 |
+
self.found_target.append(t)
|
| 324 |
+
t.status = True
|
| 325 |
+
|
| 326 |
+
self.infoMap[min_row:max_row, min_col:max_col] *= 0.05
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def updateInfoEntireTrajectory(self, agentID):
|
| 330 |
+
"""
|
| 331 |
+
update the self.infoMap and check whether the agent has found a target
|
| 332 |
+
"""
|
| 333 |
+
traj = self.trajectories[agentID - 1]
|
| 334 |
+
|
| 335 |
+
for (row,col) in traj:
|
| 336 |
+
min_row = max((row - self.sensorSize // 2), 0)
|
| 337 |
+
max_row = min((row + self.sensorSize // 2 + 1), self.shape[0])
|
| 338 |
+
min_col = max((col - self.sensorSize // 2), 0)
|
| 339 |
+
max_col = min((col + self.sensorSize // 2 + 1), self.shape[1])
|
| 340 |
+
self.infoMap[min_row:max_row, min_col:max_col] *= 0.05
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
# Execute one time step within the environment
|
| 344 |
+
def step(self, agentID, action, timeStep):
|
| 345 |
+
"""
|
| 346 |
+
the agents execute the actions
|
| 347 |
+
No movement: 0
|
| 348 |
+
North (-1,0): 1
|
| 349 |
+
East (0,1): 2
|
| 350 |
+
South (1,0): 3
|
| 351 |
+
West (0,-1): 4
|
| 352 |
+
"""
|
| 353 |
+
assert (agentID > 0)
|
| 354 |
+
|
| 355 |
+
self.executeAction(agentID, action, timeStep)
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def observe(self, agentID):
|
| 359 |
+
assert (agentID > 0)
|
| 360 |
+
vectorObs = self.extractObservation(self.agents[agentID - 1])
|
| 361 |
+
return [vectorObs]
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def check_finish(self):
|
| 365 |
+
if TERMINATE_ON_TGTS_FOUND:
|
| 366 |
+
found_status = [t.time_found for t in self.targets]
|
| 367 |
+
d = False
|
| 368 |
+
if np.isnan(found_status).sum() == 0:
|
| 369 |
+
d = True
|
| 370 |
+
return d
|
| 371 |
+
else:
|
| 372 |
+
return False
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def gradVec(self, observation, agent):
|
| 376 |
+
a = observation[0]
|
| 377 |
+
|
| 378 |
+
# Make info & unc cells with low value as 0
|
| 379 |
+
a[a < 0.0002] = 0.0
|
| 380 |
+
|
| 381 |
+
# Center square from 11x11
|
| 382 |
+
a_11x11 = a[4:7, 4:7]
|
| 383 |
+
m_11x11 = np.array((a_11x11))
|
| 384 |
+
|
| 385 |
+
# Center square from 9x9
|
| 386 |
+
a_9x9 = self.pooling(a, (3, 3), stride=(1, 1), method='max', pad=False)
|
| 387 |
+
a_9x9 = a_9x9[3:6, 3:6]
|
| 388 |
+
m_9x9 = np.array((a_9x9))
|
| 389 |
+
|
| 390 |
+
# Center square from 6x6
|
| 391 |
+
a_6x6 = self.pooling(a, (6, 6), stride=(1, 1), method='max', pad=False)
|
| 392 |
+
a_6x6 = a_6x6[1:4, 1:4]
|
| 393 |
+
m_6x6 = np.array((a_6x6))
|
| 394 |
+
|
| 395 |
+
# Center square from 3x3
|
| 396 |
+
a_3x3 = self.pooling(a, (5, 5), stride=(3, 3), method='max', pad=False)
|
| 397 |
+
m_3x3 = np.array((a_3x3))
|
| 398 |
+
|
| 399 |
+
# Merging multiScales with weights
|
| 400 |
+
m = m_3x3 * 0.25 + m_6x6 * 0.25 + m_9x9 * 0.25 + m_11x11 * 0.25
|
| 401 |
+
a = m
|
| 402 |
+
|
| 403 |
+
adx, ady = np.gradient(a)
|
| 404 |
+
den = np.linalg.norm(np.array([adx[1, 1], ady[1, 1]]))
|
| 405 |
+
if (den != 0) and (not np.isnan(den)):
|
| 406 |
+
infovec = np.array([adx[1, 1], ady[1, 1]]) / den
|
| 407 |
+
else:
|
| 408 |
+
infovec = 0
|
| 409 |
+
agentvec = []
|
| 410 |
+
|
| 411 |
+
if len(agentvec) == 0:
|
| 412 |
+
den = np.linalg.norm(infovec)
|
| 413 |
+
if (den != 0) and (not np.isnan(den)):
|
| 414 |
+
direction = infovec / den
|
| 415 |
+
else:
|
| 416 |
+
direction = self.action_vects[np.random.randint(4 + self.diag * 4)]
|
| 417 |
+
else:
|
| 418 |
+
den = np.linalg.norm(np.mean(agentvec, 0))
|
| 419 |
+
if (den != 0) and (not np.isnan(den)):
|
| 420 |
+
agentvec = np.mean(agentvec, 0) / den
|
| 421 |
+
else:
|
| 422 |
+
agentvec = 0
|
| 423 |
+
|
| 424 |
+
den = np.linalg.norm(0.6 * infovec + 0.4 * agentvec)
|
| 425 |
+
if (den != 0) and (not np.isnan(den)):
|
| 426 |
+
direction = (0.6 * infovec + 0.4 * agentvec) / den
|
| 427 |
+
else:
|
| 428 |
+
direction = self.action_vects[np.random.randint(4 + self.diag * 4)]
|
| 429 |
+
|
| 430 |
+
action_vec = [[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]] if not self.diag else [[0., 0.], [-1., 0.], [0., 1.], [1., 0], [0., -1.], [-0.707, -0.707], [-0.707, 0.707], [0.707, 0.707], [0.707, -0.707]]
|
| 431 |
+
actionid = np.argmax([np.dot(direction, a) for a in action_vec])
|
| 432 |
+
actionid = self.best_valid_action(actionid, agent, direction)
|
| 433 |
+
return actionid
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def best_valid_action(self, actionid, agent, direction):
|
| 437 |
+
if len(self.actionlist) > 1:
|
| 438 |
+
if self.action_invalid(actionid, agent):
|
| 439 |
+
action_vec = [[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]] if not self.diag else [[0., 0.], [-1., 0.], [0., 1.], [1., 0], [0., -1.], [-0.707, -0.707], [-0.707, 0.707], [0.707, 0.707], [0.707, -0.707]]
|
| 440 |
+
actionid = np.array([np.dot(direction, a) for a in action_vec])
|
| 441 |
+
actionid = actionid.argsort()
|
| 442 |
+
pi = 3 + self.diag*4
|
| 443 |
+
while self.action_invalid(actionid[pi], agent) and pi >= 0:
|
| 444 |
+
pi -= 1
|
| 445 |
+
if pi == -1:
|
| 446 |
+
return OPPOSITE_ACTIONS[self.actionlist[self.IS_step - 1][agent - 1]]
|
| 447 |
+
elif actionid[pi] == 0:
|
| 448 |
+
return OPPOSITE_ACTIONS[self.actionlist[self.IS_step - 1][agent - 1]]
|
| 449 |
+
else:
|
| 450 |
+
return actionid[pi]
|
| 451 |
+
return actionid
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def action_invalid(self, action, agent):
|
| 455 |
+
# Going back to the previous cell is disabled
|
| 456 |
+
if action == OPPOSITE_ACTIONS[self.actionlist[self.IS_step - 1][agent - 1]]:
|
| 457 |
+
return True
|
| 458 |
+
# Move N,E,S,W
|
| 459 |
+
if (action >= 1) and (action <= 8):
|
| 460 |
+
agent = self.agents[agent - 1]
|
| 461 |
+
agent.move(action)
|
| 462 |
+
row, col = agent.getLocation()
|
| 463 |
+
|
| 464 |
+
# If the move is not valid, roll it back
|
| 465 |
+
if ((row < 0) or (col < 0) or (row >= self.shape[0]) or (col >= self.shape[1])):
|
| 466 |
+
agent.reverseMove(action)
|
| 467 |
+
return True
|
| 468 |
+
|
| 469 |
+
agent.reverseMove(action)
|
| 470 |
+
return False
|
| 471 |
+
return False
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def step_all_parallel(self):
|
| 475 |
+
actions = []
|
| 476 |
+
reward = 0
|
| 477 |
+
# Decide actions for each agent
|
| 478 |
+
for agent_id in range(1, self.numAgents + 1):
|
| 479 |
+
o = self.observe(agent_id)
|
| 480 |
+
actions.append(self.gradVec(o[0], agent_id))
|
| 481 |
+
self.actionlist.append(actions)
|
| 482 |
+
|
| 483 |
+
# Execute those actions
|
| 484 |
+
for agent_id in range(1, self.numAgents + 1):
|
| 485 |
+
self.step(agent_id, actions[agent_id - 1], self.IS_step)
|
| 486 |
+
|
| 487 |
+
# Record for visualization
|
| 488 |
+
self.record_positions()
|
| 489 |
+
|
| 490 |
+
def is_scenario(self, max_step=512, episode_number=0):
|
| 491 |
+
|
| 492 |
+
# Return all metrics as None if faulty mask init
|
| 493 |
+
if self.bad_mask_init:
|
| 494 |
+
self.perf_metrics['tax'] = None
|
| 495 |
+
self.perf_metrics['travel_dist'] = None
|
| 496 |
+
self.perf_metrics['travel_steps'] = None
|
| 497 |
+
self.perf_metrics['steps_to_first_tgt'] = None
|
| 498 |
+
self.perf_metrics['steps_to_mid_tgt'] = None
|
| 499 |
+
self.perf_metrics['steps_to_last_tgt'] = None
|
| 500 |
+
self.perf_metrics['explored_rate'] = None
|
| 501 |
+
self.perf_metrics['targets_found'] = None
|
| 502 |
+
self.perf_metrics['targets_total'] = None
|
| 503 |
+
self.perf_metrics['kmeans_k'] = None
|
| 504 |
+
self.perf_metrics['tgts_gt_score'] = None
|
| 505 |
+
self.perf_metrics['clip_inference_time'] = None
|
| 506 |
+
self.perf_metrics['tta_time'] = None
|
| 507 |
+
self.perf_metrics['success_rate'] = None
|
| 508 |
+
return
|
| 509 |
+
|
| 510 |
+
eps_start = time()
|
| 511 |
+
self.IS_step = 0
|
| 512 |
+
self.finished = False
|
| 513 |
+
reward = 0
|
| 514 |
+
|
| 515 |
+
# Initialize the rendering just once before the loop
|
| 516 |
+
self.init_render()
|
| 517 |
+
self.record_positions()
|
| 518 |
+
|
| 519 |
+
# Initial Setup
|
| 520 |
+
if LOAD_AVS_BENCH and USE_CLIP_PREDS:
|
| 521 |
+
if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]: # If heatmap is resized from clip original dims
|
| 522 |
+
heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT))
|
| 523 |
+
self.env.segmentation_info_mask = np.expand_dims(heatmap.T.flatten(), axis=1)
|
| 524 |
+
unnormalized_heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap_unnormalized, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT))
|
| 525 |
+
self.env.segmentation_info_mask_unnormalized = np.expand_dims(unnormalized_heatmap.T.flatten(), axis=1)
|
| 526 |
+
self.infoMap = copy.deepcopy(heatmap)
|
| 527 |
+
print("Resized heatmap to", NUM_COORDS_WIDTH, "x", NUM_COORDS_HEIGHT)
|
| 528 |
+
else:
|
| 529 |
+
self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1)
|
| 530 |
+
self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1)
|
| 531 |
+
self.infoMap = copy.deepcopy(self.clip_seg_tta.heatmap)
|
| 532 |
+
|
| 533 |
+
self.targets_found_on_path.append(self.env.num_new_targets_found)
|
| 534 |
+
|
| 535 |
+
while self.IS_step < max_step and not self.check_finish():
|
| 536 |
+
self.step_all_parallel()
|
| 537 |
+
self.IS_step += 1
|
| 538 |
+
|
| 539 |
+
# Render after each step
|
| 540 |
+
if self.save_image:
|
| 541 |
+
self.render(episode_num=self.global_step, step_num=self.IS_step)
|
| 542 |
+
|
| 543 |
+
# Update in env
|
| 544 |
+
next_position_list = [self.trajectories_upscaled[i][-1] for i, agent in enumerate(self.agents)]
|
| 545 |
+
dist_list = [0 for _ in range(self.numAgents)]
|
| 546 |
+
travel_dist_list = [self.compute_travel_distance(traj) for traj in self.trajectories]
|
| 547 |
+
self.env.multi_robot_step(next_position_list, dist_list, travel_dist_list)
|
| 548 |
+
self.targets_found_on_path.append(self.env.num_new_targets_found)
|
| 549 |
+
|
| 550 |
+
# TTA Update via Poisson Test (with KMeans clustering stats)
|
| 551 |
+
robot_id = 0 # Assume 1 agent for now
|
| 552 |
+
robot_traj = self.trajectories[robot_id]
|
| 553 |
+
if LOAD_AVS_BENCH and USE_CLIP_PREDS and EXECUTE_TTA:
|
| 554 |
+
flat_traj_coords = [robot_traj[i][1] * self.shape[0] + robot_traj[i][0] for i in range(len(robot_traj))]
|
| 555 |
+
robot = SimpleNamespace(
|
| 556 |
+
trajectory_coords=flat_traj_coords,
|
| 557 |
+
targets_found_on_path=self.targets_found_on_path
|
| 558 |
+
)
|
| 559 |
+
self.poisson_tta_update(robot, self.global_step, self.IS_step)
|
| 560 |
+
self.infoMap = copy.deepcopy(self.env.segmentation_info_mask.reshape((self.shape[1],self.shape[0])).T)
|
| 561 |
+
self.updateInfoEntireTrajectory(robot_id)
|
| 562 |
+
|
| 563 |
+
# Update metrics
|
| 564 |
+
self.log_metrics(step=self.IS_step-1)
|
| 565 |
+
|
| 566 |
+
### Save a frame to generate gif of robot trajectories ###
|
| 567 |
+
if self.save_image:
|
| 568 |
+
robots_route = [ ([], []) ] # Assume 1 robot
|
| 569 |
+
for point in self.trajectories_upscaled[robot_id]:
|
| 570 |
+
robots_route[robot_id][0].append(point[0])
|
| 571 |
+
robots_route[robot_id][1].append(point[1])
|
| 572 |
+
if not os.path.exists(GIFS_PATH):
|
| 573 |
+
os.makedirs(GIFS_PATH)
|
| 574 |
+
if LOAD_AVS_BENCH:
|
| 575 |
+
sound_id_override = None if self.clip_seg_tta.sound_ids == [] else self.clip_seg_tta.sound_ids[0]
|
| 576 |
+
self.env.plot_env(
|
| 577 |
+
self.global_step,
|
| 578 |
+
GIFS_PATH,
|
| 579 |
+
self.IS_step-1,
|
| 580 |
+
max(travel_dist_list),
|
| 581 |
+
robots_route,
|
| 582 |
+
img_path_override=self.clip_seg_tta.img_paths[0], # Viz 1st
|
| 583 |
+
sat_path_override=self.clip_seg_tta.imo_path,
|
| 584 |
+
msk_name_override=self.clip_seg_tta.species_name,
|
| 585 |
+
sound_id_override=sound_id_override,
|
| 586 |
+
)
|
| 587 |
+
else:
|
| 588 |
+
self.env.plot_env(
|
| 589 |
+
self.global_step,
|
| 590 |
+
GIFS_PATH,
|
| 591 |
+
self.IS_step-1,
|
| 592 |
+
max(travel_dist_list),
|
| 593 |
+
robots_route
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
# Log metrics
|
| 597 |
+
if LOAD_AVS_BENCH:
|
| 598 |
+
tax = Path(self.clip_seg_tta.gt_mask_name).stem
|
| 599 |
+
self.perf_metrics['tax'] = " ".join(tax.split("_")[1:])
|
| 600 |
+
else:
|
| 601 |
+
self.perf_metrics['tax'] = None
|
| 602 |
+
travel_distances = [self.compute_travel_distance(traj) for traj in self.trajectories]
|
| 603 |
+
self.perf_metrics['travel_dist'] = max(travel_distances)
|
| 604 |
+
self.perf_metrics['travel_steps'] = self.IS_step
|
| 605 |
+
self.perf_metrics['steps_to_first_tgt'] = self.steps_to_first_tgt
|
| 606 |
+
self.perf_metrics['steps_to_mid_tgt'] = self.steps_to_mid_tgt
|
| 607 |
+
self.perf_metrics['steps_to_last_tgt'] = self.steps_to_last_tgt
|
| 608 |
+
self.perf_metrics['explored_rate'] = self.env.explored_rate
|
| 609 |
+
self.perf_metrics['targets_found'] = self.env.targets_found_rate
|
| 610 |
+
self.perf_metrics['targets_total'] = len(self.env.target_positions)
|
| 611 |
+
if USE_CLIP_PREDS:
|
| 612 |
+
self.perf_metrics['kmeans_k'] = self.kmeans_clusterer.final_k
|
| 613 |
+
self.perf_metrics['tgts_gt_score'] = self.clip_seg_tta.tgts_gt_score
|
| 614 |
+
self.perf_metrics['clip_inference_time'] = self.clip_seg_tta.clip_inference_time
|
| 615 |
+
self.perf_metrics['tta_time'] = self.clip_seg_tta.tta_time
|
| 616 |
+
else:
|
| 617 |
+
self.perf_metrics['kmeans_k'] = None
|
| 618 |
+
self.perf_metrics['tgts_gt_score'] = None
|
| 619 |
+
self.perf_metrics['clip_inference_time'] = None
|
| 620 |
+
self.perf_metrics['tta_time'] = None
|
| 621 |
+
if FORCE_LOGGING_DONE_TGTS_FOUND and self.env.targets_found_rate == 1.0:
|
| 622 |
+
self.perf_metrics['success_rate'] = True
|
| 623 |
+
else:
|
| 624 |
+
self.perf_metrics['success_rate'] = self.env.check_done()[0]
|
| 625 |
+
|
| 626 |
+
# save gif
|
| 627 |
+
if self.save_image:
|
| 628 |
+
path = GIFS_PATH
|
| 629 |
+
self.make_gif(path, self.global_step)
|
| 630 |
+
|
| 631 |
+
print(YELLOW, f"[Eps {episode_number} Completed] Time Taken: {time()-eps_start:.2f}s, Steps: {self.IS_step}", NC)
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
def asStride(self, arr, sub_shape, stride):
|
| 635 |
+
"""
|
| 636 |
+
Get a strided sub-matrices view of an ndarray.
|
| 637 |
+
See also skimage.util.shape.view_as_windows()
|
| 638 |
+
"""
|
| 639 |
+
s0, s1 = arr.strides[:2]
|
| 640 |
+
m1, n1 = arr.shape[:2]
|
| 641 |
+
m2, n2 = sub_shape
|
| 642 |
+
view_shape = (1+(m1-m2)//stride[0], 1+(n1-n2)//stride[1], m2, n2)+arr.shape[2:]
|
| 643 |
+
strides = (stride[0]*s0, stride[1]*s1, s0, s1)+arr.strides[2:]
|
| 644 |
+
subs = np.lib.stride_tricks.as_strided(arr, view_shape, strides=strides)
|
| 645 |
+
return subs
|
| 646 |
+
|
| 647 |
+
|
| 648 |
+
def pooling(self, mat, ksize, stride=None, method='max', pad=False):
|
| 649 |
+
"""
|
| 650 |
+
Overlapping pooling on 2D or 3D data.
|
| 651 |
+
|
| 652 |
+
<mat>: ndarray, input array to pool.
|
| 653 |
+
<ksize>: tuple of 2, kernel size in (ky, kx).
|
| 654 |
+
<stride>: tuple of 2 or None, stride of pooling window.
|
| 655 |
+
If None, same as <ksize> (non-overlapping pooling).
|
| 656 |
+
<method>: str, 'max for max-pooling,
|
| 657 |
+
'mean' for mean-pooling.
|
| 658 |
+
<pad>: bool, pad <mat> or not. If no pad, output has size
|
| 659 |
+
(n-f)//s+1, n being <mat> size, f being kernel size, s stride.
|
| 660 |
+
if pad, output has size ceil(n/s).
|
| 661 |
+
|
| 662 |
+
Return <result>: pooled matrix.
|
| 663 |
+
"""
|
| 664 |
+
|
| 665 |
+
m, n = mat.shape[:2]
|
| 666 |
+
ky, kx = ksize
|
| 667 |
+
if stride is None:
|
| 668 |
+
stride = (ky, kx)
|
| 669 |
+
sy, sx = stride
|
| 670 |
+
|
| 671 |
+
_ceil = lambda x, y: int(np.ceil(x/float(y)))
|
| 672 |
+
|
| 673 |
+
if pad:
|
| 674 |
+
ny = _ceil(m,sy)
|
| 675 |
+
nx = _ceil(n,sx)
|
| 676 |
+
size = ((ny-1)*sy+ky, (nx-1)*sx+kx) + mat.shape[2:]
|
| 677 |
+
mat_pad = np.full(size,np.nan)
|
| 678 |
+
mat_pad[:m,:n,...] = mat
|
| 679 |
+
else:
|
| 680 |
+
mat_pad = mat[:(m-ky)//sy*sy+ky, :(n-kx)//sx*sx+kx, ...]
|
| 681 |
+
|
| 682 |
+
view = self.asStride(mat_pad,ksize,stride)
|
| 683 |
+
|
| 684 |
+
if method == 'max':
|
| 685 |
+
result = np.nanmax(view,axis=(2,3))
|
| 686 |
+
else:
|
| 687 |
+
result = np.nanmean(view,axis=(2,3))
|
| 688 |
+
|
| 689 |
+
return result
|
| 690 |
+
|
| 691 |
+
|
| 692 |
+
def compute_travel_distance(self, trajectory):
|
| 693 |
+
distance = 0.0
|
| 694 |
+
for i in range(1, len(trajectory)):
|
| 695 |
+
# Convert the tuple positions to numpy arrays for easy computation.
|
| 696 |
+
prev_pos = np.array(trajectory[i-1])
|
| 697 |
+
curr_pos = np.array(trajectory[i])
|
| 698 |
+
# Euclidean distance between consecutive positions.
|
| 699 |
+
distance += np.linalg.norm(curr_pos - prev_pos)
|
| 700 |
+
return distance
|
| 701 |
+
|
| 702 |
+
################################################################################
|
| 703 |
+
# SPPP Related Fns
|
| 704 |
+
################################################################################
|
| 705 |
+
|
| 706 |
+
def log_metrics(self, step):
|
| 707 |
+
# Update tgt found metrics
|
| 708 |
+
if self.steps_to_first_tgt is None and self.env.num_targets_found == 1:
|
| 709 |
+
self.steps_to_first_tgt = step + 1
|
| 710 |
+
if self.steps_to_mid_tgt is None and self.env.num_targets_found == int(len(self.env.target_positions) / 2):
|
| 711 |
+
self.steps_to_mid_tgt = step + 1
|
| 712 |
+
if self.steps_to_last_tgt is None and self.env.num_targets_found == len(self.env.target_positions):
|
| 713 |
+
self.steps_to_last_tgt = step + 1
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
def transpose_flat_idx(self, idx, H=NUM_COORDS_HEIGHT, W=NUM_COORDS_WIDTH):
|
| 717 |
+
"""
|
| 718 |
+
Transpose a flat index from an ``H×W`` grid to the equivalent
|
| 719 |
+
position in the ``W×H`` transposed grid while **keeping the result
|
| 720 |
+
in 1-D**.
|
| 721 |
+
"""
|
| 722 |
+
# --- Safety check to catch out-of-range indices ---
|
| 723 |
+
assert 0 <= idx < H * W, f"idx {idx} out of bounds for shape ({H}, {W})"
|
| 724 |
+
|
| 725 |
+
# Original (row, col)
|
| 726 |
+
row, col = divmod(idx, W)
|
| 727 |
+
# After transpose these coordinates swap
|
| 728 |
+
row_T, col_T = col, row
|
| 729 |
+
|
| 730 |
+
# Flatten back into 1-D (row-major) for the W×H grid
|
| 731 |
+
return row_T * H + col_T
|
| 732 |
+
|
| 733 |
+
|
| 734 |
+
def poisson_tta_update(self, robot, episode, step):
|
| 735 |
+
|
| 736 |
+
# Generate Kmeans Clusters Stats
|
| 737 |
+
# Scale index back to CLIP_GRIDS_DIMS to be compatible with CLIP patch size
|
| 738 |
+
if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]:
|
| 739 |
+
# High-res remap via pixel coordinates preserves exact neighbourhood
|
| 740 |
+
filt_traj_coords, filt_targets_found_on_path = self.scale_trajectory(
|
| 741 |
+
robot.trajectory_coords,
|
| 742 |
+
self.env.target_positions,
|
| 743 |
+
old_dims=(NUM_COORDS_HEIGHT, NUM_COORDS_WIDTH),
|
| 744 |
+
full_dims=(512, 512),
|
| 745 |
+
new_dims=(CLIP_GRIDS_DIMS[0], CLIP_GRIDS_DIMS[1])
|
| 746 |
+
)
|
| 747 |
+
else:
|
| 748 |
+
filt_traj_coords = [self.transpose_flat_idx(idx) for idx in robot.trajectory_coords]
|
| 749 |
+
filt_targets_found_on_path = robot.targets_found_on_path
|
| 750 |
+
|
| 751 |
+
region_stats_dict = self.kmeans_clusterer.compute_region_statistics(
|
| 752 |
+
self.kmeans_sat_embeds_clusters,
|
| 753 |
+
self.clip_seg_tta.heatmap_unnormalized,
|
| 754 |
+
filt_traj_coords,
|
| 755 |
+
episode_num=episode,
|
| 756 |
+
step_num=step
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
# Prep & execute TTA
|
| 760 |
+
self.step_since_tta += 1
|
| 761 |
+
if robot.targets_found_on_path[-1] or self.step_since_tta % STEPS_PER_TTA == 0:
|
| 762 |
+
|
| 763 |
+
num_cells = self.clip_seg_tta.heatmap.shape[0] * self.clip_seg_tta.heatmap.shape[1]
|
| 764 |
+
pos_sample_weight_scale, neg_sample_weight_scale = [], []
|
| 765 |
+
|
| 766 |
+
for i, sample_loc in enumerate(filt_traj_coords):
|
| 767 |
+
label = self.kmeans_clusterer.get_label_id(self.kmeans_sat_embeds_clusters, sample_loc)
|
| 768 |
+
num_patches = region_stats_dict[label]['num_patches']
|
| 769 |
+
patches_visited = region_stats_dict[label]['patches_visited']
|
| 770 |
+
expectation = region_stats_dict[label]['expectation']
|
| 771 |
+
|
| 772 |
+
# Exponent like focal loss to wait for more samples before confidently decreasing
|
| 773 |
+
pos_weight = 4.0
|
| 774 |
+
neg_weight = min(1.0, (patches_visited/(3*num_patches))**GAMMA_EXPONENT)
|
| 775 |
+
pos_sample_weight_scale.append(pos_weight)
|
| 776 |
+
neg_sample_weight_scale.append(neg_weight)
|
| 777 |
+
|
| 778 |
+
# Adaptative LR (as samples increase, increase LR to fit more datapoints)
|
| 779 |
+
adaptive_lr = MIN_LR + (MAX_LR - MIN_LR) * (step / num_cells)
|
| 780 |
+
|
| 781 |
+
# TTA Update
|
| 782 |
+
self.clip_seg_tta.execute_tta(
|
| 783 |
+
filt_traj_coords,
|
| 784 |
+
filt_targets_found_on_path,
|
| 785 |
+
tta_steps=NUM_TTA_STEPS,
|
| 786 |
+
lr=adaptive_lr,
|
| 787 |
+
pos_sample_weight=pos_sample_weight_scale,
|
| 788 |
+
neg_sample_weight=neg_sample_weight_scale,
|
| 789 |
+
reset_weights=RESET_WEIGHTS
|
| 790 |
+
)
|
| 791 |
+
if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]: # If heatmap is resized from clip original dims
|
| 792 |
+
heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT))
|
| 793 |
+
self.env.segmentation_info_mask = np.expand_dims(heatmap.T.flatten(), axis=1)
|
| 794 |
+
unnormalized_heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap_unnormalized, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT))
|
| 795 |
+
self.env.segmentation_info_mask_unnormalized = np.expand_dims(unnormalized_heatmap.T.flatten(), axis=1)
|
| 796 |
+
print("~Resized heatmap to", NUM_COORDS_WIDTH, "x", NUM_COORDS_HEIGHT)
|
| 797 |
+
else:
|
| 798 |
+
self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1)
|
| 799 |
+
self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1)
|
| 800 |
+
|
| 801 |
+
self.step_since_tta = 0
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
def convert_heatmap_resolution(self, heatmap, full_dims=(512, 512), new_dims=(24, 24)):
|
| 805 |
+
heatmap_large = resize(heatmap, full_dims, order=1, # order=1 → bilinear
|
| 806 |
+
mode='reflect', anti_aliasing=True)
|
| 807 |
+
|
| 808 |
+
coords = self.env.graph_generator.grid_coords # (N, N, 2)
|
| 809 |
+
rows, cols = coords[...,1], coords[...,0]
|
| 810 |
+
heatmap_resized = heatmap_large[rows, cols]
|
| 811 |
+
heatmap_resized = heatmap_resized.reshape(new_dims[1], new_dims[0])
|
| 812 |
+
return heatmap_resized
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
def convert_labelmap_resolution(self, labelmap, full_dims=(512, 512), new_dims=(24, 24)):
|
| 816 |
+
"""
|
| 817 |
+
1) Upsample via nearest‐neighbor to full_dims
|
| 818 |
+
2) Sample back down to your graph grid using grid_coords
|
| 819 |
+
"""
|
| 820 |
+
# 1) Upsample with nearest‐neighbor, preserving integer labels
|
| 821 |
+
up = resize(
|
| 822 |
+
labelmap,
|
| 823 |
+
full_dims,
|
| 824 |
+
order=0, # nearest‐neighbor
|
| 825 |
+
mode='edge', # padding mode
|
| 826 |
+
preserve_range=True, # don't normalize labels
|
| 827 |
+
anti_aliasing=False # must be False for labels
|
| 828 |
+
).astype(labelmap.dtype) # back to original integer dtype
|
| 829 |
+
|
| 830 |
+
# 2) Downsample via your precomputed grid coords (N×N×2)
|
| 831 |
+
coords = self.env.graph_generator.grid_coords # shape (N, N, 2)
|
| 832 |
+
rows = coords[...,1].astype(int)
|
| 833 |
+
cols = coords[...,0].astype(int)
|
| 834 |
+
|
| 835 |
+
small = up[rows, cols] # shape (N, N)
|
| 836 |
+
small = small.reshape(new_dims[0], new_dims[1])
|
| 837 |
+
return small
|
| 838 |
+
|
| 839 |
+
|
| 840 |
+
def scale_trajectory(self,
|
| 841 |
+
flat_indices,
|
| 842 |
+
targets,
|
| 843 |
+
old_dims=(17, 17),
|
| 844 |
+
full_dims=(512, 512),
|
| 845 |
+
new_dims=(24, 24)):
|
| 846 |
+
"""
|
| 847 |
+
Args:
|
| 848 |
+
flat_indices: list of ints in [0..old_H*old_W-1]
|
| 849 |
+
targets: list of (y_pix, x_pix) in [0..full_H-1]
|
| 850 |
+
old_dims: (old_H, old_W)
|
| 851 |
+
full_dims: (full_H, full_W)
|
| 852 |
+
new_dims: (new_H, new_W)
|
| 853 |
+
|
| 854 |
+
Returns:
|
| 855 |
+
new_flat_traj: list of unique flattened indices in new_H×new_W
|
| 856 |
+
counts: list of ints, same length as new_flat_traj
|
| 857 |
+
"""
|
| 858 |
+
old_H, old_W = old_dims
|
| 859 |
+
full_H, full_W = full_dims
|
| 860 |
+
new_H, new_W = new_dims
|
| 861 |
+
|
| 862 |
+
# 1) bin targets into new grid
|
| 863 |
+
cell_h_new = full_H / new_H
|
| 864 |
+
cell_w_new = full_W / new_W
|
| 865 |
+
grid_counts = [[0]*new_W for _ in range(new_H)]
|
| 866 |
+
for x_pix, y_pix in targets: # note (x, y) order as in original implementation
|
| 867 |
+
i_t = min(int(y_pix / cell_h_new), new_H - 1)
|
| 868 |
+
j_t = min(int(x_pix / cell_w_new), new_W - 1)
|
| 869 |
+
grid_counts[i_t][j_t] += 1
|
| 870 |
+
|
| 871 |
+
# 2) Walk the trajectory indices and project each old cell's *entire
|
| 872 |
+
# pixel footprint* onto the finer 24×24 grid.
|
| 873 |
+
cell_h_full = full_H / old_H
|
| 874 |
+
cell_w_full = full_W / old_W
|
| 875 |
+
|
| 876 |
+
seen = set()
|
| 877 |
+
new_flat_traj = []
|
| 878 |
+
|
| 879 |
+
for node_idx in flat_indices:
|
| 880 |
+
if node_idx < 0 or node_idx >= len(self.env.graph_generator.node_coords):
|
| 881 |
+
continue
|
| 882 |
+
|
| 883 |
+
coord_xy = self.env.graph_generator.node_coords[node_idx]
|
| 884 |
+
try:
|
| 885 |
+
row_old, col_old = self.env.graph_generator.find_index_from_grid_coords_2d(coord_xy)
|
| 886 |
+
except Exception:
|
| 887 |
+
continue
|
| 888 |
+
|
| 889 |
+
# Bounding box of the old cell in full-resolution pixel space
|
| 890 |
+
y0 = row_old * cell_h_full
|
| 891 |
+
y1 = (row_old + 1) * cell_h_full
|
| 892 |
+
x0 = col_old * cell_w_full
|
| 893 |
+
x1 = (col_old + 1) * cell_w_full
|
| 894 |
+
|
| 895 |
+
# Which new-grid rows & cols overlap? (inclusive ranges)
|
| 896 |
+
i_start = max(0, min(int(y0 / cell_h_new), new_H - 1))
|
| 897 |
+
i_end = max(0, min(int((y1 - 1) / cell_h_new), new_H - 1))
|
| 898 |
+
j_start = max(0, min(int(x0 / cell_w_new), new_W - 1))
|
| 899 |
+
j_end = max(0, min(int((x1 - 1) / cell_w_new), new_W - 1))
|
| 900 |
+
|
| 901 |
+
for ii in range(i_start, i_end + 1):
|
| 902 |
+
for jj in range(j_start, j_end + 1):
|
| 903 |
+
f_new = ii * new_W + jj
|
| 904 |
+
if f_new not in seen:
|
| 905 |
+
seen.add(f_new)
|
| 906 |
+
new_flat_traj.append(f_new)
|
| 907 |
+
|
| 908 |
+
# 3) annotate counts
|
| 909 |
+
counts = []
|
| 910 |
+
for f in new_flat_traj:
|
| 911 |
+
i_new, j_new = divmod(f, new_W)
|
| 912 |
+
counts.append(grid_counts[i_new][j_new])
|
| 913 |
+
|
| 914 |
+
return new_flat_traj, counts
|
| 915 |
+
|
| 916 |
+
|
| 917 |
+
################################################################################
|
| 918 |
+
|
| 919 |
+
def make_gif(self, path, n):
|
| 920 |
+
""" Generate a gif given list of images """
|
| 921 |
+
with imageio.get_writer('{}/{}_target_rate_{:.2f}.gif'.format(path, n, self.env.targets_found_rate), mode='I',
|
| 922 |
+
fps=5) as writer:
|
| 923 |
+
for frame in self.env.frame_files:
|
| 924 |
+
image = imageio.imread(frame)
|
| 925 |
+
writer.append_data(image)
|
| 926 |
+
print('gif complete\n')
|
| 927 |
+
|
| 928 |
+
# Remove files
|
| 929 |
+
for filename in self.env.frame_files[:-1]:
|
| 930 |
+
os.remove(filename)
|
| 931 |
+
|
| 932 |
+
# For KMeans gif
|
| 933 |
+
if LOAD_AVS_BENCH and USE_CLIP_PREDS:
|
| 934 |
+
with imageio.get_writer('{}/{}_kmeans_stats.gif'.format(path, n), mode='I',
|
| 935 |
+
fps=5) as writer:
|
| 936 |
+
for frame in self.kmeans_clusterer.kmeans_frame_files:
|
| 937 |
+
image = imageio.imread(frame)
|
| 938 |
+
writer.append_data(image)
|
| 939 |
+
print('Kmeans Clusterer gif complete\n')
|
| 940 |
+
|
| 941 |
+
# Remove files
|
| 942 |
+
for filename in self.kmeans_clusterer.kmeans_frame_files[:-1]:
|
| 943 |
+
os.remove(filename)
|
| 944 |
+
|
| 945 |
+
|
| 946 |
+
# IS gif
|
| 947 |
+
with imageio.get_writer('{}/{}_IS.gif'.format(path, n), mode='I',
|
| 948 |
+
fps=5) as writer:
|
| 949 |
+
for frame in self.IS_frame_files:
|
| 950 |
+
image = imageio.imread(frame)
|
| 951 |
+
writer.append_data(image)
|
| 952 |
+
print('Kmeans Clusterer gif complete\n')
|
| 953 |
+
|
| 954 |
+
# Remove files
|
| 955 |
+
for filename in self.IS_frame_files[:-1]:
|
| 956 |
+
os.remove(filename)
|
| 957 |
+
|
| 958 |
+
################################################################################
|
| 959 |
+
|
| 960 |
+
|
| 961 |
+
class Agent:
|
| 962 |
+
def __init__(self, ID, infoMap=None, uncertaintyMap=None, shape=None, row=0, col=0, sensorSize=9, numAgents=8):
|
| 963 |
+
self.ID = ID
|
| 964 |
+
self.row = row
|
| 965 |
+
self.col = col
|
| 966 |
+
self.numAgents = numAgents
|
| 967 |
+
self.sensorSize = sensorSize
|
| 968 |
+
|
| 969 |
+
def setLocation(self, row, col):
|
| 970 |
+
self.row = row
|
| 971 |
+
self.col = col
|
| 972 |
+
|
| 973 |
+
def getLocation(self):
|
| 974 |
+
return [self.row, self.col]
|
| 975 |
+
|
| 976 |
+
def move(self, action):
|
| 977 |
+
"""
|
| 978 |
+
No movement: 0
|
| 979 |
+
North (-1,0): 1
|
| 980 |
+
East (0,1): 2
|
| 981 |
+
South (1,0): 3
|
| 982 |
+
West (0,-1): 4
|
| 983 |
+
LeftUp (-1,-1) : 5
|
| 984 |
+
RightUP (-1,1) :6
|
| 985 |
+
RightDown (1,1) :7
|
| 986 |
+
RightLeft (1,-1) :8
|
| 987 |
+
check valid action of the agent. be sure not to be out of the boundary
|
| 988 |
+
"""
|
| 989 |
+
if action == 0:
|
| 990 |
+
return 0
|
| 991 |
+
elif action == 1:
|
| 992 |
+
self.row -= 1
|
| 993 |
+
elif action == 2:
|
| 994 |
+
self.col += 1
|
| 995 |
+
elif action == 3:
|
| 996 |
+
self.row += 1
|
| 997 |
+
elif action == 4:
|
| 998 |
+
self.col -= 1
|
| 999 |
+
elif action == 5:
|
| 1000 |
+
self.row -= 1
|
| 1001 |
+
self.col -= 1
|
| 1002 |
+
elif action == 6:
|
| 1003 |
+
self.row -= 1
|
| 1004 |
+
self.col += 1
|
| 1005 |
+
elif action == 7:
|
| 1006 |
+
self.row += 1
|
| 1007 |
+
self.col += 1
|
| 1008 |
+
elif action == 8:
|
| 1009 |
+
self.row += 1
|
| 1010 |
+
self.col -= 1
|
| 1011 |
+
|
| 1012 |
+
def reverseMove(self, action):
|
| 1013 |
+
if action == 0:
|
| 1014 |
+
return 0
|
| 1015 |
+
elif action == 1:
|
| 1016 |
+
self.row += 1
|
| 1017 |
+
elif action == 2:
|
| 1018 |
+
self.col -= 1
|
| 1019 |
+
elif action == 3:
|
| 1020 |
+
self.row -= 1
|
| 1021 |
+
elif action == 4:
|
| 1022 |
+
self.col += 1
|
| 1023 |
+
elif action == 5:
|
| 1024 |
+
self.row += 1
|
| 1025 |
+
self.col += 1
|
| 1026 |
+
elif action == 6:
|
| 1027 |
+
self.row += 1
|
| 1028 |
+
self.col -= 1
|
| 1029 |
+
elif action == 7:
|
| 1030 |
+
self.row -= 1
|
| 1031 |
+
self.col -= 1
|
| 1032 |
+
elif action == 8:
|
| 1033 |
+
self.row -= 1
|
| 1034 |
+
self.col += 1
|
| 1035 |
+
else:
|
| 1036 |
+
print("agent can only move NESW/1234")
|
| 1037 |
+
sys.exit()
|
| 1038 |
+
|
| 1039 |
+
|
| 1040 |
+
class Target:
|
| 1041 |
+
def __init__(self, row, col, ID, time_found=np.nan):
|
| 1042 |
+
self.row = row
|
| 1043 |
+
self.col = col
|
| 1044 |
+
self.ID = ID
|
| 1045 |
+
self.time_found = time_found
|
| 1046 |
+
self.status = None
|
| 1047 |
+
self.time_visited = time_found
|
| 1048 |
+
|
| 1049 |
+
def getLocation(self):
|
| 1050 |
+
return self.row, self.col
|
| 1051 |
+
|
| 1052 |
+
def updateFound(self, timeStep):
|
| 1053 |
+
if np.isnan(self.time_found):
|
| 1054 |
+
self.time_found = timeStep
|
| 1055 |
+
|
| 1056 |
+
def updateVisited(self, timeStep):
|
| 1057 |
+
if np.isnan(self.time_visited):
|
| 1058 |
+
self.time_visited = timeStep
|
| 1059 |
+
|
| 1060 |
+
|
| 1061 |
+
if __name__ == "__main__":
|
| 1062 |
+
|
| 1063 |
+
search_env = Env(map_index=1, k_size=K_SIZE, n_agent=NUM_ROBOTS, plot=SAVE_GIFS)
|
| 1064 |
+
|
| 1065 |
+
IS_info_map = search_env.segmentation_info_mask
|
| 1066 |
+
IS_agent_loc = search_env.start_positions
|
| 1067 |
+
IS_target_loc = [[312, 123], [123, 312], [312, 312], [123, 123]]
|
| 1068 |
+
|
| 1069 |
+
env = ISEnv(state=[IS_info_map, IS_agent_loc, IS_target_loc], shape=(NUM_COORDS_HEIGHT, NUM_COORDS_WIDTH))
|
| 1070 |
+
env.is_scenario(NUM_EPS_STEPS)
|
| 1071 |
+
print()
|
planner/test_parameter.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
############################################################################################
|
| 2 |
+
# Name: test_parameter.py
|
| 3 |
+
#
|
| 4 |
+
# NOTE: Change all your hyper-params here for eval
|
| 5 |
+
# Simple How-To Guide:
|
| 6 |
+
# 1. CLIP TTA: USE_CLIP_PREDS=True, EXECUTE_TTA=True
|
| 7 |
+
# 2. CLIP (No TTA): USE_CLIP_PREDS=True, EXECUTE_TTA=False
|
| 8 |
+
# 3. Custom masks (e.g. LISA): USE_CLIP_PREDS=False, EXECUTE_TTA=False
|
| 9 |
+
############################################################################################
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
sys.modules['TRAINING'] = False # False = Inference Testing
|
| 14 |
+
|
| 15 |
+
###############################################################
|
| 16 |
+
# Overload Params
|
| 17 |
+
###############################################################
|
| 18 |
+
|
| 19 |
+
OPT_VARS = {}
|
| 20 |
+
def getenv(var_name, default=None, cast_type=str):
|
| 21 |
+
try:
|
| 22 |
+
value = os.environ.get(var_name, None)
|
| 23 |
+
if value is None:
|
| 24 |
+
result = default
|
| 25 |
+
elif cast_type == bool:
|
| 26 |
+
result = value.lower() in ("true", "1", "yes")
|
| 27 |
+
else:
|
| 28 |
+
result = cast_type(value)
|
| 29 |
+
except (ValueError, TypeError):
|
| 30 |
+
result = default
|
| 31 |
+
|
| 32 |
+
OPT_VARS[var_name] = result # Log the result
|
| 33 |
+
return result
|
| 34 |
+
|
| 35 |
+
###############################################################
|
| 36 |
+
# General
|
| 37 |
+
###############################################################
|
| 38 |
+
|
| 39 |
+
# --- GENERAL --- #
|
| 40 |
+
USE_GPU = False
|
| 41 |
+
NUM_GPU = getenv("NUM_GPU", default=1, cast_type=int) # the number of GPUs
|
| 42 |
+
NUM_META_AGENT = getenv("NUM_META_AGENT", default=2, cast_type=int) # the number of concurrent processes
|
| 43 |
+
NUM_EPS_STEPS = getenv("NUM_EPS_STEPS", default=400, cast_type=int)
|
| 44 |
+
FIX_START_POSITION = getenv("FIX_START_POSITION", default=True, cast_type=bool) # Whether to fix the starting position of the robots (middle index)
|
| 45 |
+
NUM_ROBOTS = 1 # Only allow for 1 robot
|
| 46 |
+
NUM_COORDS_WIDTH=24 # How many node coords across width?
|
| 47 |
+
NUM_COORDS_HEIGHT=24 # How many node coords across height?
|
| 48 |
+
CLIP_GRIDS_DIMS=[24,24] # [16,16] if 'openai/clip-vit-large-patch14-336'
|
| 49 |
+
SENSOR_RANGE=80 # Only applicable to 'circle' sensor model
|
| 50 |
+
SENSOR_MODEL="rectangular" # "rectangular", "circle" (NOTE: no colllision check for rectangular)
|
| 51 |
+
TERMINATE_ON_TGTS_FOUND = True # Whether to terminate episode when all targets found
|
| 52 |
+
FORCE_LOGGING_DONE_TGTS_FOUND = True # Whether to force csv logging when all targets found
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# --- Planner Params --- #
|
| 56 |
+
POLICY = getenv("POLICY", default="RL", cast_type=str)
|
| 57 |
+
NUM_TEST = 800 # Overriden if LOAD_AVS_BENCH
|
| 58 |
+
NUM_RUN = 1
|
| 59 |
+
MODEL_NAME = "avs_rl_policy.pth"
|
| 60 |
+
INPUT_DIM = 4
|
| 61 |
+
EMBEDDING_DIM = 128
|
| 62 |
+
K_SIZE = 8
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# --- Folders & Visualizations --- #
|
| 66 |
+
GRIDMAP_SET_DIR = "maps/gpt4o/envs_val"
|
| 67 |
+
MASK_SET_DIR = "maps/example/masks_val" # Overriden if LOAD_AVS_BENCH
|
| 68 |
+
TARGETS_SET_DIR = ""
|
| 69 |
+
# TARGETS_SET_DIR = "maps/example/gt_masks_val_with_tgts" # Overriden if LOAD_AVS_BENCH
|
| 70 |
+
OVERRIDE_MASK_DIR = getenv("OVERRIDE_MASK_DIR", default="", cast_type=str) # Override initial score mask from CLIP
|
| 71 |
+
SAVE_GIFS = getenv("SAVE_GIFS", default=True, cast_type=bool) # do you want to save GIFs
|
| 72 |
+
FOLDER_NAME = 'avs_search'
|
| 73 |
+
MODEL_PATH = f'inference/model'
|
| 74 |
+
GIFS_PATH = f'inference/test_results/gifs/{FOLDER_NAME}'
|
| 75 |
+
LOG_PATH = f'inference/test_results/log/{FOLDER_NAME}'
|
| 76 |
+
LOG_TEMPLATE_XLSX = f'inference/template.xlsx'
|
| 77 |
+
CSV_EXPT_NAME = getenv("CSV_EXPT_NAME", default="data", cast_type=str)
|
| 78 |
+
VIZ_GRAPH_EDGES = False # do you want to visualize the graph edges
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
#######################################################################
|
| 82 |
+
# AVS Params
|
| 83 |
+
#######################################################################
|
| 84 |
+
|
| 85 |
+
# General PARAMS
|
| 86 |
+
USE_CLIP_PREDS = getenv("USE_CLIP_PREDS", default=True, cast_type=bool) # If false, use custom masks from OVERRIDE_MASK_DIR
|
| 87 |
+
QUERY_TAX = getenv("QUERY_TAX", default="", cast_type=str) # "" = Test all tax (can accept taxonomy substrings)
|
| 88 |
+
EXECUTE_TTA = getenv("EXECUTE_TTA", default=True, cast_type=bool) # Whether to execute TTA mask updates
|
| 89 |
+
QUERY_MODALITY = getenv("QUERY_MODALITY", default="image", cast_type=str) # "image", "text", "sound"
|
| 90 |
+
STEPS_PER_TTA = 20 # no. steps before each TTA series
|
| 91 |
+
NUM_TTA_STEPS = 1 # no. of TTA steps during each series
|
| 92 |
+
RESET_WEIGHTS = True
|
| 93 |
+
MIN_LR = 1e-6
|
| 94 |
+
MAX_LR = 1e-5
|
| 95 |
+
GAMMA_EXPONENT = 2
|
| 96 |
+
|
| 97 |
+
# Paths related to AVS (TRAIN w/ TARGETS)
|
| 98 |
+
LOAD_AVS_BENCH = True # Whether to init AVS datasets
|
| 99 |
+
AVS_IMG_DIR = '/mnt/hdd/avs_bench_ds/inat21'
|
| 100 |
+
AVS_IMO_DIR = '/mnt/hdd/avs_bench_ds/sat_jpg/train_512px'
|
| 101 |
+
AVS_INAT_JSON_PATH = '/mnt/hdd/avs_bench_ds/inat21/train.json'
|
| 102 |
+
AVS_SOUND_DIR = '/mnt/hdd/avs_bench_ds/sound_mp3/test'
|
| 103 |
+
AVS_GAUSSIAN_BLUR_KERNEL = (5,5)
|
| 104 |
+
AVS_SAT_TO_IMG_IDS_PATH = getenv("AVS_SAT_TO_IMG_IDS_PATH", default="search_tri_modal|val_in_domain", cast_type=str)
|
| 105 |
+
AVS_LOAD_PRETRAINED_HF_CHECKPOINT = getenv("AVS_LOAD_PRETRAINED_HF_CHECKPOINT", default=True, cast_type=bool) # If false, load locally using CHECKPOINT_PATHs
|
| 106 |
+
AVS_SAT_CHECKPOINT_PATH = getenv("AVS_SAT_CHECKPOINT_PATH", default="", cast_type=str)
|
| 107 |
+
AVS_SOUND_CHECKPOINT_PATH = getenv("AVS_SOUND_CHECKPOINT_PATH", default="", cast_type=str)
|
| 108 |
+
|
| 109 |
+
#######################################################################
|
| 110 |
+
# UTILS
|
| 111 |
+
#######################################################################
|
| 112 |
+
|
| 113 |
+
# COLORS (for printing)
|
| 114 |
+
RED='\033[1;31m'
|
| 115 |
+
GREEN='\033[1;32m'
|
| 116 |
+
YELLOW='\033[1;93m'
|
| 117 |
+
NC_BOLD='\033[1m' # Bold, No Color
|
| 118 |
+
NC='\033[0m' # No Color
|
planner/test_worker.py
ADDED
|
@@ -0,0 +1,590 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#######################################################################
|
| 2 |
+
# Name: test_worker.py
|
| 3 |
+
#
|
| 4 |
+
# - Runs robot in environment using RL Planner
|
| 5 |
+
#######################################################################
|
| 6 |
+
|
| 7 |
+
from .test_parameter import *
|
| 8 |
+
|
| 9 |
+
import imageio
|
| 10 |
+
import os
|
| 11 |
+
import copy
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
from time import time
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from skimage.transform import resize
|
| 17 |
+
from taxabind_avs.satbind.kmeans_clustering import CombinedSilhouetteInertiaClusterer
|
| 18 |
+
from .env import Env
|
| 19 |
+
from .robot import Robot
|
| 20 |
+
|
| 21 |
+
np.seterr(invalid='raise', divide='raise')
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class TestWorker:
|
| 25 |
+
def __init__(self, meta_agent_id, n_agent, policy_net, global_step, device='cuda', greedy=False, save_image=False, clip_seg_tta=None):
|
| 26 |
+
self.device = device
|
| 27 |
+
self.greedy = greedy
|
| 28 |
+
self.n_agent = n_agent
|
| 29 |
+
self.metaAgentID = meta_agent_id
|
| 30 |
+
self.global_step = global_step
|
| 31 |
+
self.k_size = K_SIZE
|
| 32 |
+
self.save_image = save_image
|
| 33 |
+
self.clip_seg_tta = clip_seg_tta
|
| 34 |
+
self.execute_tta = EXECUTE_TTA # Added to interface with app.py
|
| 35 |
+
|
| 36 |
+
self.env = Env(map_index=self.global_step, n_agent=n_agent, k_size=self.k_size, plot=save_image, test=True)
|
| 37 |
+
self.local_policy_net = policy_net
|
| 38 |
+
|
| 39 |
+
self.robot_list = []
|
| 40 |
+
self.all_robot_positions = []
|
| 41 |
+
for i in range(self.n_agent):
|
| 42 |
+
robot_position = self.env.start_positions[i]
|
| 43 |
+
robot = Robot(robot_id=i, position=robot_position, plot=save_image)
|
| 44 |
+
self.robot_list.append(robot)
|
| 45 |
+
self.all_robot_positions.append(robot_position)
|
| 46 |
+
|
| 47 |
+
self.perf_metrics = dict()
|
| 48 |
+
self.bad_mask_init = False
|
| 49 |
+
|
| 50 |
+
# NOTE: Option to override gifs_path to interface with app.py
|
| 51 |
+
self.gifs_path = GIFS_PATH
|
| 52 |
+
|
| 53 |
+
# NOTE: updated due to app.py (hf does not allow heatmap to persist)
|
| 54 |
+
if LOAD_AVS_BENCH:
|
| 55 |
+
if clip_seg_tta is not None:
|
| 56 |
+
heatmap, heatmap_unnormalized, heatmap_unnormalized_initial, patch_embeds = self.clip_seg_tta.reset(sample_idx=self.global_step)
|
| 57 |
+
self.clip_seg_tta.heatmap = heatmap
|
| 58 |
+
self.clip_seg_tta.heatmap_unnormalized = heatmap_unnormalized
|
| 59 |
+
self.clip_seg_tta.heatmap_unnormalized_initial = heatmap_unnormalized_initial
|
| 60 |
+
self.clip_seg_tta.patch_embeds = patch_embeds
|
| 61 |
+
|
| 62 |
+
# Override target positions in env
|
| 63 |
+
self.env.target_positions = [(pose[1], pose[0]) for pose in self.clip_seg_tta.target_positions]
|
| 64 |
+
|
| 65 |
+
# Override segmentation mask
|
| 66 |
+
if not USE_CLIP_PREDS and OVERRIDE_MASK_DIR != "":
|
| 67 |
+
score_mask_path = os.path.join(OVERRIDE_MASK_DIR, self.clip_seg_tta.gt_mask_name)
|
| 68 |
+
print("score_mask_path: ", score_mask_path)
|
| 69 |
+
if os.path.exists(score_mask_path):
|
| 70 |
+
self.env.segmentation_mask = self.env.import_segmentation_mask(score_mask_path)
|
| 71 |
+
self.env.begin(self.env.map_start_position)
|
| 72 |
+
else:
|
| 73 |
+
print(f"\n\n{RED}ERROR: Trying to override, but score mask not found at path:{NC} ", score_mask_path)
|
| 74 |
+
self.bad_mask_init = True
|
| 75 |
+
|
| 76 |
+
# Save clustered embeds from sat encoder
|
| 77 |
+
if USE_CLIP_PREDS:
|
| 78 |
+
self.kmeans_clusterer = CombinedSilhouetteInertiaClusterer(
|
| 79 |
+
k_min=1,
|
| 80 |
+
k_max=8,
|
| 81 |
+
k_avg_max=4,
|
| 82 |
+
silhouette_threshold=0.15,
|
| 83 |
+
relative_threshold=0.15,
|
| 84 |
+
random_state=0,
|
| 85 |
+
min_patch_size=5,
|
| 86 |
+
n_smooth_iter=2,
|
| 87 |
+
ignore_label=-1,
|
| 88 |
+
plot=self.save_image,
|
| 89 |
+
gifs_dir = GIFS_PATH
|
| 90 |
+
)
|
| 91 |
+
# Generate kmeans clusters
|
| 92 |
+
self.kmeans_sat_embeds_clusters = self.kmeans_clusterer.fit_predict(
|
| 93 |
+
patch_embeds=self.clip_seg_tta.patch_embeds,
|
| 94 |
+
map_shape=(CLIP_GRIDS_DIMS[0], CLIP_GRIDS_DIMS[1]),
|
| 95 |
+
)
|
| 96 |
+
print("Chosen k:", self.kmeans_clusterer.final_k)
|
| 97 |
+
|
| 98 |
+
# if EXECUTE_TTA:
|
| 99 |
+
# print("Will execute TTA...")
|
| 100 |
+
|
| 101 |
+
# Define Poisson TTA params
|
| 102 |
+
self.step_since_tta = 0
|
| 103 |
+
self.steps_to_first_tgt = None
|
| 104 |
+
self.steps_to_mid_tgt = None
|
| 105 |
+
self.steps_to_last_tgt = None
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def run_episode(self, curr_episode):
|
| 109 |
+
|
| 110 |
+
# Return all metrics as None if faulty mask init
|
| 111 |
+
if self.bad_mask_init:
|
| 112 |
+
self.perf_metrics['tax'] = None
|
| 113 |
+
self.perf_metrics['travel_dist'] = None
|
| 114 |
+
self.perf_metrics['travel_steps'] = None
|
| 115 |
+
self.perf_metrics['steps_to_first_tgt'] = None
|
| 116 |
+
self.perf_metrics['steps_to_mid_tgt'] = None
|
| 117 |
+
self.perf_metrics['steps_to_last_tgt'] = None
|
| 118 |
+
self.perf_metrics['explored_rate'] = None
|
| 119 |
+
self.perf_metrics['targets_found'] = None
|
| 120 |
+
self.perf_metrics['targets_total'] = None
|
| 121 |
+
self.perf_metrics['kmeans_k'] = None
|
| 122 |
+
self.perf_metrics['tgts_gt_score'] = None
|
| 123 |
+
self.perf_metrics['clip_inference_time'] = None
|
| 124 |
+
self.perf_metrics['tta_time'] = None
|
| 125 |
+
self.perf_metrics['success_rate'] = None
|
| 126 |
+
return
|
| 127 |
+
|
| 128 |
+
eps_start = time()
|
| 129 |
+
done = False
|
| 130 |
+
for robot_id, deciding_robot in enumerate(self.robot_list):
|
| 131 |
+
deciding_robot.observations = self.get_observations(deciding_robot.robot_position)
|
| 132 |
+
if LOAD_AVS_BENCH and USE_CLIP_PREDS:
|
| 133 |
+
if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]: # If heatmap is resized from clip original dims
|
| 134 |
+
heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT))
|
| 135 |
+
self.env.segmentation_info_mask = np.expand_dims(heatmap.T.flatten(), axis=1)
|
| 136 |
+
unnormalized_heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap_unnormalized, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT))
|
| 137 |
+
self.env.segmentation_info_mask_unnormalized = np.expand_dims(unnormalized_heatmap.T.flatten(), axis=1)
|
| 138 |
+
print("Resized heatmap to", NUM_COORDS_WIDTH, "x", NUM_COORDS_HEIGHT)
|
| 139 |
+
else:
|
| 140 |
+
self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1)
|
| 141 |
+
self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1)
|
| 142 |
+
|
| 143 |
+
### Run episode ###
|
| 144 |
+
for step in range(NUM_EPS_STEPS):
|
| 145 |
+
|
| 146 |
+
next_position_list = []
|
| 147 |
+
dist_list = []
|
| 148 |
+
travel_dist_list = []
|
| 149 |
+
dist_array = np.zeros((self.n_agent, 1))
|
| 150 |
+
for robot_id, deciding_robot in enumerate(self.robot_list):
|
| 151 |
+
observations = deciding_robot.observations
|
| 152 |
+
|
| 153 |
+
### Forward pass through policy to get next position ###
|
| 154 |
+
next_position, action_index = self.select_node(observations)
|
| 155 |
+
dist = np.linalg.norm(next_position - deciding_robot.robot_position)
|
| 156 |
+
|
| 157 |
+
### Log results of action (e.g. distance travelled) ###
|
| 158 |
+
dist_array[robot_id] = dist
|
| 159 |
+
dist_list.append(dist)
|
| 160 |
+
travel_dist_list.append(deciding_robot.travel_dist)
|
| 161 |
+
next_position_list.append(next_position)
|
| 162 |
+
self.all_robot_positions[robot_id] = next_position
|
| 163 |
+
|
| 164 |
+
arriving_sequence = np.argsort(dist_list)
|
| 165 |
+
next_position_list = np.array(next_position_list)
|
| 166 |
+
dist_list = np.array(dist_list)
|
| 167 |
+
travel_dist_list = np.array(travel_dist_list)
|
| 168 |
+
next_position_list = next_position_list[arriving_sequence]
|
| 169 |
+
dist_list = dist_list[arriving_sequence]
|
| 170 |
+
travel_dist_list = travel_dist_list[arriving_sequence]
|
| 171 |
+
|
| 172 |
+
### Take Action (Deconflict if 2 agents choose the same target position) ###
|
| 173 |
+
next_position_list, dist_list = self.solve_conflict(arriving_sequence, next_position_list, dist_list)
|
| 174 |
+
reward_list, done = self.env.multi_robot_step(next_position_list, dist_list, travel_dist_list)
|
| 175 |
+
|
| 176 |
+
### Update observations + rewards from action ###
|
| 177 |
+
for reward, robot_id in zip(reward_list, arriving_sequence):
|
| 178 |
+
robot = self.robot_list[robot_id]
|
| 179 |
+
robot.save_trajectory_coords(self.env.find_index_from_coords(robot.robot_position), self.env.num_new_targets_found)
|
| 180 |
+
|
| 181 |
+
# # TTA Update via Poisson Test (with KMeans clustering stats)
|
| 182 |
+
if LOAD_AVS_BENCH and USE_CLIP_PREDS and self.execute_tta:
|
| 183 |
+
self.poisson_tta_update(robot, self.global_step, step)
|
| 184 |
+
|
| 185 |
+
robot.observations = self.get_observations(robot.robot_position)
|
| 186 |
+
robot.save_reward_done(reward, done)
|
| 187 |
+
|
| 188 |
+
# Update metrics
|
| 189 |
+
self.log_metrics(step=step)
|
| 190 |
+
|
| 191 |
+
### Save a frame to generate gif of robot trajectories ###
|
| 192 |
+
if self.save_image:
|
| 193 |
+
robots_route = []
|
| 194 |
+
for robot in self.robot_list:
|
| 195 |
+
robots_route.append([robot.xPoints, robot.yPoints])
|
| 196 |
+
if not os.path.exists(self.gifs_path):
|
| 197 |
+
os.makedirs(self.gifs_path)
|
| 198 |
+
if LOAD_AVS_BENCH:
|
| 199 |
+
# NOTE: Replaced since using app.py
|
| 200 |
+
self.env.plot_heatmap(self.gifs_path, step, max(travel_dist_list), robots_route)
|
| 201 |
+
|
| 202 |
+
if done:
|
| 203 |
+
break
|
| 204 |
+
|
| 205 |
+
if LOAD_AVS_BENCH:
|
| 206 |
+
tax = Path(self.clip_seg_tta.gt_mask_name).stem
|
| 207 |
+
self.perf_metrics['tax'] = " ".join(tax.split("_")[1:])
|
| 208 |
+
else:
|
| 209 |
+
self.perf_metrics['tax'] = None
|
| 210 |
+
self.perf_metrics['travel_dist'] = max(travel_dist_list)
|
| 211 |
+
self.perf_metrics['travel_steps'] = step + 1
|
| 212 |
+
self.perf_metrics['steps_to_first_tgt'] = self.steps_to_first_tgt
|
| 213 |
+
self.perf_metrics['steps_to_mid_tgt'] = self.steps_to_mid_tgt
|
| 214 |
+
self.perf_metrics['steps_to_last_tgt'] = self.steps_to_last_tgt
|
| 215 |
+
self.perf_metrics['explored_rate'] = self.env.explored_rate
|
| 216 |
+
self.perf_metrics['targets_found'] = self.env.targets_found_rate
|
| 217 |
+
self.perf_metrics['targets_total'] = len(self.env.target_positions)
|
| 218 |
+
if USE_CLIP_PREDS:
|
| 219 |
+
self.perf_metrics['kmeans_k'] = self.kmeans_clusterer.final_k
|
| 220 |
+
self.perf_metrics['tgts_gt_score'] = self.clip_seg_tta.tgts_gt_score
|
| 221 |
+
self.perf_metrics['clip_inference_time'] = self.clip_seg_tta.clip_inference_time
|
| 222 |
+
self.perf_metrics['tta_time'] = self.clip_seg_tta.tta_time
|
| 223 |
+
else:
|
| 224 |
+
self.perf_metrics['kmeans_k'] = None
|
| 225 |
+
self.perf_metrics['tgts_gt_score'] = None
|
| 226 |
+
self.perf_metrics['clip_inference_time'] = None
|
| 227 |
+
self.perf_metrics['tta_time'] = None
|
| 228 |
+
if FORCE_LOGGING_DONE_TGTS_FOUND and self.env.targets_found_rate == 1.0:
|
| 229 |
+
self.perf_metrics['success_rate'] = True
|
| 230 |
+
else:
|
| 231 |
+
self.perf_metrics['success_rate'] = done
|
| 232 |
+
|
| 233 |
+
# save gif
|
| 234 |
+
if self.save_image:
|
| 235 |
+
path = self.gifs_path # NOTE: Set to self.gifs_path since using app.py
|
| 236 |
+
self.make_gif(path, curr_episode)
|
| 237 |
+
|
| 238 |
+
print(YELLOW, f"[Eps {curr_episode} Completed] Time Taken: {time()-eps_start:.2f}s, Steps: {step+1}", NC)
|
| 239 |
+
|
| 240 |
+
def get_observations(self, robot_position):
|
| 241 |
+
""" Get robot's sensor observation of environment given position """
|
| 242 |
+
current_node_index = self.env.find_index_from_coords(robot_position)
|
| 243 |
+
current_index = torch.tensor([current_node_index]).unsqueeze(0).unsqueeze(0).to(self.device) # (1,1,1)
|
| 244 |
+
|
| 245 |
+
node_coords = copy.deepcopy(self.env.node_coords)
|
| 246 |
+
graph = copy.deepcopy(self.env.graph)
|
| 247 |
+
node_utility = copy.deepcopy(self.env.node_utility)
|
| 248 |
+
guidepost = copy.deepcopy(self.env.guidepost)
|
| 249 |
+
segmentation_info_mask = copy.deepcopy(self.env.filtered_seg_info_mask)
|
| 250 |
+
|
| 251 |
+
n_nodes = node_coords.shape[0]
|
| 252 |
+
node_coords = node_coords / 640
|
| 253 |
+
node_utility = node_utility / 50
|
| 254 |
+
node_utility_inputs = node_utility.reshape((n_nodes, 1))
|
| 255 |
+
|
| 256 |
+
occupied_node = np.zeros((n_nodes, 1))
|
| 257 |
+
for position in self.all_robot_positions:
|
| 258 |
+
index = self.env.find_index_from_coords(position)
|
| 259 |
+
if index == current_index.item():
|
| 260 |
+
occupied_node[index] = -1
|
| 261 |
+
else:
|
| 262 |
+
occupied_node[index] = 1
|
| 263 |
+
|
| 264 |
+
node_inputs = np.concatenate((node_coords, segmentation_info_mask, guidepost), axis=1)
|
| 265 |
+
node_inputs = torch.FloatTensor(node_inputs).unsqueeze(0).to(self.device)
|
| 266 |
+
node_padding_mask = None
|
| 267 |
+
|
| 268 |
+
graph = list(graph.values())
|
| 269 |
+
edge_inputs = []
|
| 270 |
+
for node in graph:
|
| 271 |
+
node_edges = list(map(int, node))
|
| 272 |
+
edge_inputs.append(node_edges)
|
| 273 |
+
|
| 274 |
+
bias_matrix = self.calculate_edge_mask(edge_inputs)
|
| 275 |
+
edge_mask = torch.from_numpy(bias_matrix).float().unsqueeze(0).to(self.device)
|
| 276 |
+
|
| 277 |
+
for edges in edge_inputs:
|
| 278 |
+
while len(edges) < self.k_size:
|
| 279 |
+
edges.append(0)
|
| 280 |
+
|
| 281 |
+
edge_inputs = torch.tensor(edge_inputs).unsqueeze(0).to(self.device)
|
| 282 |
+
edge_padding_mask = torch.zeros((1, len(edge_inputs), K_SIZE), dtype=torch.int64).to(self.device)
|
| 283 |
+
one = torch.ones_like(edge_padding_mask, dtype=torch.int64).to(self.device)
|
| 284 |
+
edge_padding_mask = torch.where(edge_inputs == 0, one, edge_padding_mask)
|
| 285 |
+
|
| 286 |
+
observations = node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask
|
| 287 |
+
return observations
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def select_node(self, observations):
|
| 291 |
+
""" Forward pass through policy to get next position to go to on map """
|
| 292 |
+
node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask = observations
|
| 293 |
+
with torch.no_grad():
|
| 294 |
+
logp_list = self.local_policy_net(node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask)
|
| 295 |
+
|
| 296 |
+
if self.greedy:
|
| 297 |
+
action_index = torch.argmax(logp_list, dim=1).long()
|
| 298 |
+
else:
|
| 299 |
+
action_index = torch.multinomial(logp_list.exp(), 1).long().squeeze(1)
|
| 300 |
+
|
| 301 |
+
next_node_index = edge_inputs[:, current_index.item(), action_index.item()]
|
| 302 |
+
|
| 303 |
+
next_position = self.env.node_coords[next_node_index]
|
| 304 |
+
|
| 305 |
+
return next_position, action_index
|
| 306 |
+
|
| 307 |
+
def solve_conflict(self, arriving_sequence, next_position_list, dist_list):
|
| 308 |
+
""" Deconflict if 2 agents choose the same target position """
|
| 309 |
+
for j, [robot_id, next_position] in enumerate(zip(arriving_sequence, next_position_list)):
|
| 310 |
+
moving_robot = self.robot_list[robot_id]
|
| 311 |
+
# if next_position[0] + next_position[1] * 1j in (next_position_list[:, 0] + next_position_list[:, 1] * 1j)[:j]:
|
| 312 |
+
# dist_to_next_position = np.argsort(np.linalg.norm(self.env.node_coords - next_position, axis=1))
|
| 313 |
+
# k = 0
|
| 314 |
+
# while next_position[0] + next_position[1] * 1j in (next_position_list[:, 0] + next_position_list[:, 1] * 1j)[:j]:
|
| 315 |
+
# k += 1
|
| 316 |
+
# next_position = self.env.node_coords[dist_to_next_position[k]]
|
| 317 |
+
|
| 318 |
+
dist = np.linalg.norm(next_position - moving_robot.robot_position)
|
| 319 |
+
next_position_list[j] = next_position
|
| 320 |
+
dist_list[j] = dist
|
| 321 |
+
moving_robot.travel_dist += dist
|
| 322 |
+
moving_robot.robot_position = next_position
|
| 323 |
+
|
| 324 |
+
return next_position_list, dist_list
|
| 325 |
+
|
| 326 |
+
def work(self, currEpisode):
|
| 327 |
+
'''
|
| 328 |
+
Interacts with the environment. The agent gets either gradients or experience buffer
|
| 329 |
+
'''
|
| 330 |
+
self.run_episode(currEpisode)
|
| 331 |
+
|
| 332 |
+
def calculate_edge_mask(self, edge_inputs):
|
| 333 |
+
size = len(edge_inputs)
|
| 334 |
+
bias_matrix = np.ones((size, size))
|
| 335 |
+
for i in range(size):
|
| 336 |
+
for j in range(size):
|
| 337 |
+
if j in edge_inputs[i]:
|
| 338 |
+
bias_matrix[i][j] = 0
|
| 339 |
+
return bias_matrix
|
| 340 |
+
|
| 341 |
+
def make_gif(self, path, n):
|
| 342 |
+
""" Generate a gif given list of images """
|
| 343 |
+
with imageio.get_writer('{}/{}_target_rate_{:.2f}.gif'.format(path, n, self.env.targets_found_rate), mode='I',
|
| 344 |
+
fps=5) as writer:
|
| 345 |
+
for frame in self.env.frame_files:
|
| 346 |
+
image = imageio.imread(frame)
|
| 347 |
+
writer.append_data(image)
|
| 348 |
+
print('gif complete\n')
|
| 349 |
+
|
| 350 |
+
# Remove files
|
| 351 |
+
for filename in self.env.frame_files[:-1]:
|
| 352 |
+
os.remove(filename)
|
| 353 |
+
|
| 354 |
+
# For gif during TTA
|
| 355 |
+
if LOAD_AVS_BENCH:
|
| 356 |
+
with imageio.get_writer('{}/{}_kmeans_stats.gif'.format(path, n), mode='I',
|
| 357 |
+
fps=5) as writer:
|
| 358 |
+
for frame in self.kmeans_clusterer.kmeans_frame_files:
|
| 359 |
+
image = imageio.imread(frame)
|
| 360 |
+
writer.append_data(image)
|
| 361 |
+
print('Kmeans Clusterer gif complete\n')
|
| 362 |
+
|
| 363 |
+
# Remove files
|
| 364 |
+
for filename in self.kmeans_clusterer.kmeans_frame_files[:-1]:
|
| 365 |
+
os.remove(filename)
|
| 366 |
+
|
| 367 |
+
################################################################################
|
| 368 |
+
# SPPP Related Fns
|
| 369 |
+
################################################################################
|
| 370 |
+
|
| 371 |
+
def log_metrics(self, step):
|
| 372 |
+
# Update tgt found metrics
|
| 373 |
+
if self.steps_to_first_tgt is None and self.env.num_targets_found == 1:
|
| 374 |
+
self.steps_to_first_tgt = step + 1
|
| 375 |
+
if self.steps_to_mid_tgt is None and self.env.num_targets_found == int(len(self.env.target_positions) / 2):
|
| 376 |
+
self.steps_to_mid_tgt = step + 1
|
| 377 |
+
if self.steps_to_last_tgt is None and self.env.num_targets_found == len(self.env.target_positions):
|
| 378 |
+
self.steps_to_last_tgt = step + 1
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def transpose_flat_idx(self, idx, H=NUM_COORDS_HEIGHT, W=NUM_COORDS_WIDTH):
|
| 382 |
+
"""
|
| 383 |
+
Transpose a flat index from an ``H×W`` grid to the equivalent
|
| 384 |
+
position in the ``W×H`` transposed grid while **keeping the result
|
| 385 |
+
in 1-D**.
|
| 386 |
+
"""
|
| 387 |
+
# --- Safety check to catch out-of-range indices ---
|
| 388 |
+
assert 0 <= idx < H * W, f"idx {idx} out of bounds for shape ({H}, {W})"
|
| 389 |
+
|
| 390 |
+
# Original (row, col)
|
| 391 |
+
row, col = divmod(idx, W)
|
| 392 |
+
# After transpose these coordinates swap
|
| 393 |
+
row_T, col_T = col, row
|
| 394 |
+
|
| 395 |
+
# Flatten back into 1-D (row-major) for the W×H grid
|
| 396 |
+
return row_T * H + col_T
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def poisson_tta_update(self, robot, episode, step):
|
| 400 |
+
|
| 401 |
+
# Generate Kmeans Clusters Stats
|
| 402 |
+
# Scale index back to CLIP_GRIDS_DIMS to be compatible with CLIP patch size
|
| 403 |
+
if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]:
|
| 404 |
+
# High-res remap via pixel coordinates preserves exact neighbourhood
|
| 405 |
+
filt_traj_coords, filt_targets_found_on_path = self.scale_trajectory(
|
| 406 |
+
robot.trajectory_coords,
|
| 407 |
+
self.env.target_positions,
|
| 408 |
+
old_dims=(NUM_COORDS_HEIGHT, NUM_COORDS_WIDTH),
|
| 409 |
+
full_dims=(512, 512),
|
| 410 |
+
new_dims=(CLIP_GRIDS_DIMS[0], CLIP_GRIDS_DIMS[1])
|
| 411 |
+
)
|
| 412 |
+
else:
|
| 413 |
+
filt_traj_coords = [self.transpose_flat_idx(idx) for idx in robot.trajectory_coords]
|
| 414 |
+
filt_targets_found_on_path = robot.targets_found_on_path
|
| 415 |
+
|
| 416 |
+
region_stats_dict = self.kmeans_clusterer.compute_region_statistics(
|
| 417 |
+
self.kmeans_sat_embeds_clusters,
|
| 418 |
+
self.clip_seg_tta.heatmap_unnormalized,
|
| 419 |
+
filt_traj_coords,
|
| 420 |
+
episode_num=episode,
|
| 421 |
+
step_num=step
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
# Prep & execute TTA
|
| 425 |
+
self.step_since_tta += 1
|
| 426 |
+
if robot.targets_found_on_path[-1] or self.step_since_tta % STEPS_PER_TTA == 0:
|
| 427 |
+
|
| 428 |
+
# NOTE: integration with app.py on hf
|
| 429 |
+
self.clip_seg_tta.executing_tta = True
|
| 430 |
+
|
| 431 |
+
num_cells = self.clip_seg_tta.heatmap.shape[0] * self.clip_seg_tta.heatmap.shape[1]
|
| 432 |
+
pos_sample_weight_scale, neg_sample_weight_scale = [], []
|
| 433 |
+
|
| 434 |
+
for i, sample_loc in enumerate(filt_traj_coords):
|
| 435 |
+
label = self.kmeans_clusterer.get_label_id(self.kmeans_sat_embeds_clusters, sample_loc)
|
| 436 |
+
num_patches = region_stats_dict[label]['num_patches']
|
| 437 |
+
patches_visited = region_stats_dict[label]['patches_visited']
|
| 438 |
+
expectation = region_stats_dict[label]['expectation']
|
| 439 |
+
|
| 440 |
+
# Exponent like focal loss to wait for more samples before confidently decreasing
|
| 441 |
+
pos_weight = 4.0
|
| 442 |
+
neg_weight = min(1.0, (patches_visited/(3*num_patches))**GAMMA_EXPONENT)
|
| 443 |
+
pos_sample_weight_scale.append(pos_weight)
|
| 444 |
+
neg_sample_weight_scale.append(neg_weight)
|
| 445 |
+
|
| 446 |
+
# # # Adaptative LR (as samples increase, increase LR to fit more datapoints)
|
| 447 |
+
adaptive_lr = MIN_LR + (MAX_LR - MIN_LR) * (step / num_cells)
|
| 448 |
+
|
| 449 |
+
# TTA Update
|
| 450 |
+
# NOTE: updated due to app.py (hf does not allow heatmap to persist)
|
| 451 |
+
heatmap = self.clip_seg_tta.execute_tta(
|
| 452 |
+
filt_traj_coords,
|
| 453 |
+
filt_targets_found_on_path,
|
| 454 |
+
tta_steps=NUM_TTA_STEPS,
|
| 455 |
+
lr=adaptive_lr,
|
| 456 |
+
pos_sample_weight=pos_sample_weight_scale,
|
| 457 |
+
neg_sample_weight=neg_sample_weight_scale,
|
| 458 |
+
reset_weights=RESET_WEIGHTS
|
| 459 |
+
)
|
| 460 |
+
self.clip_seg_tta.heatmap = heatmap
|
| 461 |
+
|
| 462 |
+
if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]: # If heatmap is resized from clip original dims
|
| 463 |
+
heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT))
|
| 464 |
+
self.env.segmentation_info_mask = np.expand_dims(heatmap.T.flatten(), axis=1)
|
| 465 |
+
unnormalized_heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap_unnormalized, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT))
|
| 466 |
+
self.env.segmentation_info_mask_unnormalized = np.expand_dims(unnormalized_heatmap.T.flatten(), axis=1)
|
| 467 |
+
print("~Resized heatmap to", NUM_COORDS_WIDTH, "x", NUM_COORDS_HEIGHT)
|
| 468 |
+
else:
|
| 469 |
+
self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1)
|
| 470 |
+
self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1)
|
| 471 |
+
|
| 472 |
+
self.step_since_tta = 0
|
| 473 |
+
|
| 474 |
+
# NOTE: integration with app.py on hf
|
| 475 |
+
self.clip_seg_tta.executing_tta = False
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def convert_heatmap_resolution(self, heatmap, full_dims=(512, 512), new_dims=(24, 24)):
|
| 479 |
+
heatmap_large = resize(heatmap, full_dims, order=1, # order=1 → bilinear
|
| 480 |
+
mode='reflect', anti_aliasing=True)
|
| 481 |
+
|
| 482 |
+
coords = self.env.graph_generator.grid_coords # (N, N, 2)
|
| 483 |
+
rows, cols = coords[...,1], coords[...,0]
|
| 484 |
+
heatmap_resized = heatmap_large[rows, cols]
|
| 485 |
+
heatmap_resized = heatmap_resized.reshape(new_dims[1], new_dims[0])
|
| 486 |
+
return heatmap_resized
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def convert_labelmap_resolution(self, labelmap, full_dims=(512, 512), new_dims=(24, 24)):
|
| 490 |
+
"""
|
| 491 |
+
1) Upsample via nearest‐neighbor to full_dims
|
| 492 |
+
2) Sample back down to your graph grid using grid_coords
|
| 493 |
+
"""
|
| 494 |
+
# 1) Upsample with nearest‐neighbor, preserving integer labels
|
| 495 |
+
up = resize(
|
| 496 |
+
labelmap,
|
| 497 |
+
full_dims,
|
| 498 |
+
order=0, # nearest‐neighbor
|
| 499 |
+
mode='edge', # padding mode
|
| 500 |
+
preserve_range=True, # don't normalize labels
|
| 501 |
+
anti_aliasing=False # must be False for labels
|
| 502 |
+
).astype(labelmap.dtype) # back to original integer dtype
|
| 503 |
+
|
| 504 |
+
# 2) Downsample via your precomputed grid coords
|
| 505 |
+
coords = self.env.graph_generator.grid_coords # shape (N, N, 2)
|
| 506 |
+
rows = coords[...,1].astype(int)
|
| 507 |
+
cols = coords[...,0].astype(int)
|
| 508 |
+
|
| 509 |
+
small = up[rows, cols] # shape (N, N)
|
| 510 |
+
small = small.reshape(new_dims[0], new_dims[1])
|
| 511 |
+
return small
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def scale_trajectory(self,
|
| 515 |
+
flat_indices,
|
| 516 |
+
targets,
|
| 517 |
+
old_dims=(17, 17),
|
| 518 |
+
full_dims=(512, 512),
|
| 519 |
+
new_dims=(24, 24)):
|
| 520 |
+
"""
|
| 521 |
+
Args:
|
| 522 |
+
flat_indices: list of ints in [0..old_H*old_W-1]
|
| 523 |
+
targets: list of (y_pix, x_pix) in [0..full_H-1]
|
| 524 |
+
old_dims: (old_H, old_W)
|
| 525 |
+
full_dims: (full_H, full_W)
|
| 526 |
+
new_dims: (new_H, new_W)
|
| 527 |
+
|
| 528 |
+
Returns:
|
| 529 |
+
new_flat_traj: list of unique flattened indices in new_H×new_W
|
| 530 |
+
counts: list of ints, same length as new_flat_traj
|
| 531 |
+
"""
|
| 532 |
+
old_H, old_W = old_dims
|
| 533 |
+
full_H, full_W = full_dims
|
| 534 |
+
new_H, new_W = new_dims
|
| 535 |
+
|
| 536 |
+
# 1) bin targets into new grid
|
| 537 |
+
cell_h_new = full_H / new_H
|
| 538 |
+
cell_w_new = full_W / new_W
|
| 539 |
+
grid_counts = [[0]*new_W for _ in range(new_H)]
|
| 540 |
+
for x_pix, y_pix in targets: # note (x, y) order as in original implementation
|
| 541 |
+
i_t = min(int(y_pix / cell_h_new), new_H - 1)
|
| 542 |
+
j_t = min(int(x_pix / cell_w_new), new_W - 1)
|
| 543 |
+
grid_counts[i_t][j_t] += 1
|
| 544 |
+
|
| 545 |
+
# 2) Walk the trajectory indices and project each old cell's *entire
|
| 546 |
+
# pixel footprint* onto the finer 24×24 grid.
|
| 547 |
+
cell_h_full = full_H / old_H
|
| 548 |
+
cell_w_full = full_W / old_W
|
| 549 |
+
|
| 550 |
+
seen = set()
|
| 551 |
+
new_flat_traj = []
|
| 552 |
+
|
| 553 |
+
for node_idx in flat_indices:
|
| 554 |
+
if node_idx < 0 or node_idx >= len(self.env.graph_generator.node_coords):
|
| 555 |
+
continue
|
| 556 |
+
|
| 557 |
+
coord_xy = self.env.graph_generator.node_coords[node_idx]
|
| 558 |
+
try:
|
| 559 |
+
row_old, col_old = self.env.graph_generator.find_index_from_grid_coords_2d(coord_xy)
|
| 560 |
+
except Exception:
|
| 561 |
+
continue
|
| 562 |
+
|
| 563 |
+
# Bounding box of the old cell in full-resolution pixel space
|
| 564 |
+
y0 = row_old * cell_h_full
|
| 565 |
+
y1 = (row_old + 1) * cell_h_full
|
| 566 |
+
x0 = col_old * cell_w_full
|
| 567 |
+
x1 = (col_old + 1) * cell_w_full
|
| 568 |
+
|
| 569 |
+
# Which new-grid rows & cols overlap? (inclusive ranges)
|
| 570 |
+
i_start = max(0, min(int(y0 / cell_h_new), new_H - 1))
|
| 571 |
+
i_end = max(0, min(int((y1 - 1) / cell_h_new), new_H - 1))
|
| 572 |
+
j_start = max(0, min(int(x0 / cell_w_new), new_W - 1))
|
| 573 |
+
j_end = max(0, min(int((x1 - 1) / cell_w_new), new_W - 1))
|
| 574 |
+
|
| 575 |
+
for ii in range(i_start, i_end + 1):
|
| 576 |
+
for jj in range(j_start, j_end + 1):
|
| 577 |
+
f_new = ii * new_W + jj
|
| 578 |
+
if f_new not in seen:
|
| 579 |
+
seen.add(f_new)
|
| 580 |
+
new_flat_traj.append(f_new)
|
| 581 |
+
|
| 582 |
+
# 3) annotate counts
|
| 583 |
+
counts = []
|
| 584 |
+
for f in new_flat_traj:
|
| 585 |
+
i_new, j_new = divmod(f, new_W)
|
| 586 |
+
counts.append(grid_counts[i_new][j_new])
|
| 587 |
+
|
| 588 |
+
return new_flat_traj, counts
|
| 589 |
+
|
| 590 |
+
################################################################################
|
planner/worker.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#######################################################################
|
| 2 |
+
# Name: worker.py
|
| 3 |
+
#
|
| 4 |
+
# - Runs robot in environment for N steps
|
| 5 |
+
# - Collects & Returns S(t), A(t), R(t), S(t+1)
|
| 6 |
+
#######################################################################
|
| 7 |
+
|
| 8 |
+
from .parameter import *
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import json
|
| 12 |
+
import copy
|
| 13 |
+
import imageio
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
from time import time
|
| 17 |
+
from .env import Env
|
| 18 |
+
from .robot import Robot
|
| 19 |
+
|
| 20 |
+
class Worker:
|
| 21 |
+
def __init__(self, meta_agent_id, n_agent, policy_net, q_net, global_step, device='cuda', greedy=False, save_image=False, clip_seg_tta=None):
|
| 22 |
+
self.device = device
|
| 23 |
+
self.greedy = greedy
|
| 24 |
+
self.n_agent = n_agent
|
| 25 |
+
self.metaAgentID = meta_agent_id
|
| 26 |
+
self.global_step = global_step
|
| 27 |
+
self.node_padding_size = NODE_PADDING_SIZE
|
| 28 |
+
self.k_size = K_SIZE
|
| 29 |
+
self.save_image = save_image
|
| 30 |
+
self.clip_seg_tta = clip_seg_tta
|
| 31 |
+
|
| 32 |
+
# Randomize map_index
|
| 33 |
+
mask_index = None
|
| 34 |
+
if MASKS_RAND_INDICES_PATH != "":
|
| 35 |
+
with open(MASKS_RAND_INDICES_PATH, 'r') as f:
|
| 36 |
+
mask_index_rand_json = json.load(f)
|
| 37 |
+
mask_index = mask_index_rand_json[self.global_step % len(mask_index_rand_json)]
|
| 38 |
+
print("mask_index: ", mask_index)
|
| 39 |
+
|
| 40 |
+
self.env = Env(map_index=self.global_step, n_agent=n_agent, k_size=self.k_size, plot=save_image, mask_index=mask_index)
|
| 41 |
+
self.local_policy_net = policy_net
|
| 42 |
+
self.local_q_net = q_net
|
| 43 |
+
|
| 44 |
+
self.robot_list = []
|
| 45 |
+
self.all_robot_positions = []
|
| 46 |
+
|
| 47 |
+
for i in range(self.n_agent):
|
| 48 |
+
robot_position = self.env.start_positions[i]
|
| 49 |
+
robot = Robot(robot_id=i, position=robot_position, plot=save_image)
|
| 50 |
+
self.robot_list.append(robot)
|
| 51 |
+
self.all_robot_positions.append(robot_position)
|
| 52 |
+
|
| 53 |
+
self.perf_metrics = dict()
|
| 54 |
+
self.episode_buffer = []
|
| 55 |
+
for i in range(15):
|
| 56 |
+
self.episode_buffer.append([])
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def run_episode(self, curr_episode):
|
| 60 |
+
|
| 61 |
+
eps_start = time()
|
| 62 |
+
done = False
|
| 63 |
+
for robot_id, deciding_robot in enumerate(self.robot_list):
|
| 64 |
+
deciding_robot.observations = self.get_observations(deciding_robot.robot_position)
|
| 65 |
+
|
| 66 |
+
### Run episode ###
|
| 67 |
+
for step in range(NUM_EPS_STEPS):
|
| 68 |
+
|
| 69 |
+
next_position_list = []
|
| 70 |
+
dist_list = []
|
| 71 |
+
travel_dist_list = []
|
| 72 |
+
dist_array = np.zeros((self.n_agent, 1))
|
| 73 |
+
for robot_id, deciding_robot in enumerate(self.robot_list):
|
| 74 |
+
observations = deciding_robot.observations
|
| 75 |
+
deciding_robot.save_observations(observations)
|
| 76 |
+
|
| 77 |
+
### Forward pass through policy to get next position ###
|
| 78 |
+
next_position, action_index = self.select_node(observations)
|
| 79 |
+
deciding_robot.save_action(action_index)
|
| 80 |
+
|
| 81 |
+
dist = np.linalg.norm(next_position - deciding_robot.robot_position)
|
| 82 |
+
|
| 83 |
+
### Log results of action (e.g. distance travelled) ###
|
| 84 |
+
dist_array[robot_id] = dist
|
| 85 |
+
dist_list.append(dist)
|
| 86 |
+
travel_dist_list.append(deciding_robot.travel_dist)
|
| 87 |
+
next_position_list.append(next_position)
|
| 88 |
+
self.all_robot_positions[robot_id] = next_position
|
| 89 |
+
|
| 90 |
+
arriving_sequence = np.argsort(dist_list)
|
| 91 |
+
next_position_list = np.array(next_position_list)
|
| 92 |
+
dist_list = np.array(dist_list)
|
| 93 |
+
travel_dist_list = np.array(travel_dist_list)
|
| 94 |
+
next_position_list = next_position_list[arriving_sequence]
|
| 95 |
+
dist_list = dist_list[arriving_sequence]
|
| 96 |
+
travel_dist_list = travel_dist_list[arriving_sequence]
|
| 97 |
+
|
| 98 |
+
### Take Action (Deconflict if 2 agents choose the same target position) ###
|
| 99 |
+
next_position_list, dist_list = self.solve_conflict(arriving_sequence, next_position_list, dist_list)
|
| 100 |
+
reward_list, done = self.env.multi_robot_step(next_position_list, dist_list, travel_dist_list)
|
| 101 |
+
|
| 102 |
+
### Update observations + rewards from action ###
|
| 103 |
+
for reward, robot_id in zip(reward_list, arriving_sequence):
|
| 104 |
+
robot = self.robot_list[robot_id]
|
| 105 |
+
robot.observations = self.get_observations(robot.robot_position)
|
| 106 |
+
robot.save_trajectory_coords(self.env.find_index_from_coords(robot.robot_position), self.env.num_new_targets_found)
|
| 107 |
+
robot.save_reward_done(reward, done)
|
| 108 |
+
robot.save_next_observations(robot.observations)
|
| 109 |
+
|
| 110 |
+
### Save a frame to generate gif of robot trajectories ###
|
| 111 |
+
if self.save_image:
|
| 112 |
+
robots_route = []
|
| 113 |
+
for robot in self.robot_list:
|
| 114 |
+
robots_route.append([robot.xPoints, robot.yPoints])
|
| 115 |
+
if not os.path.exists(GIFS_PATH):
|
| 116 |
+
os.makedirs(GIFS_PATH)
|
| 117 |
+
self.env.plot_env(self.global_step, GIFS_PATH, step, max(travel_dist_list), robots_route)
|
| 118 |
+
|
| 119 |
+
if done:
|
| 120 |
+
break
|
| 121 |
+
|
| 122 |
+
for robot in self.robot_list:
|
| 123 |
+
for i in range(15):
|
| 124 |
+
self.episode_buffer[i] += robot.episode_buffer[i]
|
| 125 |
+
|
| 126 |
+
self.perf_metrics['travel_dist'] = max(travel_dist_list)
|
| 127 |
+
self.perf_metrics['explored_rate'] = self.env.explored_rate
|
| 128 |
+
self.perf_metrics['targets_found'] = self.env.targets_found_rate
|
| 129 |
+
self.perf_metrics['success_rate'] = done
|
| 130 |
+
|
| 131 |
+
# save gif
|
| 132 |
+
if self.save_image:
|
| 133 |
+
path = GIFS_PATH
|
| 134 |
+
self.make_gif(path, curr_episode)
|
| 135 |
+
|
| 136 |
+
print(YELLOW, f"[Eps {curr_episode} Completed] Time Taken: {time()-eps_start:.2f}s, Steps: {step+1}", NC)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def get_observations(self, robot_position):
|
| 140 |
+
""" Get robot's sensor observation of environment given position """
|
| 141 |
+
current_node_index = self.env.find_index_from_coords(robot_position)
|
| 142 |
+
current_index = torch.tensor([current_node_index]).unsqueeze(0).unsqueeze(0).to(self.device) # (1,1,1)
|
| 143 |
+
|
| 144 |
+
node_coords = copy.deepcopy(self.env.node_coords)
|
| 145 |
+
graph = copy.deepcopy(self.env.graph)
|
| 146 |
+
node_utility = copy.deepcopy(self.env.node_utility)
|
| 147 |
+
guidepost = copy.deepcopy(self.env.guidepost)
|
| 148 |
+
segmentation_info_mask = copy.deepcopy(self.env.filtered_seg_info_mask)
|
| 149 |
+
|
| 150 |
+
n_nodes = node_coords.shape[0]
|
| 151 |
+
node_coords = node_coords / 640
|
| 152 |
+
node_utility = node_utility / 50
|
| 153 |
+
|
| 154 |
+
node_utility_inputs = node_utility.reshape((n_nodes, 1))
|
| 155 |
+
|
| 156 |
+
occupied_node = np.zeros((n_nodes, 1))
|
| 157 |
+
for position in self.all_robot_positions:
|
| 158 |
+
index = self.env.find_index_from_coords(position)
|
| 159 |
+
if index == current_index.item():
|
| 160 |
+
occupied_node[index] = -1
|
| 161 |
+
else:
|
| 162 |
+
occupied_node[index] = 1
|
| 163 |
+
|
| 164 |
+
node_inputs = np.concatenate((node_coords, segmentation_info_mask, guidepost), axis=1)
|
| 165 |
+
node_inputs = torch.FloatTensor(node_inputs).unsqueeze(0).to(self.device) # (1, node_padding_size+1, 3)
|
| 166 |
+
|
| 167 |
+
assert node_coords.shape[0] < self.node_padding_size
|
| 168 |
+
padding = torch.nn.ZeroPad2d((0, 0, 0, self.node_padding_size - node_coords.shape[0]))
|
| 169 |
+
node_inputs = padding(node_inputs)
|
| 170 |
+
|
| 171 |
+
node_padding_mask = torch.zeros((1, 1, node_coords.shape[0]), dtype=torch.int64).to(self.device)
|
| 172 |
+
node_padding = torch.ones((1, 1, self.node_padding_size - node_coords.shape[0]), dtype=torch.int64).to(
|
| 173 |
+
self.device)
|
| 174 |
+
node_padding_mask = torch.cat((node_padding_mask, node_padding), dim=-1)
|
| 175 |
+
|
| 176 |
+
graph = list(graph.values())
|
| 177 |
+
edge_inputs = []
|
| 178 |
+
for node in graph:
|
| 179 |
+
node_edges = list(map(int, node))
|
| 180 |
+
edge_inputs.append(node_edges)
|
| 181 |
+
|
| 182 |
+
bias_matrix = self.calculate_edge_mask(edge_inputs)
|
| 183 |
+
edge_mask = torch.from_numpy(bias_matrix).float().unsqueeze(0).to(self.device)
|
| 184 |
+
|
| 185 |
+
assert len(edge_inputs) < self.node_padding_size
|
| 186 |
+
padding = torch.nn.ConstantPad2d(
|
| 187 |
+
(0, self.node_padding_size - len(edge_inputs), 0, self.node_padding_size - len(edge_inputs)), 1)
|
| 188 |
+
edge_mask = padding(edge_mask)
|
| 189 |
+
padding2 = torch.nn.ZeroPad2d((0, 0, 0, self.node_padding_size - len(edge_inputs)))
|
| 190 |
+
|
| 191 |
+
for edges in edge_inputs:
|
| 192 |
+
while len(edges) < self.k_size:
|
| 193 |
+
edges.append(0)
|
| 194 |
+
|
| 195 |
+
edge_inputs = torch.tensor(edge_inputs).unsqueeze(0).to(self.device) # (1, node_padding_size+1, k_size)
|
| 196 |
+
edge_inputs = padding2(edge_inputs)
|
| 197 |
+
|
| 198 |
+
edge_padding_mask = torch.zeros((1, len(edge_inputs), K_SIZE), dtype=torch.int64).to(self.device)
|
| 199 |
+
one = torch.ones_like(edge_padding_mask, dtype=torch.int64).to(self.device)
|
| 200 |
+
edge_padding_mask = torch.where(edge_inputs == 0, one, edge_padding_mask)
|
| 201 |
+
|
| 202 |
+
observations = node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask
|
| 203 |
+
return observations
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def select_node(self, observations):
|
| 207 |
+
""" Forward pass through policy to get next position to go to on map """
|
| 208 |
+
node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask = observations
|
| 209 |
+
with torch.no_grad():
|
| 210 |
+
logp_list = self.local_policy_net(node_inputs, edge_inputs, current_index, node_padding_mask,
|
| 211 |
+
edge_padding_mask, edge_mask)
|
| 212 |
+
|
| 213 |
+
if self.greedy:
|
| 214 |
+
action_index = torch.argmax(logp_list, dim=1).long()
|
| 215 |
+
else:
|
| 216 |
+
action_index = torch.multinomial(logp_list.exp(), 1).long().squeeze(1)
|
| 217 |
+
|
| 218 |
+
next_node_index = edge_inputs[:, current_index.item(), action_index.item()]
|
| 219 |
+
|
| 220 |
+
next_position = self.env.node_coords[next_node_index]
|
| 221 |
+
|
| 222 |
+
return next_position, action_index
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def solve_conflict(self, arriving_sequence, next_position_list, dist_list):
|
| 226 |
+
""" Deconflict if 2 agents choose the same target position """
|
| 227 |
+
for j, [robot_id, next_position] in enumerate(zip(arriving_sequence, next_position_list)):
|
| 228 |
+
moving_robot = self.robot_list[robot_id]
|
| 229 |
+
# if next_position[0] + next_position[1] * 1j in (next_position_list[:, 0] + next_position_list[:, 1] * 1j)[:j]:
|
| 230 |
+
# dist_to_next_position = np.argsort(np.linalg.norm(self.env.node_coords - next_position, axis=1))
|
| 231 |
+
# k = 0
|
| 232 |
+
# while next_position[0] + next_position[1] * 1j in (next_position_list[:, 0] + next_position_list[:, 1] * 1j)[:j]:
|
| 233 |
+
# k += 1
|
| 234 |
+
# next_position = self.env.node_coords[dist_to_next_position[k]]
|
| 235 |
+
|
| 236 |
+
dist = np.linalg.norm(next_position - moving_robot.robot_position)
|
| 237 |
+
next_position_list[j] = next_position
|
| 238 |
+
dist_list[j] = dist
|
| 239 |
+
moving_robot.travel_dist += dist
|
| 240 |
+
moving_robot.robot_position = next_position
|
| 241 |
+
|
| 242 |
+
return next_position_list, dist_list
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def work(self, currEpisode):
|
| 246 |
+
'''
|
| 247 |
+
Interacts with the environment. The agent gets either gradients or experience buffer
|
| 248 |
+
'''
|
| 249 |
+
self.run_episode(currEpisode)
|
| 250 |
+
|
| 251 |
+
def calculate_edge_mask(self, edge_inputs):
|
| 252 |
+
size = len(edge_inputs)
|
| 253 |
+
bias_matrix = np.ones((size, size))
|
| 254 |
+
for i in range(size):
|
| 255 |
+
for j in range(size):
|
| 256 |
+
if j in edge_inputs[i]:
|
| 257 |
+
bias_matrix[i][j] = 0
|
| 258 |
+
return bias_matrix
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def make_gif(self, path, n):
|
| 262 |
+
""" Generate a gif given list of images """
|
| 263 |
+
with imageio.get_writer('{}/{}_target_rate_{:.2f}.gif'.format(path, n, self.env.targets_found_rate), mode='I',
|
| 264 |
+
fps=5) as writer:
|
| 265 |
+
for frame in self.env.frame_files:
|
| 266 |
+
image = imageio.imread(frame)
|
| 267 |
+
writer.append_data(image)
|
| 268 |
+
print('gif complete\n')
|
| 269 |
+
|
| 270 |
+
# Remove files
|
| 271 |
+
for filename in self.env.frame_files[:-1]:
|
| 272 |
+
os.remove(filename)
|