diff --git a/configs/dataset/bedlam_wai/default.yaml b/configs/dataset/bedlam_wai/default.yaml deleted file mode 100644 index 78448df144ecb19398c761f707b5e264dcaaae29..0000000000000000000000000000000000000000 --- a/configs/dataset/bedlam_wai/default.yaml +++ /dev/null @@ -1,3 +0,0 @@ -defaults: - - train: default - - val: default diff --git a/configs/dataset/bedlam_wai/train/default.yaml b/configs/dataset/bedlam_wai/train/default.yaml deleted file mode 100644 index 11dc8db66e605f9f42a33677c1bcba7236ca3ef7..0000000000000000000000000000000000000000 --- a/configs/dataset/bedlam_wai/train/default.yaml +++ /dev/null @@ -1,26 +0,0 @@ -dataset_str: - "BedlamWAI( - split='${dataset.bedlam_wai.train.split}', - resolution=${dataset.bedlam_wai.train.dataset_resolution}, - principal_point_centered=${dataset.bedlam_wai.train.principal_point_centered}, - aug_crop=${dataset.bedlam_wai.train.aug_crop}, - transform='${dataset.bedlam_wai.train.transform}', - data_norm_type='${dataset.bedlam_wai.train.data_norm_type}', - ROOT='${dataset.bedlam_wai.train.ROOT}', - dataset_metadata_dir='${dataset.bedlam_wai.train.dataset_metadata_dir}', - overfit_num_sets=${dataset.bedlam_wai.train.overfit_num_sets}, - variable_num_views=${dataset.bedlam_wai.train.variable_num_views}, - num_views=${dataset.bedlam_wai.train.num_views}, - covisibility_thres=${dataset.bedlam_wai.train.covisibility_thres})" -split: 'train' -dataset_resolution: ${dataset.resolution_train} -principal_point_centered: ${dataset.principal_point_centered} -aug_crop: 16 -transform: 'colorjitter+grayscale+gaublur' -data_norm_type: ${model.data_norm_type} -ROOT: ${root_data_dir}/bedlam -dataset_metadata_dir: ${mapanything_dataset_metadata_dir} -overfit_num_sets: null -variable_num_views: ${dataset.train.variable_num_views} -num_views: ${dataset.num_views} -covisibility_thres: 0.25 diff --git a/configs/dataset/bedlam_wai/val/default.yaml b/configs/dataset/bedlam_wai/val/default.yaml deleted file mode 100644 index 8d1471050f84b32fa1858a4d11ad4dc798c0f002..0000000000000000000000000000000000000000 --- a/configs/dataset/bedlam_wai/val/default.yaml +++ /dev/null @@ -1,26 +0,0 @@ -dataset_str: - "BedlamWAI( - split='${dataset.bedlam_wai.val.split}', - resolution=${dataset.bedlam_wai.val.dataset_resolution}, - principal_point_centered=${dataset.bedlam_wai.val.principal_point_centered}, - seed=${dataset.bedlam_wai.val.seed}, - transform='${dataset.bedlam_wai.val.transform}', - data_norm_type='${dataset.bedlam_wai.val.data_norm_type}', - ROOT='${dataset.bedlam_wai.val.ROOT}', - dataset_metadata_dir='${dataset.bedlam_wai.val.dataset_metadata_dir}', - overfit_num_sets=${dataset.bedlam_wai.val.overfit_num_sets}, - variable_num_views=${dataset.bedlam_wai.val.variable_num_views}, - num_views=${dataset.bedlam_wai.val.num_views}, - covisibility_thres=${dataset.bedlam_wai.val.covisibility_thres})" -split: 'val' -dataset_resolution: ${dataset.resolution_val_bedlam} -principal_point_centered: ${dataset.principal_point_centered} -seed: 777 -transform: 'imgnorm' -data_norm_type: ${model.data_norm_type} -ROOT: ${root_data_dir}/bedlam -dataset_metadata_dir: ${mapanything_dataset_metadata_dir} -overfit_num_sets: null -variable_num_views: ${dataset.val.variable_num_views} -num_views: ${dataset.num_views} -covisibility_thres: 0.25 diff --git a/configs/dataset/benchmark_504_eth3d_snpp_tav2.yaml b/configs/dataset/benchmark_504_eth3d_snpp_tav2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3f529eac82df718972eb46d36421f40887c2fecc --- /dev/null +++ b/configs/dataset/benchmark_504_eth3d_snpp_tav2.yaml @@ -0,0 +1,20 @@ +defaults: + - default + +# Number of views parameter for the multi-view datasets +num_views: 2 + +# Test Resolution +resolution_test_eth3d: ${dataset.resolution_options.504_1_52_ar} +resolution_test_scannetpp: ${dataset.resolution_options.504_1_52_ar} +resolution_test_tav2_wb: ${dataset.resolution_options.504_1_00_ar} + +# Test Set +# Sample 10 multi-view sets from each scene +# ETH3D: 13 scenes +# ScanNet++V2: 30 scenes +# TartanAirV2-WB: 5 scenes +test_dataset: + "+ 130 @ ${dataset.eth3d_wai.test.dataset_str} + + 300 @ ${dataset.scannetpp_wai.test.dataset_str} + + 50 @ ${dataset.tav2_wb_wai.test.dataset_str}" diff --git a/configs/dataset/benchmark_512_snpp_tav2.yaml b/configs/dataset/benchmark_512_snpp_tav2.yaml deleted file mode 100644 index 4be5baf2584b8557cef2d80d1e8e41bc9b4689e8..0000000000000000000000000000000000000000 --- a/configs/dataset/benchmark_512_snpp_tav2.yaml +++ /dev/null @@ -1,17 +0,0 @@ -defaults: - - default - -# Number of views parameter for the multi-view datasets -num_views: 2 - -# Test Resolution -resolution_test_scannetpp: ${dataset.resolution_options.512_1_52_ar} -resolution_test_tav2_wb: ${dataset.resolution_options.512_1_00_ar} - -# Test Set -# Sample 10 multi-view sets from each scene -# ScanNet++V2: 30 scenes -# TartanAirV2-WB: 5 scenes -test_dataset: - "+ 300 @ ${dataset.scannetpp_wai.test.dataset_str} - + 50 @ ${dataset.tav2_wb_wai.test.dataset_str}" diff --git a/configs/dataset/benchmark_518_snpp_tav2.yaml b/configs/dataset/benchmark_518_snpp_tav2.yaml deleted file mode 100644 index 4b15f825631ce0421857902f570c33812c6236f2..0000000000000000000000000000000000000000 --- a/configs/dataset/benchmark_518_snpp_tav2.yaml +++ /dev/null @@ -1,17 +0,0 @@ -defaults: - - default - -# Number of views parameter for the multi-view datasets -num_views: 2 - -# Test Resolution -resolution_test_scannetpp: ${dataset.resolution_options.518_1_52_ar} -resolution_test_tav2_wb: ${dataset.resolution_options.518_1_00_ar} - -# Test Set -# Sample 10 multi-view sets from each scene -# ScanNet++V2: 30 scenes -# TartanAirV2-WB: 5 scenes -test_dataset: - "+ 300 @ ${dataset.scannetpp_wai.test.dataset_str} - + 50 @ ${dataset.tav2_wb_wai.test.dataset_str}" diff --git a/configs/dataset/bmvs_518_many_ar_48ipg_8g.yaml b/configs/dataset/bmvs_518_many_ar_48ipg_8g.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e4ca79e90dd1ba616b2b774ee2c060ebfae09908 --- /dev/null +++ b/configs/dataset/bmvs_518_many_ar_48ipg_8g.yaml @@ -0,0 +1,23 @@ +defaults: + - default + +# Number of views parameter for the multi-view datasets +num_views: 4 + +train: + # If True, the number of views can vary from batch to batch. The maximum number of views is num_views and minimum is 2. (On by default for N-view training) + variable_num_views: true + +# Train Resolution +resolution_train: ${dataset.resolution_options.518_many_ar} + +# Validation Resolution +resolution_val_blendedmvs: ${dataset.resolution_options.518_1_33_ar} + +# Training Set +train_dataset: + "+ 140_000 @ ${dataset.blendedmvs_wai.train.dataset_str}" + +# Validation Set +test_dataset: + "+ 4_000 @ ${dataset.blendedmvs_wai.val.dataset_str}" diff --git a/configs/dataset/default.yaml b/configs/dataset/default.yaml index 84b4954bca4b0ac2127f7eafdc896ae4c46ce1b7..97cf9cd09b72815fb8337fe26d7cc2ccd56d200d 100644 --- a/configs/dataset/default.yaml +++ b/configs/dataset/default.yaml @@ -1,14 +1,10 @@ defaults: - resolution_options: default - ase_wai: default - - bedlam_wai: default - blendedmvs_wai: default - dl3dv_wai: default - - dtu_wai: default - dynamicreplica_wai: default - eth3d_wai: default - - gta_sfm_wai: default - - matrixcity_wai: default - megadepth_wai: default - mpsd_wai: default - mvs_synth_wai: default @@ -16,10 +12,8 @@ defaults: - sailvos3d_wai: default - scannetpp_wai: default - spring_wai: default - - structured3d_wai: default - tav2_wb_wai: default - unrealstereo4k_wai: default - - xrooms_wai: default # Training Set, For example: BlendedMVS(split='train', resolution=(512, 384), transform=...) train_dataset: ??? diff --git a/configs/dataset/dtu_wai/default.yaml b/configs/dataset/dtu_wai/default.yaml deleted file mode 100644 index b1278dcc74c8a2ee16b87a31ebabca50234ab9fa..0000000000000000000000000000000000000000 --- a/configs/dataset/dtu_wai/default.yaml +++ /dev/null @@ -1,2 +0,0 @@ -defaults: - - test: default diff --git a/configs/dataset/dtu_wai/test/default.yaml b/configs/dataset/dtu_wai/test/default.yaml deleted file mode 100644 index 7910a3aae6da7cef8a26ca77932b090d998cf8bb..0000000000000000000000000000000000000000 --- a/configs/dataset/dtu_wai/test/default.yaml +++ /dev/null @@ -1,22 +0,0 @@ -dataset_str: - "DTUWAI( - resolution=${dataset.dtu_wai.test.dataset_resolution}, - principal_point_centered=${dataset.dtu_wai.test.principal_point_centered}, - seed=${dataset.dtu_wai.test.seed}, - transform='${dataset.dtu_wai.test.transform}', - data_norm_type='${dataset.dtu_wai.test.data_norm_type}', - ROOT='${dataset.dtu_wai.test.ROOT}', - dataset_metadata_dir='${dataset.dtu_wai.test.dataset_metadata_dir}', - variable_num_views=${dataset.dtu_wai.test.variable_num_views}, - num_views=${dataset.dtu_wai.test.num_views}, - covisibility_thres=${dataset.dtu_wai.test.covisibility_thres})" -dataset_resolution: ${dataset.resolution_test_dtu} -principal_point_centered: ${dataset.principal_point_centered} -seed: 777 -transform: 'imgnorm' -data_norm_type: ${model.data_norm_type} -ROOT: ${root_data_dir}/dtu -dataset_metadata_dir: ${mapanything_dataset_metadata_dir} -variable_num_views: ${dataset.test.variable_num_views} -num_views: ${dataset.num_views} -covisibility_thres: 0.25 diff --git a/configs/dataset/gta_sfm_wai/default.yaml b/configs/dataset/gta_sfm_wai/default.yaml deleted file mode 100644 index 78448df144ecb19398c761f707b5e264dcaaae29..0000000000000000000000000000000000000000 --- a/configs/dataset/gta_sfm_wai/default.yaml +++ /dev/null @@ -1,3 +0,0 @@ -defaults: - - train: default - - val: default diff --git a/configs/dataset/gta_sfm_wai/train/default.yaml b/configs/dataset/gta_sfm_wai/train/default.yaml deleted file mode 100644 index 971b4f4a78f3207d38f667860becf035269a46a6..0000000000000000000000000000000000000000 --- a/configs/dataset/gta_sfm_wai/train/default.yaml +++ /dev/null @@ -1,26 +0,0 @@ -dataset_str: - "GTASfMWAI( - split='${dataset.gta_sfm_wai.train.split}', - resolution=${dataset.gta_sfm_wai.train.dataset_resolution}, - principal_point_centered=${dataset.gta_sfm_wai.train.principal_point_centered}, - aug_crop=${dataset.gta_sfm_wai.train.aug_crop}, - transform='${dataset.gta_sfm_wai.train.transform}', - data_norm_type='${dataset.gta_sfm_wai.train.data_norm_type}', - ROOT='${dataset.gta_sfm_wai.train.ROOT}', - dataset_metadata_dir='${dataset.gta_sfm_wai.train.dataset_metadata_dir}', - overfit_num_sets=${dataset.gta_sfm_wai.train.overfit_num_sets}, - variable_num_views=${dataset.gta_sfm_wai.train.variable_num_views}, - num_views=${dataset.gta_sfm_wai.train.num_views}, - covisibility_thres=${dataset.gta_sfm_wai.train.covisibility_thres})" -split: 'train' -dataset_resolution: ${dataset.resolution_train} -principal_point_centered: ${dataset.principal_point_centered} -aug_crop: 16 -transform: 'colorjitter+grayscale+gaublur' -data_norm_type: ${model.data_norm_type} -ROOT: ${root_data_dir}/gta_sfm -dataset_metadata_dir: ${mapanything_dataset_metadata_dir} -overfit_num_sets: null -variable_num_views: ${dataset.train.variable_num_views} -num_views: ${dataset.num_views} -covisibility_thres: 0.25 diff --git a/configs/dataset/gta_sfm_wai/val/default.yaml b/configs/dataset/gta_sfm_wai/val/default.yaml deleted file mode 100644 index 430ac9e292dc1059dcccbbab6bdf82b0f46f391e..0000000000000000000000000000000000000000 --- a/configs/dataset/gta_sfm_wai/val/default.yaml +++ /dev/null @@ -1,26 +0,0 @@ -dataset_str: - "GTASfMWAI( - split='${dataset.gta_sfm_wai.val.split}', - resolution=${dataset.gta_sfm_wai.val.dataset_resolution}, - principal_point_centered=${dataset.gta_sfm_wai.val.principal_point_centered}, - seed=${dataset.gta_sfm_wai.val.seed}, - transform='${dataset.gta_sfm_wai.val.transform}', - data_norm_type='${dataset.gta_sfm_wai.val.data_norm_type}', - ROOT='${dataset.gta_sfm_wai.val.ROOT}', - dataset_metadata_dir='${dataset.gta_sfm_wai.val.dataset_metadata_dir}', - overfit_num_sets=${dataset.gta_sfm_wai.val.overfit_num_sets}, - variable_num_views=${dataset.gta_sfm_wai.val.variable_num_views}, - num_views=${dataset.gta_sfm_wai.val.num_views}, - covisibility_thres=${dataset.gta_sfm_wai.val.covisibility_thres})" -split: 'val' -dataset_resolution: ${dataset.resolution_val_gta_sfm} -principal_point_centered: ${dataset.principal_point_centered} -seed: 777 -transform: 'imgnorm' -data_norm_type: ${model.data_norm_type} -ROOT: ${root_data_dir}/gta_sfm -dataset_metadata_dir: ${mapanything_dataset_metadata_dir} -overfit_num_sets: null -variable_num_views: ${dataset.val.variable_num_views} -num_views: ${dataset.num_views} -covisibility_thres: 0.25 diff --git a/configs/dataset/matrixcity_wai/default.yaml b/configs/dataset/matrixcity_wai/default.yaml deleted file mode 100644 index 78448df144ecb19398c761f707b5e264dcaaae29..0000000000000000000000000000000000000000 --- a/configs/dataset/matrixcity_wai/default.yaml +++ /dev/null @@ -1,3 +0,0 @@ -defaults: - - train: default - - val: default diff --git a/configs/dataset/matrixcity_wai/train/default.yaml b/configs/dataset/matrixcity_wai/train/default.yaml deleted file mode 100644 index ca7412ba48e10620b969af8311bb9e511bf5e437..0000000000000000000000000000000000000000 --- a/configs/dataset/matrixcity_wai/train/default.yaml +++ /dev/null @@ -1,26 +0,0 @@ -dataset_str: - "MatrixCityWAI( - split='${dataset.matrixcity_wai.train.split}', - resolution=${dataset.matrixcity_wai.train.dataset_resolution}, - principal_point_centered=${dataset.matrixcity_wai.train.principal_point_centered}, - aug_crop=${dataset.matrixcity_wai.train.aug_crop}, - transform='${dataset.matrixcity_wai.train.transform}', - data_norm_type='${dataset.matrixcity_wai.train.data_norm_type}', - ROOT='${dataset.matrixcity_wai.train.ROOT}', - dataset_metadata_dir='${dataset.matrixcity_wai.train.dataset_metadata_dir}', - overfit_num_sets=${dataset.matrixcity_wai.train.overfit_num_sets}, - variable_num_views=${dataset.matrixcity_wai.train.variable_num_views}, - num_views=${dataset.matrixcity_wai.train.num_views}, - covisibility_thres=${dataset.matrixcity_wai.train.covisibility_thres})" -split: 'train' -dataset_resolution: ${dataset.resolution_train} -principal_point_centered: ${dataset.principal_point_centered} -aug_crop: 16 -transform: 'colorjitter+grayscale+gaublur' -data_norm_type: ${model.data_norm_type} -ROOT: ${root_data_dir}/matrixcity -dataset_metadata_dir: ${mapanything_dataset_metadata_dir} -overfit_num_sets: null -variable_num_views: ${dataset.train.variable_num_views} -num_views: ${dataset.num_views} -covisibility_thres: 0.25 diff --git a/configs/dataset/matrixcity_wai/val/default.yaml b/configs/dataset/matrixcity_wai/val/default.yaml deleted file mode 100644 index 64a73059704da9e721b18b949ec3487b436c3607..0000000000000000000000000000000000000000 --- a/configs/dataset/matrixcity_wai/val/default.yaml +++ /dev/null @@ -1,26 +0,0 @@ -dataset_str: - "MatrixCityWAI( - split='${dataset.matrixcity_wai.val.split}', - resolution=${dataset.matrixcity_wai.val.dataset_resolution}, - principal_point_centered=${dataset.matrixcity_wai.val.principal_point_centered}, - seed=${dataset.matrixcity_wai.val.seed}, - transform='${dataset.matrixcity_wai.val.transform}', - data_norm_type='${dataset.matrixcity_wai.val.data_norm_type}', - ROOT='${dataset.matrixcity_wai.val.ROOT}', - dataset_metadata_dir='${dataset.matrixcity_wai.val.dataset_metadata_dir}', - overfit_num_sets=${dataset.matrixcity_wai.val.overfit_num_sets}, - variable_num_views=${dataset.matrixcity_wai.val.variable_num_views}, - num_views=${dataset.matrixcity_wai.val.num_views}, - covisibility_thres=${dataset.matrixcity_wai.val.covisibility_thres})" -split: 'val' -dataset_resolution: ${dataset.resolution_val_matrixcity} -principal_point_centered: ${dataset.principal_point_centered} -seed: 777 -transform: 'imgnorm' -data_norm_type: ${model.data_norm_type} -ROOT: ${root_data_dir}/matrixcity -dataset_metadata_dir: ${mapanything_dataset_metadata_dir} -overfit_num_sets: null -variable_num_views: ${dataset.val.variable_num_views} -num_views: ${dataset.num_views} -covisibility_thres: 0.25 diff --git a/configs/dataset/megatrain_12d_518_many_ar_24ipg_16g.yaml b/configs/dataset/megatrain_13d_518_many_ar_24ipg_8g.yaml similarity index 69% rename from configs/dataset/megatrain_12d_518_many_ar_24ipg_16g.yaml rename to configs/dataset/megatrain_13d_518_many_ar_24ipg_8g.yaml index e97bbb2c38a81334f7a57eb93854774c4df01b78..66485dff4794a1b1850dc7094ab2f092b985796e 100644 --- a/configs/dataset/megatrain_12d_518_many_ar_24ipg_16g.yaml +++ b/configs/dataset/megatrain_13d_518_many_ar_24ipg_8g.yaml @@ -14,6 +14,7 @@ resolution_train: ${dataset.resolution_options.518_many_ar} # Validation Resolution resolution_val_ase: ${dataset.resolution_options.518_1_00_ar} resolution_val_blendedmvs: ${dataset.resolution_options.518_1_33_ar} +resolution_val_dl3dv: ${dataset.resolution_options.518_1_77_ar} resolution_val_dynamicreplica: ${dataset.resolution_options.518_1_77_ar} resolution_val_megadepth: ${dataset.resolution_options.518_1_52_ar} resolution_val_mpsd: ${dataset.resolution_options.518_1_77_ar} @@ -27,23 +28,25 @@ resolution_val_unrealstereo4k: ${dataset.resolution_options.518_1_77_ar} # Training Set train_dataset: - "+ 58_000 @ ${dataset.ase_wai.train.dataset_str} - + 58_000 @ ${dataset.blendedmvs_wai.train.dataset_str} - + 45_000 @ ${dataset.dynamicreplica_wai.train.dataset_str} - + 58_000 @ ${dataset.megadepth_wai.train.dataset_str} - + 58_000 @ ${dataset.mpsd_wai.train.dataset_str} - + 58_000 @ ${dataset.mvs_synth_wai.train.dataset_str} - + 58_000 @ ${dataset.paralleldomain4d_wai.train.dataset_str} - + 58_000 @ ${dataset.sailvos3d_wai.train.dataset_str} - + 58_000 @ ${dataset.scannetpp_wai.train.dataset_str} - + 2_000 @ ${dataset.spring_wai.train.dataset_str} - + 58_000 @ ${dataset.tav2_wb_wai.train.dataset_str} - + 5_500 @ ${dataset.unrealstereo4k_wai.train.dataset_str}" + "+ 26_250 @ ${dataset.ase_wai.train.dataset_str} + + 26_250 @ ${dataset.blendedmvs_wai.train.dataset_str} + + 26_250 @ ${dataset.dl3dv_wai.train.dataset_str} + + 20_000 @ ${dataset.dynamicreplica_wai.train.dataset_str} + + 26_250 @ ${dataset.megadepth_wai.train.dataset_str} + + 26_250 @ ${dataset.mpsd_wai.train.dataset_str} + + 26_250 @ ${dataset.mvs_synth_wai.train.dataset_str} + + 26_250 @ ${dataset.paralleldomain4d_wai.train.dataset_str} + + 26_250 @ ${dataset.sailvos3d_wai.train.dataset_str} + + 26_250 @ ${dataset.scannetpp_wai.train.dataset_str} + + 1_000 @ ${dataset.spring_wai.train.dataset_str} + + 26_250 @ ${dataset.tav2_wb_wai.train.dataset_str} + + 2_750 @ ${dataset.unrealstereo4k_wai.train.dataset_str}" # Validation Set test_dataset: "+ 4_000 @ ${dataset.ase_wai.val.dataset_str} + 4_000 @ ${dataset.blendedmvs_wai.val.dataset_str} + + 4_000 @ ${dataset.dl3dv_wai.val.dataset_str} + 4_000 @ ${dataset.dynamicreplica_wai.val.dataset_str} + 4_000 @ ${dataset.megadepth_wai.val.dataset_str} + 4_000 @ ${dataset.mpsd_wai.val.dataset_str} diff --git a/configs/dataset/megatrain_11d_se_518_many_ar_48ipg_64g.yaml b/configs/dataset/megatrain_13d_518_many_ar_36ipg_64g.yaml similarity index 64% rename from configs/dataset/megatrain_11d_se_518_many_ar_48ipg_64g.yaml rename to configs/dataset/megatrain_13d_518_many_ar_36ipg_64g.yaml index a5c5f087b72828f98ae5406f871e763145dbd7b0..0fd0f06b6f892af8776fdbd98c61c370ba1078bf 100644 --- a/configs/dataset/megatrain_11d_se_518_many_ar_48ipg_64g.yaml +++ b/configs/dataset/megatrain_13d_518_many_ar_36ipg_64g.yaml @@ -13,8 +13,10 @@ resolution_train: ${dataset.resolution_options.518_many_ar} # Validation Resolution resolution_val_ase: ${dataset.resolution_options.518_1_00_ar} +resolution_val_blendedmvs: ${dataset.resolution_options.518_1_33_ar} resolution_val_dl3dv: ${dataset.resolution_options.518_1_77_ar} resolution_val_dynamicreplica: ${dataset.resolution_options.518_1_77_ar} +resolution_val_megadepth: ${dataset.resolution_options.518_1_52_ar} resolution_val_mpsd: ${dataset.resolution_options.518_1_77_ar} resolution_val_mvs_synth: ${dataset.resolution_options.518_1_77_ar} resolution_val_paralleldomain4d: ${dataset.resolution_options.518_1_33_ar} @@ -26,23 +28,27 @@ resolution_val_unrealstereo4k: ${dataset.resolution_options.518_1_77_ar} # Training Set train_dataset: - "+ 2_450_000 @ ${dataset.ase_wai.train.dataset_str} - + 250_000 @ ${dataset.dl3dv_wai.train.dataset_str} - + 12_400 @ ${dataset.dynamicreplica_wai.train.dataset_str} - + 1_675_000 @ ${dataset.mpsd_wai.train.dataset_str} - + 3_000 @ ${dataset.mvs_synth_wai.train.dataset_str} - + 36_000 @ ${dataset.paralleldomain4d_wai.train.dataset_str} - + 4_000 @ ${dataset.sailvos3d_wai.train.dataset_str} - + 22_600 @ ${dataset.scannetpp_wai.train.dataset_str} - + 800 @ ${dataset.spring_wai.train.dataset_str} - + 4_000 @ ${dataset.tav2_wb_wai.train.dataset_str} - + 200 @ ${dataset.unrealstereo4k_wai.train.dataset_str}" + "+ 315_000 @ ${dataset.ase_wai.train.dataset_str} + + 315_000 @ ${dataset.blendedmvs_wai.train.dataset_str} + + 315_000 @ ${dataset.dl3dv_wai.train.dataset_str} + + 240_000 @ ${dataset.dynamicreplica_wai.train.dataset_str} + + 315_000 @ ${dataset.megadepth_wai.train.dataset_str} + + 315_000 @ ${dataset.mpsd_wai.train.dataset_str} + + 315_000 @ ${dataset.mvs_synth_wai.train.dataset_str} + + 315_000 @ ${dataset.paralleldomain4d_wai.train.dataset_str} + + 315_000 @ ${dataset.sailvos3d_wai.train.dataset_str} + + 315_000 @ ${dataset.scannetpp_wai.train.dataset_str} + + 12_000 @ ${dataset.spring_wai.train.dataset_str} + + 315_000 @ ${dataset.tav2_wb_wai.train.dataset_str} + + 33_000 @ ${dataset.unrealstereo4k_wai.train.dataset_str}" # Validation Set test_dataset: "+ 4_000 @ ${dataset.ase_wai.val.dataset_str} + + 4_000 @ ${dataset.blendedmvs_wai.val.dataset_str} + 4_000 @ ${dataset.dl3dv_wai.val.dataset_str} + 4_000 @ ${dataset.dynamicreplica_wai.val.dataset_str} + + 4_000 @ ${dataset.megadepth_wai.val.dataset_str} + 4_000 @ ${dataset.mpsd_wai.val.dataset_str} + 4_000 @ ${dataset.mvs_synth_wai.val.dataset_str} + 4_000 @ ${dataset.paralleldomain4d_wai.val.dataset_str} diff --git a/configs/dataset/megatrain_13d_518_many_ar_48ipg_8g_mono.yaml b/configs/dataset/megatrain_13d_518_many_ar_48ipg_8g_mono.yaml new file mode 100644 index 0000000000000000000000000000000000000000..47b9ab207e6794776895b4b63d9be18586beead1 --- /dev/null +++ b/configs/dataset/megatrain_13d_518_many_ar_48ipg_8g_mono.yaml @@ -0,0 +1,59 @@ +defaults: + - default + +# Number of views parameter for the multi-view datasets +num_views: 1 + +train: + # If True, the number of views can vary from batch to batch. The maximum number of views is num_views and minimum is 2. (On by default for N-view training) + variable_num_views: true + +# Train Resolution +resolution_train: ${dataset.resolution_options.518_many_ar} + +# Validation Resolution +resolution_val_ase: ${dataset.resolution_options.518_1_00_ar} +resolution_val_blendedmvs: ${dataset.resolution_options.518_1_33_ar} +resolution_val_dl3dv: ${dataset.resolution_options.518_1_77_ar} +resolution_val_dynamicreplica: ${dataset.resolution_options.518_1_77_ar} +resolution_val_megadepth: ${dataset.resolution_options.518_1_52_ar} +resolution_val_mpsd: ${dataset.resolution_options.518_1_77_ar} +resolution_val_mvs_synth: ${dataset.resolution_options.518_1_77_ar} +resolution_val_paralleldomain4d: ${dataset.resolution_options.518_1_33_ar} +resolution_val_sailvos3d: ${dataset.resolution_options.518_1_52_ar} +resolution_val_scannetpp: ${dataset.resolution_options.518_1_52_ar} +resolution_val_spring: ${dataset.resolution_options.518_1_77_ar} +resolution_val_tav2_wb: ${dataset.resolution_options.518_1_00_ar} +resolution_val_unrealstereo4k: ${dataset.resolution_options.518_1_77_ar} + +# Training Set +train_dataset: + "+ 105_000 @ ${dataset.ase_wai.train.dataset_str} + + 105_000 @ ${dataset.blendedmvs_wai.train.dataset_str} + + 105_000 @ ${dataset.dl3dv_wai.train.dataset_str} + + 80_000 @ ${dataset.dynamicreplica_wai.train.dataset_str} + + 105_000 @ ${dataset.megadepth_wai.train.dataset_str} + + 105_000 @ ${dataset.mpsd_wai.train.dataset_str} + + 105_000 @ ${dataset.mvs_synth_wai.train.dataset_str} + + 105_000 @ ${dataset.paralleldomain4d_wai.train.dataset_str} + + 105_000 @ ${dataset.sailvos3d_wai.train.dataset_str} + + 105_000 @ ${dataset.scannetpp_wai.train.dataset_str} + + 4_000 @ ${dataset.spring_wai.train.dataset_str} + + 105_000 @ ${dataset.tav2_wb_wai.train.dataset_str} + + 11_000 @ ${dataset.unrealstereo4k_wai.train.dataset_str}" + +# Validation Set +test_dataset: + "+ 4_000 @ ${dataset.ase_wai.val.dataset_str} + + 4_000 @ ${dataset.blendedmvs_wai.val.dataset_str} + + 4_000 @ ${dataset.dl3dv_wai.val.dataset_str} + + 4_000 @ ${dataset.dynamicreplica_wai.val.dataset_str} + + 4_000 @ ${dataset.megadepth_wai.val.dataset_str} + + 4_000 @ ${dataset.mpsd_wai.val.dataset_str} + + 4_000 @ ${dataset.mvs_synth_wai.val.dataset_str} + + 4_000 @ ${dataset.paralleldomain4d_wai.val.dataset_str} + + 4_000 @ ${dataset.sailvos3d_wai.val.dataset_str} + + 4_000 @ ${dataset.scannetpp_wai.val.dataset_str} + + 500 @ ${dataset.spring_wai.val.dataset_str} + + 4_000 @ ${dataset.tav2_wb_wai.val.dataset_str} + + 500 @ ${dataset.unrealstereo4k_wai.val.dataset_str}" diff --git a/configs/dataset/megatrain_6d_518_many_ar_36ipg_64g.yaml b/configs/dataset/megatrain_6d_518_many_ar_36ipg_64g.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e9ed6e9faf0b0cb86c8c7d2c4db6f10e79987e71 --- /dev/null +++ b/configs/dataset/megatrain_6d_518_many_ar_36ipg_64g.yaml @@ -0,0 +1,38 @@ +defaults: + - default + +# Number of views parameter for the multi-view datasets +num_views: 4 + +train: + # If True, the number of views can vary from batch to batch. The maximum number of views is num_views and minimum is 2. (On by default for N-view training) + variable_num_views: true + +# Train Resolution +resolution_train: ${dataset.resolution_options.518_many_ar} + +# Validation Resolution +resolution_val_blendedmvs: ${dataset.resolution_options.518_1_33_ar} +resolution_val_mpsd: ${dataset.resolution_options.518_1_77_ar} +resolution_val_scannetpp: ${dataset.resolution_options.518_1_52_ar} +resolution_val_spring: ${dataset.resolution_options.518_1_77_ar} +resolution_val_tav2_wb: ${dataset.resolution_options.518_1_00_ar} +resolution_val_unrealstereo4k: ${dataset.resolution_options.518_1_77_ar} + +# Training Set +train_dataset: + "+ 840_000 @ ${dataset.blendedmvs_wai.train.dataset_str} + + 840_000 @ ${dataset.mpsd_wai.train.dataset_str} + + 840_000 @ ${dataset.scannetpp_wai.train.dataset_str} + + 33_000 @ ${dataset.spring_wai.train.dataset_str} + + 840_000 @ ${dataset.tav2_wb_wai.train.dataset_str} + + 87_000 @ ${dataset.unrealstereo4k_wai.train.dataset_str}" + +# Validation Set +test_dataset: + "+ 4_000 @ ${dataset.blendedmvs_wai.val.dataset_str} + + 4_000 @ ${dataset.mpsd_wai.val.dataset_str} + + 4_000 @ ${dataset.scannetpp_wai.val.dataset_str} + + 500 @ ${dataset.spring_wai.val.dataset_str} + + 4_000 @ ${dataset.tav2_wb_wai.val.dataset_str} + + 500 @ ${dataset.unrealstereo4k_wai.val.dataset_str}" diff --git a/configs/dataset/structured3d_wai/default.yaml b/configs/dataset/structured3d_wai/default.yaml deleted file mode 100644 index 78448df144ecb19398c761f707b5e264dcaaae29..0000000000000000000000000000000000000000 --- a/configs/dataset/structured3d_wai/default.yaml +++ /dev/null @@ -1,3 +0,0 @@ -defaults: - - train: default - - val: default diff --git a/configs/dataset/structured3d_wai/train/default.yaml b/configs/dataset/structured3d_wai/train/default.yaml deleted file mode 100644 index 8556d92cf9f9e830b35c797778ca63f79bc31a56..0000000000000000000000000000000000000000 --- a/configs/dataset/structured3d_wai/train/default.yaml +++ /dev/null @@ -1,26 +0,0 @@ -dataset_str: - "Structured3DWAI( - split='${dataset.structured3d_wai.train.split}', - resolution=${dataset.structured3d_wai.train.dataset_resolution}, - principal_point_centered=${dataset.structured3d_wai.train.principal_point_centered}, - aug_crop=${dataset.structured3d_wai.train.aug_crop}, - transform='${dataset.structured3d_wai.train.transform}', - data_norm_type='${dataset.structured3d_wai.train.data_norm_type}', - ROOT='${dataset.structured3d_wai.train.ROOT}', - dataset_metadata_dir='${dataset.structured3d_wai.train.dataset_metadata_dir}', - overfit_num_sets=${dataset.structured3d_wai.train.overfit_num_sets}, - variable_num_views=${dataset.structured3d_wai.train.variable_num_views}, - num_views=${dataset.structured3d_wai.train.num_views}, - covisibility_thres=${dataset.structured3d_wai.train.covisibility_thres})" -split: 'train' -dataset_resolution: ${dataset.resolution_train} -principal_point_centered: ${dataset.principal_point_centered} -aug_crop: 16 -transform: 'colorjitter+grayscale+gaublur' -data_norm_type: ${model.data_norm_type} -ROOT: ${root_data_dir}/structured3d -dataset_metadata_dir: ${mapanything_dataset_metadata_dir} -overfit_num_sets: null -variable_num_views: ${dataset.train.variable_num_views} -num_views: ${dataset.num_views} -covisibility_thres: 0.25 diff --git a/configs/dataset/structured3d_wai/val/default.yaml b/configs/dataset/structured3d_wai/val/default.yaml deleted file mode 100644 index 396a399b95993b4c20e452aa77f8e43d7177205b..0000000000000000000000000000000000000000 --- a/configs/dataset/structured3d_wai/val/default.yaml +++ /dev/null @@ -1,26 +0,0 @@ -dataset_str: - "Structured3DWAI( - split='${dataset.structured3d_wai.val.split}', - resolution=${dataset.structured3d_wai.val.dataset_resolution}, - principal_point_centered=${dataset.structured3d_wai.val.principal_point_centered}, - seed=${dataset.structured3d_wai.val.seed}, - transform='${dataset.structured3d_wai.val.transform}', - data_norm_type='${dataset.structured3d_wai.val.data_norm_type}', - ROOT='${dataset.structured3d_wai.val.ROOT}', - dataset_metadata_dir='${dataset.structured3d_wai.val.dataset_metadata_dir}', - overfit_num_sets=${dataset.structured3d_wai.val.overfit_num_sets}, - variable_num_views=${dataset.structured3d_wai.val.variable_num_views}, - num_views=${dataset.structured3d_wai.val.num_views}, - covisibility_thres=${dataset.structured3d_wai.val.covisibility_thres})" -split: 'val' -dataset_resolution: ${dataset.resolution_val_structured3d} -principal_point_centered: ${dataset.principal_point_centered} -seed: 777 -transform: 'imgnorm' -data_norm_type: ${model.data_norm_type} -ROOT: ${root_data_dir}/structured3d -dataset_metadata_dir: ${mapanything_dataset_metadata_dir} -overfit_num_sets: null -variable_num_views: ${dataset.val.variable_num_views} -num_views: ${dataset.num_views} -covisibility_thres: 0.25 diff --git a/configs/dataset/xrooms_wai/default.yaml b/configs/dataset/xrooms_wai/default.yaml deleted file mode 100644 index 78448df144ecb19398c761f707b5e264dcaaae29..0000000000000000000000000000000000000000 --- a/configs/dataset/xrooms_wai/default.yaml +++ /dev/null @@ -1,3 +0,0 @@ -defaults: - - train: default - - val: default diff --git a/configs/dataset/xrooms_wai/train/default.yaml b/configs/dataset/xrooms_wai/train/default.yaml deleted file mode 100644 index 2a6131e36392f0efe537dbdf3b6767c83d7b9a3b..0000000000000000000000000000000000000000 --- a/configs/dataset/xrooms_wai/train/default.yaml +++ /dev/null @@ -1,26 +0,0 @@ -dataset_str: - "XRoomsWAI( - split='${dataset.xrooms_wai.train.split}', - resolution=${dataset.xrooms_wai.train.dataset_resolution}, - principal_point_centered=${dataset.xrooms_wai.train.principal_point_centered}, - aug_crop=${dataset.xrooms_wai.train.aug_crop}, - transform='${dataset.xrooms_wai.train.transform}', - data_norm_type='${dataset.xrooms_wai.train.data_norm_type}', - ROOT='${dataset.xrooms_wai.train.ROOT}', - dataset_metadata_dir='${dataset.xrooms_wai.train.dataset_metadata_dir}', - overfit_num_sets=${dataset.xrooms_wai.train.overfit_num_sets}, - variable_num_views=${dataset.xrooms_wai.train.variable_num_views}, - num_views=${dataset.xrooms_wai.train.num_views}, - covisibility_thres=${dataset.xrooms_wai.train.covisibility_thres})" -split: 'train' -dataset_resolution: ${dataset.resolution_train} -principal_point_centered: ${dataset.principal_point_centered} -aug_crop: 16 -transform: 'colorjitter+grayscale+gaublur' -data_norm_type: ${model.data_norm_type} -ROOT: ${root_data_dir}/xrooms -dataset_metadata_dir: ${mapanything_dataset_metadata_dir} -overfit_num_sets: null -variable_num_views: ${dataset.train.variable_num_views} -num_views: ${dataset.num_views} -covisibility_thres: 0.25 diff --git a/configs/dataset/xrooms_wai/val/default.yaml b/configs/dataset/xrooms_wai/val/default.yaml deleted file mode 100644 index 90044d43e2cc0f3f6f67cd6e73e27c4898f88d30..0000000000000000000000000000000000000000 --- a/configs/dataset/xrooms_wai/val/default.yaml +++ /dev/null @@ -1,26 +0,0 @@ -dataset_str: - "XRoomsWAI( - split='${dataset.xrooms_wai.val.split}', - resolution=${dataset.xrooms_wai.val.dataset_resolution}, - principal_point_centered=${dataset.xrooms_wai.val.principal_point_centered}, - seed=${dataset.xrooms_wai.val.seed}, - transform='${dataset.xrooms_wai.val.transform}', - data_norm_type='${dataset.xrooms_wai.val.data_norm_type}', - ROOT='${dataset.xrooms_wai.val.ROOT}', - dataset_metadata_dir='${dataset.xrooms_wai.val.dataset_metadata_dir}', - overfit_num_sets=${dataset.xrooms_wai.val.overfit_num_sets}, - variable_num_views=${dataset.xrooms_wai.val.variable_num_views}, - num_views=${dataset.xrooms_wai.val.num_views}, - covisibility_thres=${dataset.xrooms_wai.val.covisibility_thres})" -split: 'val' -dataset_resolution: ${dataset.resolution_val_xrooms} -principal_point_centered: ${dataset.principal_point_centered} -seed: 777 -transform: 'imgnorm' -data_norm_type: ${model.data_norm_type} -ROOT: ${root_data_dir}/xrooms -dataset_metadata_dir: ${mapanything_dataset_metadata_dir} -overfit_num_sets: null -variable_num_views: ${dataset.val.variable_num_views} -num_views: ${dataset.num_views} -covisibility_thres: 0.25 diff --git a/configs/loss/moge2_loss.yaml b/configs/loss/moge2_loss.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6ec8604fbfd2608bc5fd8999966562392fe3ced9 --- /dev/null +++ b/configs/loss/moge2_loss.yaml @@ -0,0 +1,4 @@ +# Training Loss +train_criterion: "ExcludeTopNPercentPixelLoss(Regr3D(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='?avg_dis', loss_in_log=True, flatten_across_image_only=True), top_n_percent=5, apply_to_real_data_only=True) + 3.0 * NormalGMLoss(norm_mode='avg_dis', apply_normal_and_gm_loss_to_synthetic_data_only=True)" +# Validation Loss +test_criterion: "ExcludeTopNPercentPixelLoss(Regr3D(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='?avg_dis', loss_in_log=True, flatten_across_image_only=True), top_n_percent=5, apply_to_real_data_only=True) + 3.0 * NormalGMLoss(norm_mode='avg_dis', apply_normal_and_gm_loss_to_synthetic_data_only=True)" diff --git a/configs/loss/overall_loss_highpm_plus_rel_pose.yaml b/configs/loss/overall_loss_highpm_plus_rel_pose.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5f253bc4ede8bb7788ad6159efbc38036cc84fe4 --- /dev/null +++ b/configs/loss/overall_loss_highpm_plus_rel_pose.yaml @@ -0,0 +1,4 @@ +# Training Loss +train_criterion: "ConfAndExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_absolute_pose_loss=True, compute_pairwise_relative_pose_loss=True, convert_predictions_to_view0_frame=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, cam_frame_points_loss_weight=0.1, depth_loss_weight=0.1, ray_directions_loss_weight=0.1, pose_quats_loss_weight=0.1, pose_trans_loss_weight=0.1, scale_loss_weight=0.1, world_frame_points_loss_weight=1, normal_loss_weight=0.3, gm_loss_weight=0.3), conf_alpha=0.2, top_n_percent=5, apply_to_real_data_only=True, conf_loss_set_indices=[0], exclude_loss_set_indices=[1, 2]) + 0.03 * NonAmbiguousMaskLoss(BCELoss())" +# Validation Loss +test_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_absolute_pose_loss=True, compute_pairwise_relative_pose_loss=True, convert_predictions_to_view0_frame=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, cam_frame_points_loss_weight=0.1, depth_loss_weight=0.1, ray_directions_loss_weight=0.1, pose_quats_loss_weight=0.1, pose_trans_loss_weight=0.1, scale_loss_weight=0.1, world_frame_points_loss_weight=1, normal_loss_weight=0.3, gm_loss_weight=0.3), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0, 1, 2]) + 0.03 * NonAmbiguousMaskLoss(BCELoss())" diff --git a/configs/loss/overall_loss_highpm_plus_rel_pose_no_conf.yaml b/configs/loss/overall_loss_highpm_plus_rel_pose_no_conf.yaml new file mode 100644 index 0000000000000000000000000000000000000000..00d9c0e44b5fe9d1fb864020d87102e641566788 --- /dev/null +++ b/configs/loss/overall_loss_highpm_plus_rel_pose_no_conf.yaml @@ -0,0 +1,4 @@ +# Training Loss +train_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_absolute_pose_loss=True, compute_pairwise_relative_pose_loss=True, convert_predictions_to_view0_frame=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, cam_frame_points_loss_weight=0.1, depth_loss_weight=0.1, ray_directions_loss_weight=0.1, pose_quats_loss_weight=0.1, pose_trans_loss_weight=0.1, scale_loss_weight=0.1, world_frame_points_loss_weight=1, normal_loss_weight=0.3, gm_loss_weight=0.3), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0, 1, 2]) + 0.03 * NonAmbiguousMaskLoss(BCELoss())" +# Validation Loss +test_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_absolute_pose_loss=True, compute_pairwise_relative_pose_loss=True, convert_predictions_to_view0_frame=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, cam_frame_points_loss_weight=0.1, depth_loss_weight=0.1, ray_directions_loss_weight=0.1, pose_quats_loss_weight=0.1, pose_trans_loss_weight=0.1, scale_loss_weight=0.1, world_frame_points_loss_weight=1, normal_loss_weight=0.3, gm_loss_weight=0.3), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0, 1, 2]) + 0.03 * NonAmbiguousMaskLoss(BCELoss())" diff --git a/configs/loss/overall_loss_highpm_rel_pose_no_ref_view.yaml b/configs/loss/overall_loss_highpm_rel_pose_no_ref_view.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3e59edb63525c862d5859d96ef0795ed0b47813a --- /dev/null +++ b/configs/loss/overall_loss_highpm_rel_pose_no_ref_view.yaml @@ -0,0 +1,4 @@ +# Training Loss +train_criterion: "ConfAndExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_absolute_pose_loss=False, compute_pairwise_relative_pose_loss=True, convert_predictions_to_view0_frame=True, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, cam_frame_points_loss_weight=0.1, depth_loss_weight=0.1, ray_directions_loss_weight=0.1, pose_quats_loss_weight=0.1, pose_trans_loss_weight=0.1, scale_loss_weight=0.1, world_frame_points_loss_weight=1, normal_loss_weight=0.3, gm_loss_weight=0.3), conf_alpha=0.2, top_n_percent=5, apply_to_real_data_only=True, conf_loss_set_indices=[0], exclude_loss_set_indices=[1, 2]) + 0.03 * NonAmbiguousMaskLoss(BCELoss())" +# Validation Loss +test_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_absolute_pose_loss=False, compute_pairwise_relative_pose_loss=True, convert_predictions_to_view0_frame=True, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, cam_frame_points_loss_weight=0.1, depth_loss_weight=0.1, ray_directions_loss_weight=0.1, pose_quats_loss_weight=0.1, pose_trans_loss_weight=0.1, scale_loss_weight=0.1, world_frame_points_loss_weight=1, normal_loss_weight=0.3, gm_loss_weight=0.3), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0, 1, 2]) + 0.03 * NonAmbiguousMaskLoss(BCELoss())" diff --git a/configs/loss/pi3_loss.yaml b/configs/loss/pi3_loss.yaml new file mode 100644 index 0000000000000000000000000000000000000000..998f9028698d0500d69201b7d17f5f7aee0ca206 --- /dev/null +++ b/configs/loss/pi3_loss.yaml @@ -0,0 +1,4 @@ +# Training Loss +train_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_z', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=True, convert_predictions_to_view0_frame=True, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0, 1, 2])" +# Validation Loss +test_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_z', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=True, convert_predictions_to_view0_frame=True, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0, 1, 2])" diff --git a/configs/machine/aws.yaml b/configs/machine/aws.yaml index 584639ff57977dcb505b43e866443ac8df155295..e9cba6a072f897bccf6428f2b717977c6bb5fddd 100644 --- a/configs/machine/aws.yaml +++ b/configs/machine/aws.yaml @@ -2,12 +2,14 @@ defaults: - default # Root directory containing all datasets -root_data_dir: "/fsx/xrtech/data" +root_data_dir: "/ai4rl/fsx/xrtech/data" # Dataset metadata directory -mapanything_dataset_metadata_dir: "/fsx/nkeetha/mapanything_dataset_metadata" +mapanything_dataset_metadata_dir: "/ai4rl/fsx/nkeetha/mapanything_dataset_metadata" # Root directory containing pretrained checkpoints for custom models -root_pretrained_checkpoints_dir: "/fsx/nkeetha/mapanything_checkpoints" +root_pretrained_checkpoints_dir: "/ai4rl/fsx/nkeetha/mapanything_checkpoints" # Root directory to log experiments -root_experiments_dir: "/fsx/nkeetha/experiments" +root_experiments_dir: "/ai4rl/fsx/nkeetha/experiments" # Root directory containing UniCeption pretrained checkpoints -root_uniception_pretrained_checkpoints_dir: "/fsx/nkeetha/uniception_checkpoints" +root_uniception_pretrained_checkpoints_dir: "/ai4rl/fsx/nkeetha/uniception_checkpoints" +# Root directory containing external benchmark data +external_benchmark_data_root_data_dir: "/ai4rl/fsx/xrtech/external_benchmark_data/rmvd_mvs_benchmark/rmvd_test_data" diff --git a/configs/machine/default.yaml b/configs/machine/default.yaml index d89f6d360340a9eb79a12707ce19b07aba3a083a..1155bbd9ebee093c7e998b980ba56daa5fea0209 100644 --- a/configs/machine/default.yaml +++ b/configs/machine/default.yaml @@ -8,3 +8,5 @@ root_pretrained_checkpoints_dir: ??? root_experiments_dir: ??? # Root directory containing UniCeption pretrained checkpoints root_uniception_pretrained_checkpoints_dir: ??? +# Root directory containing external benchmark data +external_benchmark_data_root_data_dir: ??? diff --git a/configs/machine/psc.yaml b/configs/machine/psc.yaml index 1e529be838f2f9838e428bb25244b7b4c94aefa3..bcf88df5ce9a68c9d53f494f0566970c8460f24b 100644 --- a/configs/machine/psc.yaml +++ b/configs/machine/psc.yaml @@ -6,8 +6,10 @@ root_data_dir: "/ocean/projects/cis220039p/shared/datasets" # Dataset metadata directory mapanything_dataset_metadata_dir: "/ocean/projects/cis220039p/shared/mapanything_dataset_metadata" # Root directory containing pretrained checkpoints for custom models -root_pretrained_checkpoints_dir: "/ocean/projects/cis220039p/nkeetha/code/AnyMap/checkpoints" +root_pretrained_checkpoints_dir: "/jet/home/yzhang25/mapanything/checkpoints" # Root directory to log experiments -root_experiments_dir: "/ocean/projects/cis220039p/nkeetha/experiments" +root_experiments_dir: "/jet/home/yzhang25/mapanything/outputs" # Root directory containing UniCeption pretrained checkpoints -root_uniception_pretrained_checkpoints_dir: "/ocean/projects/cis220039p/nkeetha/code/AnyMap/UniCeption/checkpoints" +root_uniception_pretrained_checkpoints_dir: "/ocean/projects/cis220039p/shared/uniception/checkpoints/" +# Root directory containing external benchmark data +external_benchmark_data_root_data_dir: "/jet/home/yzhang25/mapanything/benchmarking/rmvd_mvs_benchmark/rmvd_test_data" diff --git a/configs/machine/psc_yuchen.yaml b/configs/machine/psc_yuchen.yaml deleted file mode 100644 index d071afa04eadb2cd35b4f41bc12c1e8f820fe27c..0000000000000000000000000000000000000000 --- a/configs/machine/psc_yuchen.yaml +++ /dev/null @@ -1,13 +0,0 @@ -defaults: - - default - -# Root directory containing all datasets -root_data_dir: "/ocean/projects/cis220039p/shared/datasets" -# Dataset metadata directory -mapanything_dataset_metadata_dir: "/ocean/projects/cis220039p/shared/mapanything_dataset_metadata" -# Root directory containing pretrained checkpoints for custom models -root_pretrained_checkpoints_dir: "/jet/home/yzhang25/AnyMap/checkpoints" -# Root directory to log experiments -root_experiments_dir: "/jet/home/yzhang25/AnyMap/outputs" -# Root directory containing UniCeption pretrained checkpoints -root_uniception_pretrained_checkpoints_dir: "/ocean/projects/cis220039p/shared/uniception/checkpoints/" diff --git a/configs/machine/xri_dgx.yaml b/configs/machine/xri_dgx.yaml index ba77beedfdd736d6a6d7c236ff3a1b9033ea8b24..3a0702886b2c09eb92d3f7ca47bdfee8006854c6 100644 --- a/configs/machine/xri_dgx.yaml +++ b/configs/machine/xri_dgx.yaml @@ -6,8 +6,10 @@ root_data_dir: "/mnt/xri_mapsresearch/data/nkeetha" # Dataset metadata directory mapanything_dataset_metadata_dir: "/mnt/xri_mapsresearch/data/nkeetha/mapanything_dataset_metadata" # Root directory containing pretrained checkpoints for custom models -root_pretrained_checkpoints_dir: "/mnt/xri_mapsresearch/code/nkeetha/AnyMap/checkpoints" +root_pretrained_checkpoints_dir: "/mnt/xri_mapsresearch/code/nkeetha/mapanything/checkpoints" # Root directory to log experiments root_experiments_dir: "/mnt/xri_mapsresearch/experiments/nkeetha" # Root directory containing UniCeption pretrained checkpoints -root_uniception_pretrained_checkpoints_dir: "/mnt/xri_mapsresearch/code/nkeetha/AnyMap/UniCeption/checkpoints" +root_uniception_pretrained_checkpoints_dir: "/mnt/xri_mapsresearch/code/nkeetha/mapanything/UniCeption/checkpoints" +# Root directory containing external benchmark data +external_benchmark_data_root_data_dir: "/mnt/xri_mapsresearch/data/nkeetha/rmvd_mvs_benchmark/rmvd_test_data" diff --git a/configs/model/da3.yaml b/configs/model/da3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..89ac0bb75d6fda432abebf6be30e604602652b6e --- /dev/null +++ b/configs/model/da3.yaml @@ -0,0 +1,13 @@ +# String for model factory +model_str: "da3" +# Model config +model_config: + name: "da3" + # HF model string + hf_model_name: "depth-anything/DA3-GIANT" +# Image Normalization Type +data_norm_type: "dinov2" +# DA3 checkpoint is already loaded in the inference wrapper +pretrained: null +# Torch hub force reload +torch_hub_force_reload: False diff --git a/configs/model/da3_nested.yaml b/configs/model/da3_nested.yaml new file mode 100644 index 0000000000000000000000000000000000000000..77ab699517e6a4cdbb111acb322aa8e36cf4d294 --- /dev/null +++ b/configs/model/da3_nested.yaml @@ -0,0 +1,13 @@ +# String for model factory +model_str: "da3" +# Model config +model_config: + name: "da3_nested" + # HF model string + hf_model_name: "depth-anything/DA3NESTED-GIANT-LARGE" +# Image Normalization Type +data_norm_type: "dinov2" +# DA3 checkpoint is already loaded in the inference wrapper +pretrained: null +# Torch hub force reload +torch_hub_force_reload: False diff --git a/configs/model/encoder/dinov2_giant_24_layers.yaml b/configs/model/encoder/dinov2_giant_24_layers.yaml new file mode 100644 index 0000000000000000000000000000000000000000..df63897853429d105c91585fe31ebd5979b2dc01 --- /dev/null +++ b/configs/model/encoder/dinov2_giant_24_layers.yaml @@ -0,0 +1,18 @@ +# UniCeption encoder string used for selecting encoder class (python3 -m uniception.models.encoders.list) +encoder_str: "dinov2" +# Name of the encoder +name: "dinov2_giant_24_layers" +# Data normalization type +data_norm_type: "dinov2" +# ViT size +size: "giant" +# Registers +with_registers: False +# Flag to indicate whether model class uses torch hub +uses_torch_hub: True +# Flag to indicate whether to use gradient checkpointing for encoder +gradient_checkpointing: False +# Turn off final normalization so that the features can be passed to DINOv2 init multi-view transformer +norm_returned_features: False +# Keep only the first 24 layers of DINOv2 ViT-G (other 16 layers are in multi-view transformer) +keep_first_n_layers: 24 diff --git a/configs/model/info_sharing/aat_ifr_16_layers_dinov2_vitg_init.yaml b/configs/model/info_sharing/aat_ifr_16_layers_dinov2_vitg_init.yaml new file mode 100644 index 0000000000000000000000000000000000000000..96c63c180e55eadfc9a6233c756765acea96f19a --- /dev/null +++ b/configs/model/info_sharing/aat_ifr_16_layers_dinov2_vitg_init.yaml @@ -0,0 +1,33 @@ +# Model type (Options: ["cross_attention", "global_attention", "alternating_attention"]) +model_type: "alternating_attention" +# Model class type (Options: ["no_intermediate_features", "intermediate_features"]) +model_return_type: "intermediate_features" +# Custom positional encoding (Options: ["RoPEfreq"], Callable Function, null) +custom_positional_encoding: null +# Module arguments +module_args: + # Name of the info sharing module + name: "aat_16_layers_dinov2_vitg_init" + # Indices of the intermediate features to be shared (indices start from 0) + indices: [7, 11] + # Normalize intermediate features + norm_intermediate: True + # Size string + size: "16_layers" + # Depth (this includes both frame-wise and gloabl attention layers) + depth: 16 + # Distinguish Reference and Non-Reference Views + distinguish_ref_and_non_ref_views: True + # Flag to indicate whether to use gradient checkpointing + gradient_checkpointing: False + # Feature dim (similar to ViT-Giant) + dim: 1536 + # Number of heads (similar to ViT-Giant) + num_heads: 24 + # Set transformer parameters similar to DINOv2 + mlp_ratio: 4 + qkv_bias: True + qk_norm: False + init_values: 1e-5 + # Load layers 24 to 40 from DINOv2 ViT-G as init + pretrained_checkpoint_path: '${machine.root_pretrained_checkpoints_dir}/aat_init_w_dinov2_vitg_layers_24_to_40.pth' diff --git a/configs/model/info_sharing/aat_ifr_16_layers_vitg_dim.yaml b/configs/model/info_sharing/aat_ifr_16_layers_vitg_dim.yaml new file mode 100644 index 0000000000000000000000000000000000000000..823cdb2f0ec431f4bcf248915169539af17398f3 --- /dev/null +++ b/configs/model/info_sharing/aat_ifr_16_layers_vitg_dim.yaml @@ -0,0 +1,31 @@ +# Model type (Options: ["cross_attention", "global_attention", "alternating_attention"]) +model_type: "alternating_attention" +# Model class type (Options: ["no_intermediate_features", "intermediate_features"]) +model_return_type: "intermediate_features" +# Custom positional encoding (Options: ["RoPEfreq"], Callable Function, null) +custom_positional_encoding: null +# Module arguments +module_args: + # Name of the info sharing module + name: "aat_16_layers_vitg_dim_ifr" + # Indices of the intermediate features to be shared (indices start from 0) + indices: [7, 11] + # Normalize intermediate features + norm_intermediate: True + # Size string + size: "16_layers" + # Depth (this includes both frame-wise and gloabl attention layers) + depth: 16 + # Distinguish Reference and Non-Reference Views + distinguish_ref_and_non_ref_views: True + # Flag to indicate whether to use gradient checkpointing + gradient_checkpointing: False + # Feature dim (similar to ViT-Giant) + dim: 1536 + # Number of heads (similar to ViT-Giant) + num_heads: 24 + # Set transformer parameters similar to DINOv2 + mlp_ratio: 4 + qkv_bias: True + qk_norm: False + init_values: 1e-5 diff --git a/configs/model/mapanything.yaml b/configs/model/mapanything.yaml index a69c44ec9cd9d684e2dc0764365d462fcc095030..e37faabfae7681bfa8699bbc2edcc7e3b26e9cca 100644 --- a/configs/model/mapanything.yaml +++ b/configs/model/mapanything.yaml @@ -1,7 +1,7 @@ defaults: - default - - encoder: dinov2_large - - info_sharing: aat_ifr_24_layers + - encoder: dinov2_giant_24_layers + - info_sharing: aat_ifr_16_layers_vitg_dim - pred_head: dpt_pose_scale - task: images_only @@ -14,5 +14,7 @@ model_config: info_sharing_config: ${model.info_sharing} pred_head_config: ${model.pred_head} geometric_input_config: ${model.task} + use_register_tokens_from_encoder: True + info_sharing_mlp_layer_str: "swiglufused" # Image Normalization Type data_norm_type: ${model.encoder.data_norm_type} diff --git a/configs/model/mapanything_large_inference.yaml b/configs/model/mapanything_dino_init.yaml similarity index 70% rename from configs/model/mapanything_large_inference.yaml rename to configs/model/mapanything_dino_init.yaml index 2e57bf7ddbcc416d4e637276202e60515535e0e1..98e0c7ae1d1405478f8ab6b06e773085f8756993 100644 --- a/configs/model/mapanything_large_inference.yaml +++ b/configs/model/mapanything_dino_init.yaml @@ -1,7 +1,7 @@ defaults: - default - - encoder: dinov2_large - - info_sharing: aat_ifr_48_layers_escaling + - encoder: dinov2_giant_24_layers + - info_sharing: aat_ifr_16_layers_dinov2_vitg_init - pred_head: dpt_pose_scale - task: images_only @@ -14,5 +14,7 @@ model_config: info_sharing_config: ${model.info_sharing} pred_head_config: ${model.pred_head} geometric_input_config: ${model.task} + use_register_tokens_from_encoder: True + info_sharing_mlp_layer_str: "swiglufused" # Image Normalization Type data_norm_type: ${model.encoder.data_norm_type} diff --git a/configs/model/mapanything_inference.yaml b/configs/model/mapanything_inference.yaml deleted file mode 100644 index a081212dc00ee6b2e105254346129e430531f642..0000000000000000000000000000000000000000 --- a/configs/model/mapanything_inference.yaml +++ /dev/null @@ -1,18 +0,0 @@ -defaults: - - default - - encoder: dinov2_large - - info_sharing: aat_ifr_24_layers_escaling - - pred_head: dpt_pose_scale - - task: images_only - -# String for model factory -model_str: "mapanything" -# Model config -model_config: - name: "mapanything" - encoder_config: ${model.encoder} - info_sharing_config: ${model.info_sharing} - pred_head_config: ${model.pred_head} - geometric_input_config: ${model.task} -# Image Normalization Type -data_norm_type: ${model.encoder.data_norm_type} diff --git a/configs/model/mapanything_large.yaml b/configs/model/mapanything_v1.yaml similarity index 92% rename from configs/model/mapanything_large.yaml rename to configs/model/mapanything_v1.yaml index 371e2fe6513fee6f5a0619aff5fff0c515a4d8d9..a69c44ec9cd9d684e2dc0764365d462fcc095030 100644 --- a/configs/model/mapanything_large.yaml +++ b/configs/model/mapanything_v1.yaml @@ -1,7 +1,7 @@ defaults: - default - encoder: dinov2_large - - info_sharing: aat_ifr_48_layers + - info_sharing: aat_ifr_24_layers - pred_head: dpt_pose_scale - task: images_only diff --git a/configs/rmvd_benchmark.yaml b/configs/rmvd_benchmark.yaml index b4042f433e9eabcf9fde4221d81c443495af8a64..e4fe28bda7aab7f333550e814a290d0fc9972dde 100644 --- a/configs/rmvd_benchmark.yaml +++ b/configs/rmvd_benchmark.yaml @@ -6,7 +6,7 @@ defaults: # Path Settings output_dir: ${hydra:run.dir} -root_data_dir: ${machine.root_data_dir} +external_benchmark_data_root_data_dir: ${machine.external_benchmark_data_root_data_dir} mapanything_dataset_metadata_dir: ${machine.mapanything_dataset_metadata_dir} root_pretrained_checkpoints_dir: ${machine.root_pretrained_checkpoints_dir} root_experiments_dir: ${machine.root_experiments_dir} diff --git a/configs/train_params/moge2_finetune.yaml b/configs/train_params/moge2_finetune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9b60a42c8b65e2e7be3e99a4442b981d10765f26 --- /dev/null +++ b/configs/train_params/moge2_finetune.yaml @@ -0,0 +1,6 @@ +defaults: + - default + +# Use lower lr for finetuning +lr: 1e-05 +min_lr: 1e-07 diff --git a/configs/train_params/pi3_finetune.yaml b/configs/train_params/pi3_finetune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1d4d13a9fefcffc48ab9c1027e4ce831877b63eb --- /dev/null +++ b/configs/train_params/pi3_finetune.yaml @@ -0,0 +1,16 @@ +defaults: + - default + +# Use lower lr for finetuning +lr: 1e-05 +min_lr: 1e-07 + +# Optimizer parameters specific to submodules +submodule_configs: + # DINOv2 + model.encoder: + lr: 5e-07 + min_lr: 5e-09 + warmup_epochs: ${train_params.warmup_epochs} + weight_decay: ${train_params.weight_decay} + schedule_type: ${train_params.schedule_type} diff --git a/configs/train_params/vggt_finetune.yaml b/configs/train_params/vggt_finetune.yaml index d130d4f12ad4d8e2de2165bfeb058addd0b668a8..7cf885171c4707199d5d77c061bfd936b5adcccd 100644 --- a/configs/train_params/vggt_finetune.yaml +++ b/configs/train_params/vggt_finetune.yaml @@ -1,7 +1,7 @@ defaults: - default -# Use 10x lower lr for finetuning +# Use lower lr for finetuning lr: 1e-05 min_lr: 1e-07 diff --git a/mapanything/datasets/__init__.py b/mapanything/datasets/__init__.py index ca5b413385e67e07ae394402431ce78a3b480300..b01f8285f38a61eaecaf1e6ee1fcc17d2077d23b 100644 --- a/mapanything/datasets/__init__.py +++ b/mapanything/datasets/__init__.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + """ MapAnything Datasets """ @@ -5,14 +10,10 @@ MapAnything Datasets import torch from mapanything.datasets.wai.ase import ASEWAI # noqa -from mapanything.datasets.wai.bedlam import BedlamWAI # noqa from mapanything.datasets.wai.blendedmvs import BlendedMVSWAI # noqa from mapanything.datasets.wai.dl3dv import DL3DVWAI # noqa -from mapanything.datasets.wai.dtu import DTUWAI # noqa from mapanything.datasets.wai.dynamicreplica import DynamicReplicaWAI # noqa from mapanything.datasets.wai.eth3d import ETH3DWAI # noqa -from mapanything.datasets.wai.gta_sfm import GTASfMWAI # noqa -from mapanything.datasets.wai.matrixcity import MatrixCityWAI # noqa from mapanything.datasets.wai.megadepth import MegaDepthWAI # noqa from mapanything.datasets.wai.mpsd import MPSDWAI # noqa from mapanything.datasets.wai.mvs_synth import MVSSynthWAI # noqa @@ -20,10 +21,8 @@ from mapanything.datasets.wai.paralleldomain4d import ParallelDomain4DWAI # noq from mapanything.datasets.wai.sailvos3d import SAILVOS3DWAI # noqa from mapanything.datasets.wai.scannetpp import ScanNetPPWAI # noqa from mapanything.datasets.wai.spring import SpringWAI # noqa -from mapanything.datasets.wai.structured3d import Structured3DWAI # noqa from mapanything.datasets.wai.tav2_wb import TartanAirV2WBWAI # noqa from mapanything.datasets.wai.unrealstereo4k import UnrealStereo4KWAI # noqa -from mapanything.datasets.wai.xrooms import XRoomsWAI # noqa from mapanything.utils.train_tools import get_rank, get_world_size diff --git a/mapanything/datasets/base/base_dataset.py b/mapanything/datasets/base/base_dataset.py index 0a184d200d98b181f68a358ae8dea73c3e3b072f..fbc7d89742a188ebdc8cf41c8d59384d640db64d 100644 --- a/mapanything/datasets/base/base_dataset.py +++ b/mapanything/datasets/base/base_dataset.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + """ Base class for MapAnything datasets. """ @@ -314,7 +319,7 @@ class BaseDataset(EasyDataset): use_bidirectional_covis=True, ): """ - Randomly samples S indices from an N x N covisbility matrix by forming adjacency edges such that the resulting subgraph (given by the indices) is connected. + Randomly samples S indices from an N x N covisibility matrix by forming adjacency edges such that the resulting subgraph (given by the indices) is connected. If the current node has no new unvisited neighbors, backtracking occurs. Retries with different starting indices if the desired number of samples is not reached, excluding previously visited components. @@ -569,7 +574,7 @@ class BaseDataset(EasyDataset): if "non_ambiguous_mask" in view: assert view["depthmap"].shape == view["non_ambiguous_mask"].shape - # Expand the last dimennsion of the depthmap + # Expand the last dimension of the depthmap view["depthmap"] = view["depthmap"][..., None] # Append RNG state to the views, this allows to check whether the RNG is in the same state each time diff --git a/mapanything/datasets/base/batched_sampler.py b/mapanything/datasets/base/batched_sampler.py index 322a66bfadd30b88ea6d14b1f8241dc570263846..9cf98194756c02fda271aeb0eca86f2db477cada 100644 --- a/mapanything/datasets/base/batched_sampler.py +++ b/mapanything/datasets/base/batched_sampler.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + """ Utilities for random sampling under a single or multiple constraints diff --git a/mapanything/datasets/base/easy_dataset.py b/mapanything/datasets/base/easy_dataset.py index 0ac429bf808762e4f83f04af50f562e885a3dd61..4cdaa203596b842bd4b33c0c1188881df2ee8ded 100644 --- a/mapanything/datasets/base/easy_dataset.py +++ b/mapanything/datasets/base/easy_dataset.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + """ Base dataset class that enables easy resizing and combining @@ -165,7 +170,7 @@ class EasyDataset: class MulDataset(EasyDataset): - """Artifically augmenting the size of a dataset.""" + """Artificially augmenting the size of a dataset.""" multiplicator: int @@ -239,7 +244,7 @@ class MulDataset(EasyDataset): class ResizedDataset(EasyDataset): - """Artifically changing the size of a dataset.""" + """Artificially changing the size of a dataset.""" new_size: int diff --git a/mapanything/datasets/utils/data_splits.py b/mapanything/datasets/utils/data_splits.py index 2a69ec6df5a98cd21e15c9de1791be49903d8383..d9103e9093219e8c9e30727721bfeb5dbecf1d14 100644 --- a/mapanything/datasets/utils/data_splits.py +++ b/mapanything/datasets/utils/data_splits.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + """ Modules containing dataset split information """ @@ -1717,18 +1722,6 @@ class DL3DV10KSplits: ] -class DTUSplits: - """ - This class contains the information about the splits of the DTU dataset. - """ - - def __init__(self): - """ - All scenes are in the test split. - """ - self.test_split_scenes = "all" - - class ETH3DSplits: """ This class contains the information about the splits of the ETH3D dataset. diff --git a/mapanything/datasets/wai/ase.py b/mapanything/datasets/wai/ase.py index d3bf49ffe6fc61323265178a583c3585d53824b2..439caae3cba1e88f34816800e54ade35310c7c8e 100644 --- a/mapanything/datasets/wai/ase.py +++ b/mapanything/datasets/wai/ase.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + """ ASE Dataset using WAI format data. """ @@ -85,7 +90,7 @@ class ASEWAI(BaseDataset): covisibility_map_dir = os.path.join( scene_root, "covisibility", covisibility_version_key ) - # Assumes only npy file in directory is covisbility map + # Assumes only npy file in directory is covisibility map covisibility_map_name = next( f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") ) @@ -150,11 +155,13 @@ def get_parser(): import argparse parser = argparse.ArgumentParser() - parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/ase", type=str) + parser.add_argument( + "-rd", "--root_dir", default="/ai4rl/fsx/xrtech/data/ase", type=str + ) parser.add_argument( "-dmd", "--dataset_metadata_dir", - default="/fsx/nkeetha/mapanything_dataset_metadata", + default="/ai4rl/fsx/nkeetha/mapanything_dataset_metadata", type=str, ) parser.add_argument( diff --git a/mapanything/datasets/wai/bedlam.py b/mapanything/datasets/wai/bedlam.py deleted file mode 100644 index d4daa56470f73d8e67f9b6a4c5ba3fd822ef9ff1..0000000000000000000000000000000000000000 --- a/mapanything/datasets/wai/bedlam.py +++ /dev/null @@ -1,309 +0,0 @@ -""" -Bedlam Dataset using WAI format data. -""" - -import os - -import numpy as np - -from mapanything.datasets.base.base_dataset import BaseDataset -from mapanything.utils.wai.core import load_data, load_frame - - -class BedlamWAI(BaseDataset): - """ - Bedlam dataset containing diverse synthetic scenes with humans. - """ - - def __init__( - self, - *args, - ROOT, - dataset_metadata_dir, - split, - overfit_num_sets=None, - sample_specific_scene: bool = False, - specific_scene_name: str = None, - **kwargs, - ): - """ - Initialize the dataset attributes. - Args: - ROOT: Root directory of the dataset. - dataset_metadata_dir: Path to the dataset metadata directory. - split: Dataset split (train, val, test). - overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. - sample_specific_scene: Whether to sample a specific scene from the dataset. - specific_scene_name: Name of the specific scene to sample. - """ - # Initialize the dataset attributes - super().__init__(*args, **kwargs) - self.ROOT = ROOT - self.dataset_metadata_dir = dataset_metadata_dir - self.split = split - self.overfit_num_sets = overfit_num_sets - self.sample_specific_scene = sample_specific_scene - self.specific_scene_name = specific_scene_name - self._load_data() - - # Define the dataset type flags - self.is_metric_scale = True - self.is_synthetic = True - - def _load_data(self): - "Load the precomputed dataset metadata" - # Load the dataset metadata corresponding to the split - split_metadata_path = os.path.join( - self.dataset_metadata_dir, - self.split, - f"bedlam_scene_list_{self.split}.npy", - ) - split_scene_list = np.load(split_metadata_path, allow_pickle=True) - - # Get the list of all scenes - if not self.sample_specific_scene: - self.scenes = list(split_scene_list) - else: - self.scenes = [self.specific_scene_name] - self.num_of_scenes = len(self.scenes) - - def _get_views(self, sampled_idx, num_views_to_sample, resolution): - # Get the scene name of the sampled index - scene_index = sampled_idx - scene_name = self.scenes[scene_index] - - # Get the metadata corresponding to the scene - scene_root = os.path.join(self.ROOT, scene_name) - scene_meta = load_data( - os.path.join(scene_root, "scene_meta.json"), "scene_meta" - ) - scene_file_names = list(scene_meta["frame_names"].keys()) - num_views_in_scene = len(scene_file_names) - - # Load the scene pairwise covisibility mmap - covisibility_version_key = "v0" - covisibility_map_dir = os.path.join( - scene_root, "covisibility", covisibility_version_key - ) - # Assumes only npy file in directory is covisbility map - covisibility_map_name = next( - f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") - ) - covisibility_map_path = os.path.join( - scene_root, "covisibility", covisibility_version_key, covisibility_map_name - ) - pairwise_covisibility = load_data(covisibility_map_path, "mmap") - - # Get the indices of the N views in the scene - # Bedlam scenes have very large number of images - # Thus, we use unidirectional covis for faster access - view_indices = self._sample_view_indices( - num_views_to_sample, - num_views_in_scene, - pairwise_covisibility, - use_bidirectional_covis=False, - ) - - # Get the views corresponding to the selected view indices - views = [] - for view_index in view_indices: - # Load the data corresponding to the view - view_file_name = scene_file_names[view_index] - view_data = load_frame( - scene_root, - view_file_name, - modalities=["image", "depth"], - scene_meta=scene_meta, - ) - - # Convert necessary data to numpy - image = view_data["image"].permute(1, 2, 0).numpy() - image = image[:, :, :3] # RGBA to RGB - image = (image * 255).astype(np.uint8) - depthmap = view_data["depth"].numpy().astype(np.float32) - intrinsics = view_data["intrinsics"].numpy().astype(np.float32) - c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) - - # Ensure that the depthmap has all valid values - depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) - - # Get the non ambiguous mask (zero depth pixels are sky or ambiguous) - non_ambiguous_mask = (depthmap > 0).astype(int) - - # Mask out the outlier depth (see through window or horizon depth) - percentile_depth = np.percentile(depthmap, 95) - depthmap[depthmap > percentile_depth] = 0 - - # Resize the data to match the desired resolution - additional_quantities_to_resize = [non_ambiguous_mask] - image, depthmap, intrinsics, additional_quantities_to_resize = ( - self._crop_resize_if_necessary( - image=image, - resolution=resolution, - depthmap=depthmap, - intrinsics=intrinsics, - additional_quantities=additional_quantities_to_resize, - ) - ) - non_ambiguous_mask = additional_quantities_to_resize[0] - - # Append the view dictionary to the list of views - views.append( - dict( - img=image, - depthmap=depthmap, - camera_pose=c2w_pose, # cam2world - camera_intrinsics=intrinsics, - non_ambiguous_mask=non_ambiguous_mask, - dataset="Bedlam", - label=scene_name, - instance=os.path.join("images", str(view_file_name)), - ) - ) - - return views - - -def get_parser(): - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument( - "-rd", "--root_dir", default="/fsx/xrtech/data/bedlam", type=str - ) - parser.add_argument( - "-dmd", - "--dataset_metadata_dir", - default="/fsx/nkeetha/mapanything_dataset_metadata", - type=str, - ) - parser.add_argument( - "-nv", - "--num_of_views", - default=2, - type=int, - ) - parser.add_argument("--viz", action="store_true") - - return parser - - -if __name__ == "__main__": - import rerun as rr - from tqdm import tqdm - - from mapanything.datasets.base.base_dataset import view_name - from mapanything.utils.image import rgb - from mapanything.utils.viz import script_add_rerun_args - - parser = get_parser() - script_add_rerun_args( - parser - ) # Options: --headless, --connect, --serve, --addr, --save, --stdout - args = parser.parse_args() - - dataset = BedlamWAI( - num_views=args.num_of_views, - split="train", - covisibility_thres=0.25, - ROOT=args.root_dir, - dataset_metadata_dir=args.dataset_metadata_dir, - resolution=(518, 294), - aug_crop=16, - transform="colorjitter+grayscale+gaublur", - data_norm_type="dinov2", - ) - # dataset = BedlamWAI( - # num_views=args.num_of_views, - # split="val", - # covisibility_thres=0.25, - # ROOT=args.root_dir, - # dataset_metadata_dir=args.dataset_metadata_dir, - # resolution=(518, 294), - # seed=777, - # transform="imgnorm", - # data_norm_type="dinov2", - # ) - print(dataset.get_stats()) - - if args.viz: - rr.script_setup(args, "Bedlam_Dataloader") - rr.set_time("stable_time", sequence=0) - rr.log("world", rr.ViewCoordinates.RDF, static=True) - - sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) - - for num, idx in enumerate(tqdm(sampled_indices)): - views = dataset[idx] - assert len(views) == args.num_of_views - sample_name = f"{idx}" - for view_idx in range(args.num_of_views): - sample_name += f" {view_name(views[view_idx])}" - print(sample_name) - for view_idx in range(args.num_of_views): - image = rgb( - views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] - ) - depthmap = views[view_idx]["depthmap"] - pose = views[view_idx]["camera_pose"] - intrinsics = views[view_idx]["camera_intrinsics"] - pts3d = views[view_idx]["pts3d"] - valid_mask = views[view_idx]["valid_mask"] - if "non_ambiguous_mask" in views[view_idx]: - non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] - else: - non_ambiguous_mask = None - if "prior_depth_along_ray" in views[view_idx]: - prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] - else: - prior_depth_along_ray = None - if args.viz: - rr.set_time("stable_time", sequence=num) - base_name = f"world/view_{view_idx}" - pts_name = f"world/view_{view_idx}_pointcloud" - # Log camera info and loaded data - height, width = image.shape[0], image.shape[1] - rr.log( - base_name, - rr.Transform3D( - translation=pose[:3, 3], - mat3x3=pose[:3, :3], - ), - ) - rr.log( - f"{base_name}/pinhole", - rr.Pinhole( - image_from_camera=intrinsics, - height=height, - width=width, - camera_xyz=rr.ViewCoordinates.RDF, - ), - ) - rr.log( - f"{base_name}/pinhole/rgb", - rr.Image(image), - ) - rr.log( - f"{base_name}/pinhole/depth", - rr.DepthImage(depthmap), - ) - if prior_depth_along_ray is not None: - rr.log( - f"prior_depth_along_ray_{view_idx}", - rr.DepthImage(prior_depth_along_ray), - ) - if non_ambiguous_mask is not None: - rr.log( - f"{base_name}/pinhole/non_ambiguous_mask", - rr.SegmentationImage(non_ambiguous_mask.astype(int)), - ) - # Log points in 3D - filtered_pts = pts3d[valid_mask] - filtered_pts_col = image[valid_mask] - rr.log( - pts_name, - rr.Points3D( - positions=filtered_pts.reshape(-1, 3), - colors=filtered_pts_col.reshape(-1, 3), - ), - ) diff --git a/mapanything/datasets/wai/blendedmvs.py b/mapanything/datasets/wai/blendedmvs.py index 8ef278842807b667f478ba4b1f013bb702173632..bbd2368bda93b88d180e0ce3fd53ddd15f6359ed 100644 --- a/mapanything/datasets/wai/blendedmvs.py +++ b/mapanything/datasets/wai/blendedmvs.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + """ BlendedMVS Dataset using WAI format data. """ @@ -86,7 +91,7 @@ class BlendedMVSWAI(BaseDataset): covisibility_map_dir = os.path.join( scene_root, "covisibility", covisibility_version_key ) - # Assumes only npy file in directory is covisbility map + # Assumes only npy file in directory is covisibility map covisibility_map_name = next( f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") ) @@ -168,12 +173,12 @@ def get_parser(): parser = argparse.ArgumentParser() parser.add_argument( - "-rd", "--root_dir", default="/fsx/xrtech/data/blendedmvs", type=str + "-rd", "--root_dir", default="/ai4rl/fsx/xrtech/data/blendedmvs", type=str ) parser.add_argument( "-dmd", "--dataset_metadata_dir", - default="/fsx/nkeetha/mapanything_dataset_metadata", + default="/ai4rl/fsx/nkeetha/mapanything_dataset_metadata", type=str, ) parser.add_argument( diff --git a/mapanything/datasets/wai/dl3dv.py b/mapanything/datasets/wai/dl3dv.py index ab9bf341a231e77d6a3dde7b7c3e203638b223a7..1903c7acc251d1b208a08036bd2ad6fecec2fe06 100644 --- a/mapanything/datasets/wai/dl3dv.py +++ b/mapanything/datasets/wai/dl3dv.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + """ DL3DV Dataset using WAI format data. """ @@ -93,7 +98,7 @@ class DL3DVWAI(BaseDataset): covisibility_map_dir = os.path.join( scene_root, "covisibility", covisibility_version_key ) - # Assumes only npy file in directory is covisbility map + # Assumes only npy file in directory is covisibility map covisibility_map_name = next( f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") ) @@ -210,11 +215,13 @@ def get_parser(): import argparse parser = argparse.ArgumentParser() - parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/dl3dv", type=str) + parser.add_argument( + "-rd", "--root_dir", default="/ai4rl/fsx/xrtech/data/dl3dv", type=str + ) parser.add_argument( "-dmd", "--dataset_metadata_dir", - default="/fsx/nkeetha/mapanything_dataset_metadata", + default="/ai4rl/fsx/nkeetha/mapanything_dataset_metadata", type=str, ) parser.add_argument( diff --git a/mapanything/datasets/wai/dtu.py b/mapanything/datasets/wai/dtu.py deleted file mode 100644 index fde4a666752c36a2668913c010d42fb93d62dd3b..0000000000000000000000000000000000000000 --- a/mapanything/datasets/wai/dtu.py +++ /dev/null @@ -1,272 +0,0 @@ -""" -DTU Dataset using WAI format data. -""" - -import os - -import numpy as np - -from mapanything.datasets.base.base_dataset import BaseDataset -from mapanything.utils.wai.core import load_data, load_frame - - -class DTUWAI(BaseDataset): - """ - DTU dataset containing high-quality multi-view stereo object scans. - """ - - def __init__( - self, - *args, - ROOT, - dataset_metadata_dir, - overfit_num_sets=None, - sample_specific_scene: bool = False, - specific_scene_name: str = None, - **kwargs, - ): - """ - Initialize the dataset attributes. - Args: - ROOT: Root directory of the dataset. - dataset_metadata_dir: Path to the dataset metadata directory. - overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. - sample_specific_scene: Whether to sample a specific scene from the dataset. - specific_scene_name: Name of the specific scene to sample. - """ - # Initialize the dataset attributes - super().__init__(*args, **kwargs) - self.ROOT = ROOT - self.dataset_metadata_dir = dataset_metadata_dir - self.split = "test" - self.overfit_num_sets = overfit_num_sets - self.sample_specific_scene = sample_specific_scene - self.specific_scene_name = specific_scene_name - self._load_data() - - # Define the dataset type flags - self.is_metric_scale = False - self.is_synthetic = False - - def _load_data(self): - "Load the precomputed dataset metadata" - # Load the dataset metadata corresponding to the split - split_metadata_path = os.path.join( - self.dataset_metadata_dir, - self.split, - f"dtu_scene_list_{self.split}.npy", - ) - split_scene_list = np.load(split_metadata_path, allow_pickle=True) - - # Get the list of all scenes - if not self.sample_specific_scene: - self.scenes = list(split_scene_list) - else: - self.scenes = [self.specific_scene_name] - self.num_of_scenes = len(self.scenes) - - def _get_views(self, sampled_idx, num_views_to_sample, resolution): - # Get the scene name of the sampled index - scene_index = sampled_idx - scene_name = self.scenes[scene_index] - - # Get the metadata corresponding to the scene - scene_root = os.path.join(self.ROOT, scene_name) - scene_meta = load_data( - os.path.join(scene_root, "scene_meta.json"), "scene_meta" - ) - scene_file_names = list(scene_meta["frame_names"].keys()) - num_views_in_scene = len(scene_file_names) - - # Load the scene pairwise covisibility mmap - covisibility_version_key = "v0" - covisibility_map_dir = os.path.join( - scene_root, "covisibility", covisibility_version_key - ) - # Assumes only npy file in directory is covisbility map - covisibility_map_name = next( - f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") - ) - covisibility_map_path = os.path.join( - scene_root, "covisibility", covisibility_version_key, covisibility_map_name - ) - pairwise_covisibility = load_data(covisibility_map_path, "mmap") - - # Get the indices of the N views in the scene - view_indices = self._sample_view_indices( - num_views_to_sample, num_views_in_scene, pairwise_covisibility - ) - - # Get the views corresponding to the selected view indices - views = [] - for view_index in view_indices: - # Load the data corresponding to the view - view_file_name = scene_file_names[view_index] - view_data = load_frame( - scene_root, - view_file_name, - modalities=["image", "depth"], - scene_meta=scene_meta, - ) - - # Convert necessary data to numpy - image = view_data["image"].permute(1, 2, 0).numpy() - image = (image * 255).astype(np.uint8) - depthmap = view_data["depth"].numpy().astype(np.float32) - intrinsics = view_data["intrinsics"].numpy().astype(np.float32) - c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) - - # Resize the data to match the desired resolution - image, depthmap, intrinsics = self._crop_resize_if_necessary( - image=image, - resolution=resolution, - depthmap=depthmap, - intrinsics=intrinsics, - additional_quantities=None, - ) - - # Append the view dictionary to the list of views - views.append( - dict( - img=image, - depthmap=depthmap, - camera_pose=c2w_pose, # cam2world - camera_intrinsics=intrinsics, - dataset="DTU", - label=scene_name, - instance=os.path.join("images", str(view_file_name)), - ) - ) - - return views - - -def get_parser(): - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/dtu", type=str) - parser.add_argument( - "-dmd", - "--dataset_metadata_dir", - default="/fsx/nkeetha/mapanything_dataset_metadata", - type=str, - ) - parser.add_argument( - "-nv", - "--num_of_views", - default=2, - type=int, - ) - parser.add_argument("--viz", action="store_true") - - return parser - - -if __name__ == "__main__": - import rerun as rr - from tqdm import tqdm - - from mapanything.datasets.base.base_dataset import view_name - from mapanything.utils.image import rgb - from mapanything.utils.viz import script_add_rerun_args - - parser = get_parser() - script_add_rerun_args( - parser - ) # Options: --headless, --connect, --serve, --addr, --save, --stdout - args = parser.parse_args() - - dataset = DTUWAI( - num_views=args.num_of_views, - covisibility_thres=0.25, - ROOT=args.root_dir, - dataset_metadata_dir=args.dataset_metadata_dir, - resolution=(518, 392), - seed=777, - transform="imgnorm", - data_norm_type="dinov2", - ) - print(dataset.get_stats()) - - if args.viz: - rr.script_setup(args, "DTU_Dataloader") - rr.set_time("stable_time", sequence=0) - rr.log("world", rr.ViewCoordinates.RDF, static=True) - - sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) - - for num, idx in enumerate(tqdm(sampled_indices)): - views = dataset[idx] - assert len(views) == args.num_of_views - sample_name = f"{idx}" - for view_idx in range(args.num_of_views): - sample_name += f" {view_name(views[view_idx])}" - print(sample_name) - for view_idx in range(args.num_of_views): - image = rgb( - views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] - ) - depthmap = views[view_idx]["depthmap"] - pose = views[view_idx]["camera_pose"] - intrinsics = views[view_idx]["camera_intrinsics"] - pts3d = views[view_idx]["pts3d"] - valid_mask = views[view_idx]["valid_mask"] - if "non_ambiguous_mask" in views[view_idx]: - non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] - else: - non_ambiguous_mask = None - if "prior_depth_along_ray" in views[view_idx]: - prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] - else: - prior_depth_along_ray = None - if args.viz: - rr.set_time("stable_time", sequence=num) - base_name = f"world/view_{view_idx}" - pts_name = f"world/view_{view_idx}_pointcloud" - # Log camera info and loaded data - height, width = image.shape[0], image.shape[1] - rr.log( - base_name, - rr.Transform3D( - translation=pose[:3, 3], - mat3x3=pose[:3, :3], - ), - ) - rr.log( - f"{base_name}/pinhole", - rr.Pinhole( - image_from_camera=intrinsics, - height=height, - width=width, - camera_xyz=rr.ViewCoordinates.RDF, - ), - ) - rr.log( - f"{base_name}/pinhole/rgb", - rr.Image(image), - ) - rr.log( - f"{base_name}/pinhole/depth", - rr.DepthImage(depthmap), - ) - if prior_depth_along_ray is not None: - rr.log( - f"prior_depth_along_ray_{view_idx}", - rr.DepthImage(prior_depth_along_ray), - ) - if non_ambiguous_mask is not None: - rr.log( - f"{base_name}/pinhole/non_ambiguous_mask", - rr.SegmentationImage(non_ambiguous_mask.astype(int)), - ) - # Log points in 3D - filtered_pts = pts3d[valid_mask] - filtered_pts_col = image[valid_mask] - rr.log( - pts_name, - rr.Points3D( - positions=filtered_pts.reshape(-1, 3), - colors=filtered_pts_col.reshape(-1, 3), - ), - ) diff --git a/mapanything/datasets/wai/dynamicreplica.py b/mapanything/datasets/wai/dynamicreplica.py index 20d1ce109631f394c6dd5469f2439b0de5e52e97..363e491760d841de33c58049529a330eb4ae47e7 100644 --- a/mapanything/datasets/wai/dynamicreplica.py +++ b/mapanything/datasets/wai/dynamicreplica.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + """ Dynamic Replica Dataset using WAI format data. """ @@ -85,7 +90,7 @@ class DynamicReplicaWAI(BaseDataset): covisibility_map_dir = os.path.join( scene_root, "covisibility", covisibility_version_key ) - # Assumes only npy file in directory is covisbility map + # Assumes only npy file in directory is covisibility map covisibility_map_name = next( f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") ) @@ -152,12 +157,12 @@ def get_parser(): parser = argparse.ArgumentParser() parser.add_argument( - "-rd", "--root_dir", default="/fsx/xrtech/data/dynamicreplica", type=str + "-rd", "--root_dir", default="/ai4rl/fsx/xrtech/data/dynamicreplica", type=str ) parser.add_argument( "-dmd", "--dataset_metadata_dir", - default="/fsx/nkeetha/mapanything_dataset_metadata", + default="/ai4rl/fsx/nkeetha/mapanything_dataset_metadata", type=str, ) parser.add_argument( diff --git a/mapanything/datasets/wai/eth3d.py b/mapanything/datasets/wai/eth3d.py index 51a8cebe347aeb18031379f31fff4f6581f9802b..8e61f6137a93f5940efcb11738095523f04707cb 100644 --- a/mapanything/datasets/wai/eth3d.py +++ b/mapanything/datasets/wai/eth3d.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + """ ETH3D Dataset using WAI format data. """ @@ -83,7 +88,7 @@ class ETH3DWAI(BaseDataset): covisibility_map_dir = os.path.join( scene_root, "covisibility", covisibility_version_key ) - # Assumes only npy file in directory is covisbility map + # Assumes only npy file in directory is covisibility map covisibility_map_name = next( f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") ) @@ -145,11 +150,13 @@ def get_parser(): import argparse parser = argparse.ArgumentParser() - parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/eth3d", type=str) + parser.add_argument( + "-rd", "--root_dir", default="/ai4rl/fsx/xrtech/data/eth3d", type=str + ) parser.add_argument( "-dmd", "--dataset_metadata_dir", - default="/fsx/nkeetha/mapanything_dataset_metadata", + default="/ai4rl/fsx/nkeetha/mapanything_dataset_metadata", type=str, ) parser.add_argument( diff --git a/mapanything/datasets/wai/gta_sfm.py b/mapanything/datasets/wai/gta_sfm.py deleted file mode 100644 index 06122b8f305dd1a63dfb7f992914fb334aa63f56..0000000000000000000000000000000000000000 --- a/mapanything/datasets/wai/gta_sfm.py +++ /dev/null @@ -1,303 +0,0 @@ -""" -GTA SfM Dataset using WAI format data. -""" - -import os - -import numpy as np - -from mapanything.datasets.base.base_dataset import BaseDataset -from mapanything.utils.wai.core import load_data, load_frame - - -class GTASfMWAI(BaseDataset): - """ - GTA SfM dataset containing large diversity of synthetic in-the-wild scenes. - """ - - def __init__( - self, - *args, - ROOT, - dataset_metadata_dir, - split, - overfit_num_sets=None, - sample_specific_scene: bool = False, - specific_scene_name: str = None, - **kwargs, - ): - """ - Initialize the dataset attributes. - Args: - ROOT: Root directory of the dataset. - dataset_metadata_dir: Path to the dataset metadata directory. - split: Dataset split (train, val, test). - overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. - sample_specific_scene: Whether to sample a specific scene from the dataset. - specific_scene_name: Name of the specific scene to sample. - """ - # Initialize the dataset attributes - super().__init__(*args, **kwargs) - self.ROOT = ROOT - self.dataset_metadata_dir = dataset_metadata_dir - self.split = split - self.overfit_num_sets = overfit_num_sets - self.sample_specific_scene = sample_specific_scene - self.specific_scene_name = specific_scene_name - self._load_data() - - # Define the dataset type flags - self.is_metric_scale = True - self.is_synthetic = True - - def _load_data(self): - "Load the precomputed dataset metadata" - # Load the dataset metadata corresponding to the split - split_metadata_path = os.path.join( - self.dataset_metadata_dir, - self.split, - f"gta_sfm_scene_list_{self.split}.npy", - ) - split_scene_list = np.load(split_metadata_path, allow_pickle=True) - - # Get the list of all scenes - if not self.sample_specific_scene: - self.scenes = list(split_scene_list) - else: - self.scenes = [self.specific_scene_name] - self.num_of_scenes = len(self.scenes) - - def _get_views(self, sampled_idx, num_views_to_sample, resolution): - # Get the scene name of the sampled index - scene_index = sampled_idx - scene_name = self.scenes[scene_index] - - # Get the metadata corresponding to the scene - scene_root = os.path.join(self.ROOT, scene_name) - scene_meta = load_data( - os.path.join(scene_root, "scene_meta.json"), "scene_meta" - ) - scene_file_names = list(scene_meta["frame_names"].keys()) - num_views_in_scene = len(scene_file_names) - - # Load the scene pairwise covisibility mmap - covisibility_version_key = "v0" - covisibility_map_dir = os.path.join( - scene_root, "covisibility", covisibility_version_key - ) - # Assumes only npy file in directory is covisbility map - covisibility_map_name = next( - f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") - ) - covisibility_map_path = os.path.join( - scene_root, "covisibility", covisibility_version_key, covisibility_map_name - ) - pairwise_covisibility = load_data(covisibility_map_path, "mmap") - - # Get the indices of the N views in the scene - view_indices = self._sample_view_indices( - num_views_to_sample, num_views_in_scene, pairwise_covisibility - ) - - # Get the views corresponding to the selected view indices - views = [] - for view_index in view_indices: - # Load the data corresponding to the view - view_file_name = scene_file_names[view_index] - view_data = load_frame( - scene_root, - view_file_name, - modalities=["image", "depth"], - scene_meta=scene_meta, - ) - - # Convert necessary data to numpy - image = view_data["image"].permute(1, 2, 0).numpy() - image = (image * 255).astype(np.uint8) - depthmap = view_data["depth"].numpy().astype(np.float32) - intrinsics = view_data["intrinsics"].numpy().astype(np.float32) - c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) - - # Ensure that the depthmap has all valid values - depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) - - # Get the non ambiguous mask (zero depth pixels are sky or ambiguous) - non_ambiguous_mask = (depthmap > 0).astype(int) - - # Mask out the outlier depth (horizon depth) - percentile_depth = np.percentile(depthmap, 95) - depthmap[depthmap > percentile_depth] = 0 - - # Resize the data to match the desired resolution - additional_quantities_to_resize = [non_ambiguous_mask] - image, depthmap, intrinsics, additional_quantities_to_resize = ( - self._crop_resize_if_necessary( - image=image, - resolution=resolution, - depthmap=depthmap, - intrinsics=intrinsics, - additional_quantities=additional_quantities_to_resize, - ) - ) - non_ambiguous_mask = additional_quantities_to_resize[0] - - # Append the view dictionary to the list of views - views.append( - dict( - img=image, - depthmap=depthmap, - camera_pose=c2w_pose, # cam2world - camera_intrinsics=intrinsics, - non_ambiguous_mask=non_ambiguous_mask, - dataset="GTASfM", - label=scene_name, - instance=os.path.join("images", str(view_file_name)), - ) - ) - - return views - - -def get_parser(): - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument( - "-rd", "--root_dir", default="/fsx/xrtech/data/gta_sfm", type=str - ) - parser.add_argument( - "-dmd", - "--dataset_metadata_dir", - default="/fsx/nkeetha/mapanything_dataset_metadata", - type=str, - ) - parser.add_argument( - "-nv", - "--num_of_views", - default=2, - type=int, - ) - parser.add_argument("--viz", action="store_true") - - return parser - - -if __name__ == "__main__": - import rerun as rr - from tqdm import tqdm - - from mapanything.datasets.base.base_dataset import view_name - from mapanything.utils.image import rgb - from mapanything.utils.viz import script_add_rerun_args - - parser = get_parser() - script_add_rerun_args( - parser - ) # Options: --headless, --connect, --serve, --addr, --save, --stdout - args = parser.parse_args() - - dataset = GTASfMWAI( - num_views=args.num_of_views, - split="train", - covisibility_thres=0.25, - ROOT=args.root_dir, - dataset_metadata_dir=args.dataset_metadata_dir, - resolution=(518, 392), - aug_crop=16, - transform="colorjitter+grayscale+gaublur", - data_norm_type="dinov2", - ) - # dataset = GTASfMWAI( - # num_views=args.num_of_views, - # split="val", - # covisibility_thres=0.25, - # ROOT=args.root_dir, - # dataset_metadata_dir=args.dataset_metadata_dir, - # resolution=(518, 392), - # seed=777, - # transform="imgnorm", - # data_norm_type="dinov2", - # ) - print(dataset.get_stats()) - - if args.viz: - rr.script_setup(args, "GTASfM_Dataloader") - rr.set_time("stable_time", sequence=0) - rr.log("world", rr.ViewCoordinates.RDF, static=True) - - sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) - - for num, idx in enumerate(tqdm(sampled_indices)): - views = dataset[idx] - assert len(views) == args.num_of_views - sample_name = f"{idx}" - for view_idx in range(args.num_of_views): - sample_name += f" {view_name(views[view_idx])}" - print(sample_name) - for view_idx in range(args.num_of_views): - image = rgb( - views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] - ) - depthmap = views[view_idx]["depthmap"] - pose = views[view_idx]["camera_pose"] - intrinsics = views[view_idx]["camera_intrinsics"] - pts3d = views[view_idx]["pts3d"] - valid_mask = views[view_idx]["valid_mask"] - if "non_ambiguous_mask" in views[view_idx]: - non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] - else: - non_ambiguous_mask = None - if "prior_depth_along_ray" in views[view_idx]: - prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] - else: - prior_depth_along_ray = None - if args.viz: - rr.set_time("stable_time", sequence=num) - base_name = f"world/view_{view_idx}" - pts_name = f"world/view_{view_idx}_pointcloud" - # Log camera info and loaded data - height, width = image.shape[0], image.shape[1] - rr.log( - base_name, - rr.Transform3D( - translation=pose[:3, 3], - mat3x3=pose[:3, :3], - ), - ) - rr.log( - f"{base_name}/pinhole", - rr.Pinhole( - image_from_camera=intrinsics, - height=height, - width=width, - camera_xyz=rr.ViewCoordinates.RDF, - ), - ) - rr.log( - f"{base_name}/pinhole/rgb", - rr.Image(image), - ) - rr.log( - f"{base_name}/pinhole/depth", - rr.DepthImage(depthmap), - ) - if prior_depth_along_ray is not None: - rr.log( - f"prior_depth_along_ray_{view_idx}", - rr.DepthImage(prior_depth_along_ray), - ) - if non_ambiguous_mask is not None: - rr.log( - f"{base_name}/pinhole/non_ambiguous_mask", - rr.SegmentationImage(non_ambiguous_mask.astype(int)), - ) - # Log points in 3D - filtered_pts = pts3d[valid_mask] - filtered_pts_col = image[valid_mask] - rr.log( - pts_name, - rr.Points3D( - positions=filtered_pts.reshape(-1, 3), - colors=filtered_pts_col.reshape(-1, 3), - ), - ) diff --git a/mapanything/datasets/wai/matrixcity.py b/mapanything/datasets/wai/matrixcity.py deleted file mode 100644 index ff1b9262d81c5e79b11cc9720bb3750b89c475c5..0000000000000000000000000000000000000000 --- a/mapanything/datasets/wai/matrixcity.py +++ /dev/null @@ -1,307 +0,0 @@ -""" -Matrix City Dataset using WAI format data. -""" - -import os - -import numpy as np - -from mapanything.datasets.base.base_dataset import BaseDataset -from mapanything.utils.wai.core import load_data, load_frame - - -class MatrixCityWAI(BaseDataset): - """ - Matrix City dataset containing large scale aerial & street-view urban synthetic scenes. - Depth maps are antialiased and there are floaters at all object boundaries due to interpolation. - https://github.com/city-super/MatrixCity/issues/4#issuecomment-3027961575 - Normal based edge masking doesn't fix this issue completely. - """ - - def __init__( - self, - *args, - ROOT, - dataset_metadata_dir, - split, - overfit_num_sets=None, - sample_specific_scene: bool = False, - specific_scene_name: str = None, - **kwargs, - ): - """ - Initialize the dataset attributes. - Args: - ROOT: Root directory of the dataset. - dataset_metadata_dir: Path to the dataset metadata directory. - split: Dataset split (train, val, test). - overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. - sample_specific_scene: Whether to sample a specific scene from the dataset. - specific_scene_name: Name of the specific scene to sample. - """ - # Initialize the dataset attributes - super().__init__(*args, **kwargs) - self.ROOT = ROOT - self.dataset_metadata_dir = dataset_metadata_dir - self.split = split - self.overfit_num_sets = overfit_num_sets - self.sample_specific_scene = sample_specific_scene - self.specific_scene_name = specific_scene_name - self._load_data() - - # Define the dataset type flags - self.is_metric_scale = True - self.is_synthetic = True - - def _load_data(self): - "Load the precomputed dataset metadata" - # Load the dataset metadata corresponding to the split - split_metadata_path = os.path.join( - self.dataset_metadata_dir, - self.split, - f"matrixcity_scene_list_{self.split}.npy", - ) - split_scene_list = np.load(split_metadata_path, allow_pickle=True) - - # Get the list of all scenes - if not self.sample_specific_scene: - self.scenes = list(split_scene_list) - else: - self.scenes = [self.specific_scene_name] - self.num_of_scenes = len(self.scenes) - - def _get_views(self, sampled_idx, num_views_to_sample, resolution): - # Get the scene name of the sampled index - scene_index = sampled_idx - scene_name = self.scenes[scene_index] - - # Get the metadata corresponding to the scene - scene_root = os.path.join(self.ROOT, scene_name) - scene_meta = load_data( - os.path.join(scene_root, "scene_meta.json"), "scene_meta" - ) - scene_file_names = list(scene_meta["frame_names"].keys()) - num_views_in_scene = len(scene_file_names) - - # Load the scene pairwise covisibility mmap - covisibility_version_key = "v0" - covisibility_map_dir = os.path.join( - scene_root, "covisibility", covisibility_version_key - ) - # Assumes only npy file in directory is covisbility map - covisibility_map_name = next( - f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") - ) - covisibility_map_path = os.path.join( - scene_root, "covisibility", covisibility_version_key, covisibility_map_name - ) - pairwise_covisibility = load_data(covisibility_map_path, "mmap") - - # Get the indices of the N views in the scene - view_indices = self._sample_view_indices( - num_views_to_sample, num_views_in_scene, pairwise_covisibility - ) - - # Get the views corresponding to the selected view indices - views = [] - for view_index in view_indices: - # Load the data corresponding to the view - view_file_name = scene_file_names[view_index] - view_data = load_frame( - scene_root, - view_file_name, - modalities=["image", "depth"], - scene_meta=scene_meta, - ) - - # Convert necessary data to numpy - image = view_data["image"].permute(1, 2, 0).numpy() - image = image[:, :, :3] # RGBA to RGB - image = (image * 255).astype(np.uint8) - depthmap = view_data["depth"].numpy().astype(np.float32) - intrinsics = view_data["intrinsics"].numpy().astype(np.float32) - c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) - - # Ensure that the depthmap has all valid values - depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) - - # Get the non ambiguous mask (zero depth pixels are sky or ambiguous) - non_ambiguous_mask = (depthmap > 0).astype(int) - - # Mask out the outlier depth (horizon depth) - percentile_depth = np.percentile(depthmap, 95) - depthmap[depthmap > percentile_depth] = 0 - - # Resize the data to match the desired resolution - additional_quantities_to_resize = [non_ambiguous_mask] - image, depthmap, intrinsics, additional_quantities_to_resize = ( - self._crop_resize_if_necessary( - image=image, - resolution=resolution, - depthmap=depthmap, - intrinsics=intrinsics, - additional_quantities=additional_quantities_to_resize, - ) - ) - non_ambiguous_mask = additional_quantities_to_resize[0] - - # Append the view dictionary to the list of views - views.append( - dict( - img=image, - depthmap=depthmap, - camera_pose=c2w_pose, # cam2world - camera_intrinsics=intrinsics, - non_ambiguous_mask=non_ambiguous_mask, - dataset="MatrixCity", - label=scene_name, - instance=os.path.join("images", str(view_file_name)), - ) - ) - - return views - - -def get_parser(): - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument( - "-rd", "--root_dir", default="/fsx/xrtech/data/matrixcity", type=str - ) - parser.add_argument( - "-dmd", - "--dataset_metadata_dir", - default="/fsx/nkeetha/mapanything_dataset_metadata", - type=str, - ) - parser.add_argument( - "-nv", - "--num_of_views", - default=2, - type=int, - ) - parser.add_argument("--viz", action="store_true") - - return parser - - -if __name__ == "__main__": - import rerun as rr - from tqdm import tqdm - - from mapanything.datasets.base.base_dataset import view_name - from mapanything.utils.image import rgb - from mapanything.utils.viz import script_add_rerun_args - - parser = get_parser() - script_add_rerun_args( - parser - ) # Options: --headless, --connect, --serve, --addr, --save, --stdout - args = parser.parse_args() - - dataset = MatrixCityWAI( - num_views=args.num_of_views, - split="train", - covisibility_thres=0.25, - ROOT=args.root_dir, - dataset_metadata_dir=args.dataset_metadata_dir, - resolution=(518, 294), - aug_crop=16, - transform="colorjitter+grayscale+gaublur", - data_norm_type="dinov2", - ) - # dataset = MatrixCityWAI( - # num_views=args.num_of_views, - # split="val", - # covisibility_thres=0.25, - # ROOT=args.root_dir, - # dataset_metadata_dir=args.dataset_metadata_dir, - # resolution=(518, 294), - # seed=777, - # transform="imgnorm", - # data_norm_type="dinov2", - # ) - print(dataset.get_stats()) - - if args.viz: - rr.script_setup(args, "MatrixCity_Dataloader") - rr.set_time("stable_time", sequence=0) - rr.log("world", rr.ViewCoordinates.RDF, static=True) - - sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) - - for num, idx in enumerate(tqdm(sampled_indices)): - views = dataset[idx] - assert len(views) == args.num_of_views - sample_name = f"{idx}" - for view_idx in range(args.num_of_views): - sample_name += f" {view_name(views[view_idx])}" - print(sample_name) - for view_idx in range(args.num_of_views): - image = rgb( - views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] - ) - depthmap = views[view_idx]["depthmap"] - pose = views[view_idx]["camera_pose"] - intrinsics = views[view_idx]["camera_intrinsics"] - pts3d = views[view_idx]["pts3d"] - valid_mask = views[view_idx]["valid_mask"] - if "non_ambiguous_mask" in views[view_idx]: - non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] - else: - non_ambiguous_mask = None - if "prior_depth_along_ray" in views[view_idx]: - prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] - else: - prior_depth_along_ray = None - if args.viz: - rr.set_time("stable_time", sequence=num) - base_name = f"world/view_{view_idx}" - pts_name = f"world/view_{view_idx}_pointcloud" - # Log camera info and loaded data - height, width = image.shape[0], image.shape[1] - rr.log( - base_name, - rr.Transform3D( - translation=pose[:3, 3], - mat3x3=pose[:3, :3], - ), - ) - rr.log( - f"{base_name}/pinhole", - rr.Pinhole( - image_from_camera=intrinsics, - height=height, - width=width, - camera_xyz=rr.ViewCoordinates.RDF, - ), - ) - rr.log( - f"{base_name}/pinhole/rgb", - rr.Image(image), - ) - rr.log( - f"{base_name}/pinhole/depth", - rr.DepthImage(depthmap), - ) - if prior_depth_along_ray is not None: - rr.log( - f"prior_depth_along_ray_{view_idx}", - rr.DepthImage(prior_depth_along_ray), - ) - if non_ambiguous_mask is not None: - rr.log( - f"{base_name}/pinhole/non_ambiguous_mask", - rr.SegmentationImage(non_ambiguous_mask.astype(int)), - ) - # Log points in 3D - filtered_pts = pts3d[valid_mask] - filtered_pts_col = image[valid_mask] - rr.log( - pts_name, - rr.Points3D( - positions=filtered_pts.reshape(-1, 3), - colors=filtered_pts_col.reshape(-1, 3), - ), - ) diff --git a/mapanything/datasets/wai/megadepth.py b/mapanything/datasets/wai/megadepth.py index deb87f45872231ad7088fa781a22fd72bed8a084..cc5a6a9193b9e56303df1a8cea9bbad33f6a5b0e 100644 --- a/mapanything/datasets/wai/megadepth.py +++ b/mapanything/datasets/wai/megadepth.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + """ MegaDepth Dataset using WAI format data. """ @@ -87,7 +92,7 @@ class MegaDepthWAI(BaseDataset): covisibility_map_dir = os.path.join( scene_root, "covisibility", covisibility_version_key ) - # Assumes only npy file in directory is covisbility map + # Assumes only npy file in directory is covisibility map covisibility_map_name = next( f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") ) @@ -169,12 +174,12 @@ def get_parser(): parser = argparse.ArgumentParser() parser.add_argument( - "-rd", "--root_dir", default="/fsx/xrtech/data/megadepth", type=str + "-rd", "--root_dir", default="/ai4rl/fsx/xrtech/data/megadepth", type=str ) parser.add_argument( "-dmd", "--dataset_metadata_dir", - default="/fsx/nkeetha/mapanything_dataset_metadata", + default="/ai4rl/fsx/nkeetha/mapanything_dataset_metadata", type=str, ) parser.add_argument( diff --git a/mapanything/datasets/wai/mpsd.py b/mapanything/datasets/wai/mpsd.py index efba1a506fdd9a770d4e94d0ebacd5724ee77ff4..5206912824346a3c6aa30ba9aa20f02b339efe0c 100644 --- a/mapanything/datasets/wai/mpsd.py +++ b/mapanything/datasets/wai/mpsd.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + """ MPSD Dataset using WAI format data. """ @@ -86,7 +91,7 @@ class MPSDWAI(BaseDataset): covisibility_map_dir = os.path.join( scene_root, "covisibility", covisibility_version_key ) - # Assumes only npy file in directory is covisbility map + # Assumes only npy file in directory is covisibility map covisibility_map_name = next( f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") ) @@ -167,11 +172,13 @@ def get_parser(): import argparse parser = argparse.ArgumentParser() - parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/mpsd", type=str) + parser.add_argument( + "-rd", "--root_dir", default="/ai4rl/fsx/xrtech/data/mpsd", type=str + ) parser.add_argument( "-dmd", "--dataset_metadata_dir", - default="/fsx/nkeetha/mapanything_dataset_metadata", + default="/ai4rl/fsx/nkeetha/mapanything_dataset_metadata", type=str, ) parser.add_argument( diff --git a/mapanything/datasets/wai/mvs_synth.py b/mapanything/datasets/wai/mvs_synth.py index ceb6e9c8aa5d430a968626ed67a2abf6cb6ae6b6..9f1f36e1c0a3e763b5db650f72bb84edc56fed0e 100644 --- a/mapanything/datasets/wai/mvs_synth.py +++ b/mapanything/datasets/wai/mvs_synth.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + """ MVS Synth Dataset using WAI format data. """ @@ -85,7 +90,7 @@ class MVSSynthWAI(BaseDataset): covisibility_map_dir = os.path.join( scene_root, "covisibility", covisibility_version_key ) - # Assumes only npy file in directory is covisbility map + # Assumes only npy file in directory is covisibility map covisibility_map_name = next( f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") ) @@ -163,12 +168,12 @@ def get_parser(): parser = argparse.ArgumentParser() parser.add_argument( - "-rd", "--root_dir", default="/fsx/xrtech/data/mvs_synth", type=str + "-rd", "--root_dir", default="/ai4rl/fsx/xrtech/data/mvs_synth", type=str ) parser.add_argument( "-dmd", "--dataset_metadata_dir", - default="/fsx/nkeetha/mapanything_dataset_metadata", + default="/ai4rl/fsx/nkeetha/mapanything_dataset_metadata", type=str, ) parser.add_argument( diff --git a/mapanything/datasets/wai/paralleldomain4d.py b/mapanything/datasets/wai/paralleldomain4d.py index 091a53e83ffa742347f075993afaa63fa3b194be..58adcc710ffc6274ba44c2e0cfa6c141d007ae4b 100644 --- a/mapanything/datasets/wai/paralleldomain4d.py +++ b/mapanything/datasets/wai/paralleldomain4d.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + """ Parallel Domain 4D Dataset using WAI format data. """ @@ -85,7 +90,7 @@ class ParallelDomain4DWAI(BaseDataset): covisibility_map_dir = os.path.join( scene_root, "covisibility", covisibility_version_key ) - # Assumes only npy file in directory is covisbility map + # Assumes only npy file in directory is covisibility map covisibility_map_name = next( f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") ) @@ -164,12 +169,12 @@ def get_parser(): parser = argparse.ArgumentParser() parser.add_argument( - "-rd", "--root_dir", default="/fsx/xrtech/data/paralleldomain4d", type=str + "-rd", "--root_dir", default="/ai4rl/fsx/xrtech/data/paralleldomain4d", type=str ) parser.add_argument( "-dmd", "--dataset_metadata_dir", - default="/fsx/nkeetha/mapanything_dataset_metadata", + default="/ai4rl/fsx/nkeetha/mapanything_dataset_metadata", type=str, ) parser.add_argument( diff --git a/mapanything/datasets/wai/sailvos3d.py b/mapanything/datasets/wai/sailvos3d.py index eff48a6b3a2b5190b4e61386d58ad056bca504b4..bffb051a390c6324e65c0358372299fa2ff2431c 100644 --- a/mapanything/datasets/wai/sailvos3d.py +++ b/mapanything/datasets/wai/sailvos3d.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + """ SAIL-VOS 3D Dataset using WAI format data. """ @@ -85,7 +90,7 @@ class SAILVOS3DWAI(BaseDataset): covisibility_map_dir = os.path.join( scene_root, "covisibility", covisibility_version_key ) - # Assumes only npy file in directory is covisbility map + # Assumes only npy file in directory is covisibility map covisibility_map_name = next( f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") ) @@ -163,12 +168,12 @@ def get_parser(): parser = argparse.ArgumentParser() parser.add_argument( - "-rd", "--root_dir", default="/fsx/xrtech/data/sailvos3d", type=str + "-rd", "--root_dir", default="/ai4rl/fsx/xrtech/data/sailvos3d", type=str ) parser.add_argument( "-dmd", "--dataset_metadata_dir", - default="/fsx/nkeetha/mapanything_dataset_metadata", + default="/ai4rl/fsx/nkeetha/mapanything_dataset_metadata", type=str, ) parser.add_argument( diff --git a/mapanything/datasets/wai/scannetpp.py b/mapanything/datasets/wai/scannetpp.py index 8ecde6dfe22c08162cbe9ffb1e53c0315187c2da..e9ecce3e0d9ef817b15e30482237a3902b7b0d91 100644 --- a/mapanything/datasets/wai/scannetpp.py +++ b/mapanything/datasets/wai/scannetpp.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + """ ScanNet++V2 Dataset using WAI format data. """ @@ -85,7 +90,7 @@ class ScanNetPPWAI(BaseDataset): covisibility_map_dir = os.path.join( scene_root, "covisibility", covisibility_version_key ) - # Assumes only npy file in directory is covisbility map + # Assumes only npy file in directory is covisibility map covisibility_map_name = next( f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") ) @@ -151,12 +156,12 @@ def get_parser(): parser = argparse.ArgumentParser() parser.add_argument( - "-rd", "--root_dir", default="/fsx/xrtech/data/scannetppv2", type=str + "-rd", "--root_dir", default="/ai4rl/fsx/xrtech/data/scannetppv2", type=str ) parser.add_argument( "-dmd", "--dataset_metadata_dir", - default="/fsx/nkeetha/mapanything_dataset_metadata", + default="/ai4rl/fsx/nkeetha/mapanything_dataset_metadata", type=str, ) parser.add_argument( diff --git a/mapanything/datasets/wai/spring.py b/mapanything/datasets/wai/spring.py index 3bbfd375eb03ab662e838505e5159dbf64d8eef2..cf15e18dd31e564000e3ef6192775d4c514ba70d 100644 --- a/mapanything/datasets/wai/spring.py +++ b/mapanything/datasets/wai/spring.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + """ Spring Dataset using WAI format data. """ @@ -88,7 +93,7 @@ class SpringWAI(BaseDataset): ) covisibility_map_name = next( f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") - ) # Assumes only npy file in directory is covisbility map + ) # Assumes only npy file in directory is covisibility map covisibility_map_path = os.path.join( scene_root, "covisibility", covisibility_version_key, covisibility_map_name ) @@ -171,12 +176,12 @@ def get_parser(): parser = argparse.ArgumentParser() parser.add_argument( - "-rd", "--root_dir", default="/fsx/xrtech/data/spring", type=str + "-rd", "--root_dir", default="/ai4rl/fsx/xrtech/data/spring", type=str ) parser.add_argument( "-dmd", "--dataset_metadata_dir", - default="/fsx/nkeetha/mapanything_dataset_metadata", + default="/ai4rl/fsx/nkeetha/mapanything_dataset_metadata", type=str, ) parser.add_argument( diff --git a/mapanything/datasets/wai/structured3d.py b/mapanything/datasets/wai/structured3d.py deleted file mode 100644 index d8b54723714b4f3da530a236684b45cc8e6f7bf7..0000000000000000000000000000000000000000 --- a/mapanything/datasets/wai/structured3d.py +++ /dev/null @@ -1,292 +0,0 @@ -""" -Structured3D Dataset using WAI format data. -""" - -import os - -import numpy as np - -from mapanything.datasets.base.base_dataset import BaseDataset -from mapanything.utils.wai.core import load_data, load_frame - - -class Structured3DWAI(BaseDataset): - """ - Structured3D dataset containing large diversity of synthetic multi-room indoor scenes. - """ - - def __init__( - self, - *args, - ROOT, - dataset_metadata_dir, - split, - overfit_num_sets=None, - sample_specific_scene: bool = False, - specific_scene_name: str = None, - **kwargs, - ): - """ - Initialize the dataset attributes. - Args: - ROOT: Root directory of the dataset. - dataset_metadata_dir: Path to the dataset metadata directory. - split: Dataset split (train, val, test). - overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. - sample_specific_scene: Whether to sample a specific scene from the dataset. - specific_scene_name: Name of the specific scene to sample. - """ - # Initialize the dataset attributes - super().__init__(*args, **kwargs) - self.ROOT = ROOT - self.dataset_metadata_dir = dataset_metadata_dir - self.split = split - self.overfit_num_sets = overfit_num_sets - self.sample_specific_scene = sample_specific_scene - self.specific_scene_name = specific_scene_name - self._load_data() - - # Define the dataset type flags - self.is_metric_scale = True - self.is_synthetic = True - - def _load_data(self): - "Load the precomputed dataset metadata" - # Load the dataset metadata corresponding to the split - split_metadata_path = os.path.join( - self.dataset_metadata_dir, - self.split, - f"structured3d_scene_list_{self.split}.npy", - ) - split_scene_list = np.load(split_metadata_path, allow_pickle=True) - - # Get the list of all scenes - if not self.sample_specific_scene: - self.scenes = list(split_scene_list) - else: - self.scenes = [self.specific_scene_name] - self.num_of_scenes = len(self.scenes) - - def _get_views(self, sampled_idx, num_views_to_sample, resolution): - # Get the scene name of the sampled index - scene_index = sampled_idx - scene_name = self.scenes[scene_index] - - # Get the metadata corresponding to the scene - scene_root = os.path.join(self.ROOT, scene_name) - scene_meta = load_data( - os.path.join(scene_root, "scene_meta.json"), "scene_meta" - ) - scene_file_names = list(scene_meta["frame_names"].keys()) - num_views_in_scene = len(scene_file_names) - - # Load the scene pairwise covisibility mmap - covisibility_version_key = "v0" - covisibility_map_dir = os.path.join( - scene_root, "covisibility", covisibility_version_key - ) - # Assumes only npy file in directory is covisbility map - covisibility_map_name = next( - f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") - ) - covisibility_map_path = os.path.join( - scene_root, "covisibility", covisibility_version_key, covisibility_map_name - ) - pairwise_covisibility = load_data(covisibility_map_path, "mmap") - - # Get the indices of the N views in the scene - view_indices = self._sample_view_indices( - num_views_to_sample, num_views_in_scene, pairwise_covisibility - ) - - # Get the views corresponding to the selected view indices - views = [] - for view_index in view_indices: - # Load the data corresponding to the view - view_file_name = scene_file_names[view_index] - view_data = load_frame( - scene_root, - view_file_name, - modalities=["image", "depth"], - scene_meta=scene_meta, - ) - - # Convert necessary data to numpy - image = view_data["image"].permute(1, 2, 0).numpy() - image = image[:, :, :3] # RGBA to RGB - image = (image * 255).astype(np.uint8) - depthmap = view_data["depth"].numpy().astype(np.float32) - intrinsics = view_data["intrinsics"].numpy().astype(np.float32) - c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) - - # Ensure that the depthmap has all valid values - depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) - - # Resize the data to match the desired resolution - image, depthmap, intrinsics = self._crop_resize_if_necessary( - image=image, - resolution=resolution, - depthmap=depthmap, - intrinsics=intrinsics, - additional_quantities=None, - ) - - # Append the view dictionary to the list of views - views.append( - dict( - img=image, - depthmap=depthmap, - camera_pose=c2w_pose, # cam2world - camera_intrinsics=intrinsics, - dataset="Structured3D", - label=scene_name, - instance=os.path.join("images", str(view_file_name)), - ) - ) - - return views - - -def get_parser(): - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument( - "-rd", "--root_dir", default="/fsx/xrtech/data/structured3d", type=str - ) - parser.add_argument( - "-dmd", - "--dataset_metadata_dir", - default="/fsx/nkeetha/mapanything_dataset_metadata", - type=str, - ) - parser.add_argument( - "-nv", - "--num_of_views", - default=2, - type=int, - ) - parser.add_argument("--viz", action="store_true") - - return parser - - -if __name__ == "__main__": - import rerun as rr - from tqdm import tqdm - - from mapanything.datasets.base.base_dataset import view_name - from mapanything.utils.image import rgb - from mapanything.utils.viz import script_add_rerun_args - - parser = get_parser() - script_add_rerun_args( - parser - ) # Options: --headless, --connect, --serve, --addr, --save, --stdout - args = parser.parse_args() - - dataset = Structured3DWAI( - num_views=args.num_of_views, - split="train", - covisibility_thres=0.25, - ROOT=args.root_dir, - dataset_metadata_dir=args.dataset_metadata_dir, - resolution=(518, 294), - aug_crop=16, - transform="colorjitter+grayscale+gaublur", - data_norm_type="dinov2", - ) - # dataset = Structured3DWAI( - # num_views=args.num_of_views, - # split="val", - # covisibility_thres=0.25, - # ROOT=args.root_dir, - # dataset_metadata_dir=args.dataset_metadata_dir, - # resolution=(518, 294), - # seed=777, - # transform="imgnorm", - # data_norm_type="dinov2", - # ) - print(dataset.get_stats()) - - if args.viz: - rr.script_setup(args, "Structured3D_Dataloader") - rr.set_time("stable_time", sequence=0) - rr.log("world", rr.ViewCoordinates.RDF, static=True) - - sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) - - for num, idx in enumerate(tqdm(sampled_indices)): - views = dataset[idx] - assert len(views) == args.num_of_views - sample_name = f"{idx}" - for view_idx in range(args.num_of_views): - sample_name += f" {view_name(views[view_idx])}" - print(sample_name) - for view_idx in range(args.num_of_views): - image = rgb( - views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] - ) - depthmap = views[view_idx]["depthmap"] - pose = views[view_idx]["camera_pose"] - intrinsics = views[view_idx]["camera_intrinsics"] - pts3d = views[view_idx]["pts3d"] - valid_mask = views[view_idx]["valid_mask"] - if "non_ambiguous_mask" in views[view_idx]: - non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] - else: - non_ambiguous_mask = None - if "prior_depth_along_ray" in views[view_idx]: - prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] - else: - prior_depth_along_ray = None - if args.viz: - rr.set_time("stable_time", sequence=num) - base_name = f"world/view_{view_idx}" - pts_name = f"world/view_{view_idx}_pointcloud" - # Log camera info and loaded data - height, width = image.shape[0], image.shape[1] - rr.log( - base_name, - rr.Transform3D( - translation=pose[:3, 3], - mat3x3=pose[:3, :3], - ), - ) - rr.log( - f"{base_name}/pinhole", - rr.Pinhole( - image_from_camera=intrinsics, - height=height, - width=width, - camera_xyz=rr.ViewCoordinates.RDF, - ), - ) - rr.log( - f"{base_name}/pinhole/rgb", - rr.Image(image), - ) - rr.log( - f"{base_name}/pinhole/depth", - rr.DepthImage(depthmap), - ) - if prior_depth_along_ray is not None: - rr.log( - f"prior_depth_along_ray_{view_idx}", - rr.DepthImage(prior_depth_along_ray), - ) - if non_ambiguous_mask is not None: - rr.log( - f"{base_name}/pinhole/non_ambiguous_mask", - rr.SegmentationImage(non_ambiguous_mask.astype(int)), - ) - # Log points in 3D - filtered_pts = pts3d[valid_mask] - filtered_pts_col = image[valid_mask] - rr.log( - pts_name, - rr.Points3D( - positions=filtered_pts.reshape(-1, 3), - colors=filtered_pts_col.reshape(-1, 3), - ), - ) diff --git a/mapanything/datasets/wai/tav2_wb.py b/mapanything/datasets/wai/tav2_wb.py index f19bc5c00de55b7b16eb53252260519448443dee..f3c445026c60ce7dda9a5022650d49a5e22eafdb 100644 --- a/mapanything/datasets/wai/tav2_wb.py +++ b/mapanything/datasets/wai/tav2_wb.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + """ TartanAirV2-WB Dataset using WAI format data. """ @@ -86,7 +91,7 @@ class TartanAirV2WBWAI(BaseDataset): covisibility_map_dir = os.path.join( scene_root, "covisibility", covisibility_version_key ) - # Assumes only npy file in directory is covisbility map + # Assumes only npy file in directory is covisibility map covisibility_map_name = next( f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") ) @@ -172,12 +177,12 @@ def get_parser(): parser = argparse.ArgumentParser() parser.add_argument( - "-rd", "--root_dir", default="/fsx/xrtech/data/tav2_wb", type=str + "-rd", "--root_dir", default="/ai4rl/fsx/xrtech/data/tav2_wb", type=str ) parser.add_argument( "-dmd", "--dataset_metadata_dir", - default="/fsx/nkeetha/mapanything_dataset_metadata", + default="/ai4rl/fsx/nkeetha/mapanything_dataset_metadata", type=str, ) parser.add_argument( diff --git a/mapanything/datasets/wai/unrealstereo4k.py b/mapanything/datasets/wai/unrealstereo4k.py index 99b9d32a094b416509261450bc11058c634001a3..633b5f1e1ec8b8417a480f70453b50ae262bc6e7 100644 --- a/mapanything/datasets/wai/unrealstereo4k.py +++ b/mapanything/datasets/wai/unrealstereo4k.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + """ UnrealStereo4K Dataset using WAI format data. """ @@ -85,7 +90,7 @@ class UnrealStereo4KWAI(BaseDataset): covisibility_map_dir = os.path.join( scene_root, "covisibility", covisibility_version_key ) - # Assumes only npy file in directory is covisbility map + # Assumes only npy file in directory is covisibility map covisibility_map_name = next( f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") ) @@ -164,12 +169,12 @@ def get_parser(): parser = argparse.ArgumentParser() parser.add_argument( - "-rd", "--root_dir", default="/fsx/xrtech/data/unrealstereo4k", type=str + "-rd", "--root_dir", default="/ai4rl/fsx/xrtech/data/unrealstereo4k", type=str ) parser.add_argument( "-dmd", "--dataset_metadata_dir", - default="/fsx/nkeetha/mapanything_dataset_metadata", + default="/ai4rl/fsx/nkeetha/mapanything_dataset_metadata", type=str, ) parser.add_argument( diff --git a/mapanything/datasets/wai/xrooms.py b/mapanything/datasets/wai/xrooms.py deleted file mode 100644 index d8e31b8c59c02507dc1aa90b5b81d4883bf18a35..0000000000000000000000000000000000000000 --- a/mapanything/datasets/wai/xrooms.py +++ /dev/null @@ -1,300 +0,0 @@ -""" -XRooms Dataset using WAI format data. -""" - -import os - -import numpy as np - -from mapanything.datasets.base.base_dataset import BaseDataset -from mapanything.utils.wai.core import load_data, load_frame - - -class XRoomsWAI(BaseDataset): - """ - XRooms dataset containing large diversity of synthetic re-lightable indoor scenes. - """ - - def __init__( - self, - *args, - ROOT, - dataset_metadata_dir, - split, - overfit_num_sets=None, - sample_specific_scene: bool = False, - specific_scene_name: str = None, - **kwargs, - ): - """ - Initialize the dataset attributes. - Args: - ROOT: Root directory of the dataset. - dataset_metadata_dir: Path to the dataset metadata directory. - split: Dataset split (train, val, test). - overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. - sample_specific_scene: Whether to sample a specific scene from the dataset. - specific_scene_name: Name of the specific scene to sample. - """ - # Initialize the dataset attributes - super().__init__(*args, **kwargs) - self.ROOT = ROOT - self.dataset_metadata_dir = dataset_metadata_dir - self.split = split - self.overfit_num_sets = overfit_num_sets - self.sample_specific_scene = sample_specific_scene - self.specific_scene_name = specific_scene_name - self._load_data() - - # Define the dataset type flags - self.is_metric_scale = True - self.is_synthetic = True - - def _load_data(self): - "Load the precomputed dataset metadata" - # Load the dataset metadata corresponding to the split - split_metadata_path = os.path.join( - self.dataset_metadata_dir, - self.split, - f"xrooms_scene_list_{self.split}.npy", - ) - split_scene_list = np.load(split_metadata_path, allow_pickle=True) - - # Get the list of all scenes - if not self.sample_specific_scene: - self.scenes = list(split_scene_list) - else: - self.scenes = [self.specific_scene_name] - self.num_of_scenes = len(self.scenes) - - def _get_views(self, sampled_idx, num_views_to_sample, resolution): - # Get the scene name of the sampled index - scene_index = sampled_idx - scene_name = self.scenes[scene_index] - - # Get the metadata corresponding to the scene - scene_root = os.path.join(self.ROOT, scene_name) - scene_meta = load_data( - os.path.join(scene_root, "scene_meta.json"), "scene_meta" - ) - scene_file_names = list(scene_meta["frame_names"].keys()) - num_views_in_scene = len(scene_file_names) - - # Load the scene pairwise covisibility mmap - covisibility_version_key = "v0" - covisibility_map_dir = os.path.join( - scene_root, "covisibility", covisibility_version_key - ) - # Assumes only npy file in directory is covisbility map - covisibility_map_name = next( - f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") - ) - covisibility_map_path = os.path.join( - scene_root, "covisibility", covisibility_version_key, covisibility_map_name - ) - pairwise_covisibility = load_data(covisibility_map_path, "mmap") - - ### HOTFIX HACK for incompatible covisibility in a few scenes - ### TODO: Re-mine covisibility on errorenous scenes - if len(pairwise_covisibility) == num_views_in_scene: - # Get the indices of the N views in the scene - view_indices = self._sample_view_indices( - num_views_to_sample, num_views_in_scene, pairwise_covisibility - ) - else: - # Get a random view index - view_indices = self._rng.choice(num_views_in_scene, size=1, replace=False) - # Repeat the view index to get the desired number of views - view_indices = np.repeat(view_indices, num_views_to_sample) - ### END HOTFIX HACK - - # Get the views corresponding to the selected view indices - views = [] - for view_index in view_indices: - # Load the data corresponding to the view - view_file_name = scene_file_names[view_index] - view_data = load_frame( - scene_root, - view_file_name, - modalities=["image", "depth"], - scene_meta=scene_meta, - ) - - # Convert necessary data to numpy - image = view_data["image"].permute(1, 2, 0).numpy() - image = (image * 255).astype(np.uint8) - depthmap = view_data["depth"].numpy().astype(np.float32) - intrinsics = view_data["intrinsics"].numpy().astype(np.float32) - c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) - - # Ensure that the depthmap has all valid values - depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) - - # Resize the data to match the desired resolution - image, depthmap, intrinsics = self._crop_resize_if_necessary( - image=image, - resolution=resolution, - depthmap=depthmap, - intrinsics=intrinsics, - additional_quantities=None, - ) - - # Append the view dictionary to the list of views - views.append( - dict( - img=image, - depthmap=depthmap, - camera_pose=c2w_pose, # cam2world - camera_intrinsics=intrinsics, - dataset="XRooms", - label=scene_name, - instance=os.path.join("images", str(view_file_name)), - ) - ) - - return views - - -def get_parser(): - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument( - "-rd", "--root_dir", default="/fsx/xrtech/data/xrooms", type=str - ) - parser.add_argument( - "-dmd", - "--dataset_metadata_dir", - default="/fsx/nkeetha/mapanything_dataset_metadata", - type=str, - ) - parser.add_argument( - "-nv", - "--num_of_views", - default=2, - type=int, - ) - parser.add_argument("--viz", action="store_true") - - return parser - - -if __name__ == "__main__": - import rerun as rr - from tqdm import tqdm - - from mapanything.datasets.base.base_dataset import view_name - from mapanything.utils.image import rgb - from mapanything.utils.viz import script_add_rerun_args - - parser = get_parser() - script_add_rerun_args( - parser - ) # Options: --headless, --connect, --serve, --addr, --save, --stdout - args = parser.parse_args() - - dataset = XRoomsWAI( - num_views=args.num_of_views, - split="train", - covisibility_thres=0.25, - ROOT=args.root_dir, - dataset_metadata_dir=args.dataset_metadata_dir, - resolution=(518, 518), - aug_crop=16, - transform="colorjitter+grayscale+gaublur", - data_norm_type="dinov2", - ) - # dataset = XRoomsWAI( - # num_views=args.num_of_views, - # split="val", - # covisibility_thres=0.25, - # ROOT=args.root_dir, - # dataset_metadata_dir=args.dataset_metadata_dir, - # resolution=(518, 518), - # seed=777, - # transform="imgnorm", - # data_norm_type="dinov2", - # ) - print(dataset.get_stats()) - - if args.viz: - rr.script_setup(args, "XRooms_Dataloader") - rr.set_time("stable_time", sequence=0) - rr.log("world", rr.ViewCoordinates.RDF, static=True) - - sampled_indices = np.random.choice(len(dataset), size=10, replace=False) - - for num, idx in enumerate(tqdm(sampled_indices)): - views = dataset[idx] - assert len(views) == args.num_of_views - sample_name = f"{idx}" - for view_idx in range(args.num_of_views): - sample_name += f" {view_name(views[view_idx])}" - print(sample_name) - for view_idx in range(args.num_of_views): - image = rgb( - views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] - ) - depthmap = views[view_idx]["depthmap"] - pose = views[view_idx]["camera_pose"] - intrinsics = views[view_idx]["camera_intrinsics"] - pts3d = views[view_idx]["pts3d"] - valid_mask = views[view_idx]["valid_mask"] - if "non_ambiguous_mask" in views[view_idx]: - non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] - else: - non_ambiguous_mask = None - if "prior_depth_along_ray" in views[view_idx]: - prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] - else: - prior_depth_along_ray = None - if args.viz: - rr.set_time("stable_time", sequence=num) - base_name = f"world/view_{view_idx}" - pts_name = f"world/view_{view_idx}_pointcloud" - # Log camera info and loaded data - height, width = image.shape[0], image.shape[1] - rr.log( - base_name, - rr.Transform3D( - translation=pose[:3, 3], - mat3x3=pose[:3, :3], - ), - ) - rr.log( - f"{base_name}/pinhole", - rr.Pinhole( - image_from_camera=intrinsics, - height=height, - width=width, - camera_xyz=rr.ViewCoordinates.RDF, - ), - ) - rr.log( - f"{base_name}/pinhole/rgb", - rr.Image(image), - ) - rr.log( - f"{base_name}/pinhole/depth", - rr.DepthImage(depthmap), - ) - if prior_depth_along_ray is not None: - rr.log( - f"prior_depth_along_ray_{view_idx}", - rr.DepthImage(prior_depth_along_ray), - ) - if non_ambiguous_mask is not None: - rr.log( - f"{base_name}/pinhole/non_ambiguous_mask", - rr.SegmentationImage(non_ambiguous_mask.astype(int)), - ) - # Log points in 3D - filtered_pts = pts3d[valid_mask] - filtered_pts_col = image[valid_mask] - rr.log( - pts_name, - rr.Points3D( - positions=filtered_pts.reshape(-1, 3), - colors=filtered_pts_col.reshape(-1, 3), - ), - ) diff --git a/mapanything/models/__init__.py b/mapanything/models/__init__.py index f1b9de65a1cb8ff5cb15fab527bc5a7194896f8a..7c1417f13318640e87414909fe3eae1d78585519 100644 --- a/mapanything/models/__init__.py +++ b/mapanything/models/__init__.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + """ Model Factory for MapAnything """ @@ -71,6 +76,10 @@ MODEL_CONFIGS = { "module": "mapanything.models.external.anycalib", "class_name": "AnyCalibWrapper", }, + "da3": { + "module": "mapanything.models.external.da3", + "class_name": "DA3Wrapper", + }, "dust3r": { "module": "mapanything.models.external.dust3r", "class_name": "DUSt3RBAWrapper", diff --git a/mapanything/models/mapanything/__init__.py b/mapanything/models/mapanything/__init__.py index 13e6352b2a6e6a5374d5e8a079ef65fc503816e8..885076daecd0c83d1a85769a9224490295b71547 100644 --- a/mapanything/models/mapanything/__init__.py +++ b/mapanything/models/mapanything/__init__.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + from mapanything.models.mapanything.ablations import MapAnythingAblations from mapanything.models.mapanything.model import MapAnything from mapanything.models.mapanything.modular_dust3r import ModularDUSt3R diff --git a/mapanything/models/mapanything/ablations.py b/mapanything/models/mapanything/ablations.py index 5bf4cb633d30b09815157d17087cc3eca93b39a4..9429780c05454d2ec13f7f8d5f323d794e077d60 100644 --- a/mapanything/models/mapanything/ablations.py +++ b/mapanything/models/mapanything/ablations.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + """ MapAnything Ablation model classes defined using UniCeption modules. """ @@ -104,7 +109,7 @@ class MapAnythingAblations(nn.Module): """ super().__init__() - # Initalize the attributes + # Initialize the attributes self.name = name self.encoder_config = encoder_config self.info_sharing_config = info_sharing_config @@ -230,7 +235,7 @@ class MapAnythingAblations(nn.Module): else: self.custom_positional_encoding = None - # Add dependecies to info_sharing_config + # Add dependencies to info_sharing_config info_sharing_config["module_args"]["input_embed_dim"] = ( self.encoder.enc_embed_dim ) @@ -241,7 +246,7 @@ class MapAnythingAblations(nn.Module): # Initialize Multi-View Transformer if self.info_sharing_return_type == "no_intermediate_features": # Returns only normalized last layer features - # Intialize multi-view transformer based on type + # Initialize multi-view transformer based on type if self.info_sharing_type == "cross_attention": self.info_sharing = MultiViewCrossAttentionTransformer( **info_sharing_config["module_args"] @@ -343,7 +348,7 @@ class MapAnythingAblations(nn.Module): # Initialize Dense Prediction Head for all views self.dense_head = LinearFeature(**pred_head_config["feature_head"]) elif "dpt" in self.pred_head_type: - # Initialze Dense Predction Head for all views + # Initialize Dense Prediction Head for all views self.dpt_feature_head = DPTFeature(**pred_head_config["feature_head"]) self.dpt_regressor_head = DPTRegressionProcessor( **pred_head_config["regressor_head"] diff --git a/mapanything/models/mapanything/model.py b/mapanything/models/mapanything/model.py index 2de75cf5be57a05fa97b41f877264113f27e11c3..3b4e88edbac2e18ca23a88acaad22e37b95229af 100644 --- a/mapanything/models/mapanything/model.py +++ b/mapanything/models/mapanything/model.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + """ MapAnything model class defined using UniCeption modules. """ @@ -71,6 +76,7 @@ from uniception.models.prediction_heads.dpt import DPTFeature, DPTRegressionProc from uniception.models.prediction_heads.linear import LinearFeature from uniception.models.prediction_heads.mlp_head import MLPHead from uniception.models.prediction_heads.pose_head import PoseHead +from uniception.models.utils.transformer_blocks import Mlp, SwiGLUFFNFused # Enable TF32 precision if supported (for GPU >= Ampere and PyTorch >= 1.12) if hasattr(torch.backends.cuda, "matmul") and hasattr( @@ -96,6 +102,8 @@ class MapAnything(nn.Module, PyTorchModelHubMixin): load_specific_pretrained_submodules: bool = False, specific_pretrained_submodules: list = None, torch_hub_force_reload: bool = False, + use_register_tokens_from_encoder: bool = False, + info_sharing_mlp_layer_str: str = "mlp", ): """ Multi-view model containing an image encoder fused with optional geometric modalities followed by a multi-view attention transformer and respective downstream heads. @@ -113,10 +121,12 @@ class MapAnything(nn.Module, PyTorchModelHubMixin): load_specific_pretrained_submodules (bool): Whether to load specific pretrained submodules. (default: False) specific_pretrained_submodules (list): List of specific pretrained submodules to load. Must be provided when load_specific_pretrained_submodules is True. (default: None) torch_hub_force_reload (bool): Whether to force reload the encoder from torch hub. (default: False) + use_register_tokens_from_encoder (bool): Whether to use register tokens from encoder. (default: False) + info_sharing_mlp_layer_str (str): Type of MLP layer to use in the multi-view transformer. Useful for DINO init of the multi-view transformer. Options: "mlp" or "swiglufused". (default: "mlp") """ super().__init__() - # Initalize the attributes + # Initialize the attributes self.name = name self.encoder_config = encoder_config self.info_sharing_config = info_sharing_config @@ -126,6 +136,8 @@ class MapAnything(nn.Module, PyTorchModelHubMixin): self.load_specific_pretrained_submodules = load_specific_pretrained_submodules self.specific_pretrained_submodules = specific_pretrained_submodules self.torch_hub_force_reload = torch_hub_force_reload + self.use_register_tokens_from_encoder = use_register_tokens_from_encoder + self.info_sharing_mlp_layer_str = info_sharing_mlp_layer_str self.class_init_args = { "name": self.name, "encoder_config": self.encoder_config, @@ -136,6 +148,8 @@ class MapAnything(nn.Module, PyTorchModelHubMixin): "load_specific_pretrained_submodules": self.load_specific_pretrained_submodules, "specific_pretrained_submodules": self.specific_pretrained_submodules, "torch_hub_force_reload": self.torch_hub_force_reload, + "use_register_tokens_from_encoder": self.use_register_tokens_from_encoder, + "info_sharing_mlp_layer_str": self.info_sharing_mlp_layer_str, } # Get relevant parameters from the configs @@ -196,6 +210,16 @@ class MapAnything(nn.Module, PyTorchModelHubMixin): self.scale_token = nn.Parameter(torch.zeros(self.encoder.enc_embed_dim)) torch.nn.init.trunc_normal_(self.scale_token, std=0.02) + # Set the MLP layer config for the info sharing transformer + if info_sharing_mlp_layer_str == "mlp": + info_sharing_config["module_args"]["mlp_layer"] = Mlp + elif info_sharing_mlp_layer_str == "swiglufused": + info_sharing_config["module_args"]["mlp_layer"] = SwiGLUFFNFused + else: + raise ValueError( + f"Invalid info_sharing_mlp_layer_str: {info_sharing_mlp_layer_str}. Valid options: ['mlp', 'swiglufused']" + ) + # Initialize the info sharing module (multi-view transformer) self._initialize_info_sharing(info_sharing_config) @@ -251,7 +275,7 @@ class MapAnything(nn.Module, PyTorchModelHubMixin): else: self.custom_positional_encoding = None - # Add dependecies to info_sharing_config + # Add dependencies to info_sharing_config info_sharing_config["module_args"]["input_embed_dim"] = ( self.encoder.enc_embed_dim ) @@ -262,7 +286,7 @@ class MapAnything(nn.Module, PyTorchModelHubMixin): # Initialize Multi-View Transformer if self.info_sharing_return_type == "no_intermediate_features": # Returns only normalized last layer features - # Intialize multi-view transformer based on type + # Initialize multi-view transformer based on type if self.info_sharing_type == "cross_attention": self.info_sharing = MultiViewCrossAttentionTransformer( **info_sharing_config["module_args"] @@ -365,7 +389,7 @@ class MapAnything(nn.Module, PyTorchModelHubMixin): # Initialize Dense Prediction Head for all views self.dense_head = LinearFeature(**pred_head_config["feature_head"]) elif "dpt" in self.pred_head_type: - # Initialze Dense Predction Head for all views + # Initialize Dense Prediction Head for all views self.dpt_feature_head = DPTFeature(**pred_head_config["feature_head"]) self.dpt_regressor_head = DPTRegressionProcessor( **pred_head_config["regressor_head"] @@ -623,7 +647,9 @@ class MapAnything(nn.Module, PyTorchModelHubMixin): views (List[dict]): List of dictionaries containing the input views' images and instance information. Returns: - List[torch.Tensor]: A list containing the encoded features for all N views. + A tuple containing: + List[torch.Tensor]: A list containing the encoded features for all N views. + List[torch.Tensor]: A list containing the encoded per-view registers for all N views. """ num_views = len(views) data_norm_type = views[0]["data_norm_type"][0] @@ -636,8 +662,16 @@ class MapAnything(nn.Module, PyTorchModelHubMixin): all_encoder_features_across_views = encoder_output.features.chunk( num_views, dim=0 ) + all_encoder_registers_across_views = None + if ( + self.use_register_tokens_from_encoder + and encoder_output.registers is not None + ): + all_encoder_registers_across_views = encoder_output.registers.chunk( + num_views, dim=0 + ) - return all_encoder_features_across_views + return all_encoder_features_across_views, all_encoder_registers_across_views def _compute_pose_quats_and_trans_for_across_views_in_ref_view( self, @@ -1504,7 +1538,9 @@ class MapAnything(nn.Module, PyTorchModelHubMixin): num_views = len(views) # Run the image encoder on all the input views - all_encoder_features_across_views = self._encode_n_views(views) + all_encoder_features_across_views, all_encoder_registers_across_views = ( + self._encode_n_views(views) + ) # Encode the optional geometric inputs and fuse with the encoded features from the N input views # Use high precision to prevent NaN values after layer norm in dense representation encoder (due to high variance in last dim of features) @@ -1526,6 +1562,7 @@ class MapAnything(nn.Module, PyTorchModelHubMixin): # Output is a list containing the encoded features for all N views after information sharing. info_sharing_input = MultiViewTransformerInput( features=all_encoder_features_across_views, + additional_input_tokens_per_view=all_encoder_registers_across_views, additional_input_tokens=input_scale_token, ) if self.info_sharing_return_type == "no_intermediate_features": @@ -2069,7 +2106,13 @@ class MapAnything(nn.Module, PyTorchModelHubMixin): for name in view.keys(): if name in ignore_keys: continue - view[name] = view[name].to(self.device, non_blocking=True) + val = view[name] + if name == "camera_poses" and isinstance(val, tuple): + view[name] = tuple( + x.to(self.device, non_blocking=True) for x in val + ) + elif hasattr(val, "to"): + view[name] = val.to(self.device, non_blocking=True) # Pre-process the input views processed_views = preprocess_input_views_for_inference(validated_views) diff --git a/mapanything/models/mapanything/modular_dust3r.py b/mapanything/models/mapanything/modular_dust3r.py index 00914e7830eaf9e11c91a9d0384461fb3536ae98..c9f6172842c59488c1abb8256409ff4551e3ee44 100644 --- a/mapanything/models/mapanything/modular_dust3r.py +++ b/mapanything/models/mapanything/modular_dust3r.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + """ Modular DUSt3R class defined using UniCeption modules. """ @@ -70,7 +75,7 @@ class ModularDUSt3R(nn.Module): """ super().__init__(*args, **kwargs) - # Initalize the attributes + # Initialize the attributes self.name = name self.encoder_config = encoder_config self.info_sharing_config = info_sharing_config @@ -126,7 +131,7 @@ class ModularDUSt3R(nn.Module): else: self.custom_positional_encoding = None - # Add dependecies to info_sharing_config + # Add dependencies to info_sharing_config info_sharing_config["module_args"]["input_embed_dim"] = ( self.encoder.enc_embed_dim ) @@ -137,7 +142,7 @@ class ModularDUSt3R(nn.Module): # Initialize Multi-View Transformer if self.info_sharing_return_type == "no_intermediate_features": # Returns only normalized last layer features - # Intialize multi-view transformer based on type + # Initialize multi-view transformer based on type if self.info_sharing_type == "cross_attention": self.info_sharing = MultiViewCrossAttentionTransformer( **info_sharing_config["module_args"] @@ -217,7 +222,7 @@ class ModularDUSt3R(nn.Module): # Initialize Prediction Head 2 self.head2 = LinearFeature(**pred_head_config["feature_head"]) elif self.pred_head_type == "dpt": - # Initialze Predction Head 1 + # Initialize Prediction Head 1 self.dpt_feature_head1 = DPTFeature(**pred_head_config["feature_head"]) self.dpt_regressor_head1 = DPTRegressionProcessor( **pred_head_config["regressor_head"] diff --git a/mapanything/train/losses.py b/mapanything/train/losses.py index 097b5684fe006969287fdb89d8b6f020a50f1fab..2c57da80fcb96991e2c429e02deaf2233c623994 100644 --- a/mapanything/train/losses.py +++ b/mapanything/train/losses.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + """ Multi-view geometric losses for training 3D reconstruction models. @@ -1499,7 +1504,7 @@ class PointsPlusScaleRegr3D(Criterion, MultiLoss): ): """ Initialize the loss criterion for World Frame Pointmaps & Scale. - The predicited scene representation is always normalized w.r.t. the frame of view0. + The predicted scene representation is always normalized w.r.t. the frame of view0. Loss is applied between the predicted metric scale and the ground truth metric scale. Args: @@ -1532,7 +1537,7 @@ class PointsPlusScaleRegr3D(Criterion, MultiLoss): n_views = len(batch) # Everything is normalized w.r.t. camera of view0 - # Intialize lists to store data for all views + # Initialize lists to store data for all views # Ground truth quantities in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"]) no_norm_gt_pts = [] @@ -1982,6 +1987,7 @@ class FactoredGeometryRegr3D(Criterion, MultiLoss): ray_directions_loss_weight=1, pose_quats_loss_weight=1, pose_trans_loss_weight=1, + compute_absolute_pose_loss=True, compute_pairwise_relative_pose_loss=False, convert_predictions_to_view0_frame=False, compute_world_frame_points_loss=True, @@ -2016,6 +2022,7 @@ class FactoredGeometryRegr3D(Criterion, MultiLoss): ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1. pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1. pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1. + compute_absolute_pose_loss (bool): If True, compute the absolute pose loss. Default: True. compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the exhaustive pairwise relative poses. Default: False. convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame. @@ -2037,14 +2044,16 @@ class FactoredGeometryRegr3D(Criterion, MultiLoss): self.loss_in_log = loss_in_log self.flatten_across_image_only = flatten_across_image_only self.depth_type_for_loss = depth_type_for_loss - assert self.depth_type_for_loss in ["depth_along_ray", "depth_z"], ( - "depth_type_for_loss must be one of ['depth_along_ray', 'depth_z']" - ) + assert self.depth_type_for_loss in [ + "depth_along_ray", + "depth_z", + ], "depth_type_for_loss must be one of ['depth_along_ray', 'depth_z']" self.cam_frame_points_loss_weight = cam_frame_points_loss_weight self.depth_loss_weight = depth_loss_weight self.ray_directions_loss_weight = ray_directions_loss_weight self.pose_quats_loss_weight = pose_quats_loss_weight self.pose_trans_loss_weight = pose_trans_loss_weight + self.compute_absolute_pose_loss = compute_absolute_pose_loss self.compute_pairwise_relative_pose_loss = compute_pairwise_relative_pose_loss self.convert_predictions_to_view0_frame = convert_predictions_to_view0_frame self.compute_world_frame_points_loss = compute_world_frame_points_loss @@ -2058,7 +2067,7 @@ class FactoredGeometryRegr3D(Criterion, MultiLoss): n_views = len(batch) # Everything is normalized w.r.t. camera of view0 - # Intialize lists to store data for all views + # Initialize lists to store data for all views # Ground truth quantities in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"]) no_norm_gt_pts = [] @@ -2413,7 +2422,35 @@ class FactoredGeometryRegr3D(Criterion, MultiLoss): gt_pts3d = apply_log_to_norm(gt_pts3d) pred_pts3d = apply_log_to_norm(pred_pts3d) - if self.compute_pairwise_relative_pose_loss: + # Compute pose loss + if ( + self.compute_absolute_pose_loss + and self.compute_pairwise_relative_pose_loss + ): + # Compute the absolute pose loss + # Get the pose info for the current view + pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]] + gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]] + pred_pose_quats = pred_info[i]["pose_quats"] + gt_pose_quats = gt_info[i]["pose_quats"] + + # Compute pose translation loss + abs_pose_trans_loss = self.criterion( + pred_pose_trans, gt_pose_trans, factor="pose_trans" + ) + abs_pose_trans_loss = abs_pose_trans_loss * self.pose_trans_loss_weight + + # Compute pose rotation loss + # Handle quaternion two-to-one mapping + abs_pose_quats_loss = torch.minimum( + self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"), + self.criterion( + pred_pose_quats, -gt_pose_quats, factor="pose_quats" + ), + ) + abs_pose_quats_loss = abs_pose_quats_loss * self.pose_quats_loss_weight + + # Compute the pairwise relative pose loss # Get the inverse of current view predicted pose pred_inv_curr_view_pose_quats = quaternion_inverse( pred_info[i]["pose_quats"] @@ -2496,15 +2533,14 @@ class FactoredGeometryRegr3D(Criterion, MultiLoss): gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0) # Compute pose translation loss - pose_trans_loss = self.criterion( + rel_pose_trans_loss = self.criterion( pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans" ) - pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight - pose_trans_losses.append(pose_trans_loss) + rel_pose_trans_loss = rel_pose_trans_loss * self.pose_trans_loss_weight # Compute pose rotation loss # Handle quaternion two-to-one mapping - pose_quats_loss = torch.minimum( + rel_pose_quats_loss = torch.minimum( self.criterion( pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats" ), @@ -2512,9 +2548,18 @@ class FactoredGeometryRegr3D(Criterion, MultiLoss): pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats" ), ) - pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + rel_pose_quats_loss = rel_pose_quats_loss * self.pose_quats_loss_weight + + # Concatenate the absolute and relative pose losses together + pose_trans_loss = torch.cat( + [abs_pose_trans_loss, rel_pose_trans_loss], dim=0 + ) + pose_quats_loss = torch.cat( + [abs_pose_quats_loss, rel_pose_quats_loss], dim=0 + ) + pose_trans_losses.append(pose_trans_loss) pose_quats_losses.append(pose_quats_loss) - else: + elif self.compute_absolute_pose_loss: # Get the pose info for the current view pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]] gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]] @@ -2538,6 +2583,112 @@ class FactoredGeometryRegr3D(Criterion, MultiLoss): ) pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight pose_quats_losses.append(pose_quats_loss) + elif self.compute_pairwise_relative_pose_loss: + # Get the inverse of current view predicted pose + pred_inv_curr_view_pose_quats = quaternion_inverse( + pred_info[i]["pose_quats"] + ) + pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( + pred_inv_curr_view_pose_quats + ) + pred_inv_curr_view_pose_trans = -1 * ein.einsum( + pred_inv_curr_view_pose_rot_mat, + pred_info[i]["pose_trans"], + "b i j, b j -> b i", + ) + + # Get the inverse of the current view GT pose + gt_inv_curr_view_pose_quats = quaternion_inverse( + gt_info[i]["pose_quats"] + ) + gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( + gt_inv_curr_view_pose_quats + ) + gt_inv_curr_view_pose_trans = -1 * ein.einsum( + gt_inv_curr_view_pose_rot_mat, + gt_info[i]["pose_trans"], + "b i j, b j -> b i", + ) + + # Get the other N-1 relative poses using the current pose as reference frame + pred_rel_pose_quats = [] + pred_rel_pose_trans = [] + gt_rel_pose_quats = [] + gt_rel_pose_trans = [] + for ov_idx in range(n_views): + if ov_idx == i: + continue + # Get the relative predicted pose + pred_ov_rel_pose_quats = quaternion_multiply( + pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"] + ) + pred_ov_rel_pose_trans = ( + ein.einsum( + pred_inv_curr_view_pose_rot_mat, + pred_info[ov_idx]["pose_trans"], + "b i j, b j -> b i", + ) + + pred_inv_curr_view_pose_trans + ) + + # Get the relative GT pose + gt_ov_rel_pose_quats = quaternion_multiply( + gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"] + ) + gt_ov_rel_pose_trans = ( + ein.einsum( + gt_inv_curr_view_pose_rot_mat, + gt_info[ov_idx]["pose_trans"], + "b i j, b j -> b i", + ) + + gt_inv_curr_view_pose_trans + ) + + # Get the valid translations using valid_norm_factor_masks for current view and other view + overall_valid_mask_for_trans = ( + valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx] + ) + + # Append the relative poses + pred_rel_pose_quats.append(pred_ov_rel_pose_quats) + pred_rel_pose_trans.append( + pred_ov_rel_pose_trans[overall_valid_mask_for_trans] + ) + gt_rel_pose_quats.append(gt_ov_rel_pose_quats) + gt_rel_pose_trans.append( + gt_ov_rel_pose_trans[overall_valid_mask_for_trans] + ) + + # Cat the N-1 relative poses along the batch dimension + pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0) + pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0) + gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0) + gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0) + + # Compute pose translation loss + pose_trans_loss = self.criterion( + pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans" + ) + pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight + pose_trans_losses.append(pose_trans_loss) + + # Compute pose rotation loss + # Handle quaternion two-to-one mapping + pose_quats_loss = torch.minimum( + self.criterion( + pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats" + ), + self.criterion( + pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats" + ), + ) + pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + pose_quats_losses.append(pose_quats_loss) + else: + # Error + raise ValueError( + "compute_absolute_pose_loss and compute_pairwise_relative_pose_loss cannot both be False" + ) # Compute ray direction loss ray_directions_loss = self.criterion( @@ -2673,6 +2824,7 @@ class FactoredGeometryRegr3DPlusNormalGMLoss(FactoredGeometryRegr3D): ray_directions_loss_weight=1, pose_quats_loss_weight=1, pose_trans_loss_weight=1, + compute_absolute_pose_loss=True, compute_pairwise_relative_pose_loss=False, convert_predictions_to_view0_frame=False, compute_world_frame_points_loss=True, @@ -2708,6 +2860,7 @@ class FactoredGeometryRegr3DPlusNormalGMLoss(FactoredGeometryRegr3D): ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1. pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1. pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1. + compute_absolute_pose_loss (bool): If True, compute the absolute pose loss. Default: True. compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the exhaustive pairwise relative poses. Default: False. convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame. @@ -2733,6 +2886,7 @@ class FactoredGeometryRegr3DPlusNormalGMLoss(FactoredGeometryRegr3D): ray_directions_loss_weight=ray_directions_loss_weight, pose_quats_loss_weight=pose_quats_loss_weight, pose_trans_loss_weight=pose_trans_loss_weight, + compute_absolute_pose_loss=compute_absolute_pose_loss, compute_pairwise_relative_pose_loss=compute_pairwise_relative_pose_loss, convert_predictions_to_view0_frame=convert_predictions_to_view0_frame, compute_world_frame_points_loss=compute_world_frame_points_loss, @@ -2859,7 +3013,35 @@ class FactoredGeometryRegr3DPlusNormalGMLoss(FactoredGeometryRegr3D): gt_pts3d = apply_log_to_norm(gt_pts3d) pred_pts3d = apply_log_to_norm(pred_pts3d) - if self.compute_pairwise_relative_pose_loss: + # Compute pose loss + if ( + self.compute_absolute_pose_loss + and self.compute_pairwise_relative_pose_loss + ): + # Compute the absolute pose loss + # Get the pose info for the current view + pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]] + gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]] + pred_pose_quats = pred_info[i]["pose_quats"] + gt_pose_quats = gt_info[i]["pose_quats"] + + # Compute pose translation loss + abs_pose_trans_loss = self.criterion( + pred_pose_trans, gt_pose_trans, factor="pose_trans" + ) + abs_pose_trans_loss = abs_pose_trans_loss * self.pose_trans_loss_weight + + # Compute pose rotation loss + # Handle quaternion two-to-one mapping + abs_pose_quats_loss = torch.minimum( + self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"), + self.criterion( + pred_pose_quats, -gt_pose_quats, factor="pose_quats" + ), + ) + abs_pose_quats_loss = abs_pose_quats_loss * self.pose_quats_loss_weight + + # Compute the pairwise relative pose loss # Get the inverse of current view predicted pose pred_inv_curr_view_pose_quats = quaternion_inverse( pred_info[i]["pose_quats"] @@ -2942,15 +3124,14 @@ class FactoredGeometryRegr3DPlusNormalGMLoss(FactoredGeometryRegr3D): gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0) # Compute pose translation loss - pose_trans_loss = self.criterion( + rel_pose_trans_loss = self.criterion( pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans" ) - pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight - pose_trans_losses.append(pose_trans_loss) + rel_pose_trans_loss = rel_pose_trans_loss * self.pose_trans_loss_weight # Compute pose rotation loss # Handle quaternion two-to-one mapping - pose_quats_loss = torch.minimum( + rel_pose_quats_loss = torch.minimum( self.criterion( pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats" ), @@ -2958,9 +3139,18 @@ class FactoredGeometryRegr3DPlusNormalGMLoss(FactoredGeometryRegr3D): pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats" ), ) - pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + rel_pose_quats_loss = rel_pose_quats_loss * self.pose_quats_loss_weight + + # Concatenate the absolute and relative pose losses together + pose_trans_loss = torch.cat( + [abs_pose_trans_loss, rel_pose_trans_loss], dim=0 + ) + pose_quats_loss = torch.cat( + [abs_pose_quats_loss, rel_pose_quats_loss], dim=0 + ) + pose_trans_losses.append(pose_trans_loss) pose_quats_losses.append(pose_quats_loss) - else: + elif self.compute_absolute_pose_loss: # Get the pose info for the current view pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]] gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]] @@ -2984,56 +3174,162 @@ class FactoredGeometryRegr3DPlusNormalGMLoss(FactoredGeometryRegr3D): ) pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight pose_quats_losses.append(pose_quats_loss) + elif self.compute_pairwise_relative_pose_loss: + # Get the inverse of current view predicted pose + pred_inv_curr_view_pose_quats = quaternion_inverse( + pred_info[i]["pose_quats"] + ) + pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( + pred_inv_curr_view_pose_quats + ) + pred_inv_curr_view_pose_trans = -1 * ein.einsum( + pred_inv_curr_view_pose_rot_mat, + pred_info[i]["pose_trans"], + "b i j, b j -> b i", + ) - # Compute ray direction loss - ray_directions_loss = self.criterion( - pred_ray_directions, gt_ray_directions, factor="ray_directions" - ) - ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight - ray_directions_losses.append(ray_directions_loss) - - # Compute depth loss - depth_loss = self.criterion(pred_depth, gt_depth, factor="depth") - depth_loss = depth_loss * self.depth_loss_weight - depth_losses.append(depth_loss) - - # Compute camera frame point loss - cam_pts3d_loss = self.criterion( - pred_cam_pts3d, gt_cam_pts3d, factor="points" - ) - cam_pts3d_loss = cam_pts3d_loss * self.cam_frame_points_loss_weight - cam_pts3d_losses.append(cam_pts3d_loss) - - if self.compute_world_frame_points_loss: - # Compute point loss - pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points") - pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight - pts3d_losses.append(pts3d_loss) + # Get the inverse of the current view GT pose + gt_inv_curr_view_pose_quats = quaternion_inverse( + gt_info[i]["pose_quats"] + ) + gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( + gt_inv_curr_view_pose_quats + ) + gt_inv_curr_view_pose_trans = -1 * ein.einsum( + gt_inv_curr_view_pose_rot_mat, + gt_info[i]["pose_trans"], + "b i j, b j -> b i", + ) - # Handle ambiguous pixels - if self.ambiguous_loss_value > 0: - if not self.flatten_across_image_only: - depth_losses[i] = torch.where( - ambiguous_masks[i][valid_masks[i]], - self.ambiguous_loss_value, - depth_losses[i], - ) - cam_pts3d_losses[i] = torch.where( - ambiguous_masks[i][valid_masks[i]], - self.ambiguous_loss_value, - cam_pts3d_losses[i], + # Get the other N-1 relative poses using the current pose as reference frame + pred_rel_pose_quats = [] + pred_rel_pose_trans = [] + gt_rel_pose_quats = [] + gt_rel_pose_trans = [] + for ov_idx in range(n_views): + if ov_idx == i: + continue + # Get the relative predicted pose + pred_ov_rel_pose_quats = quaternion_multiply( + pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"] ) - if self.compute_world_frame_points_loss: - pts3d_losses[i] = torch.where( - ambiguous_masks[i][valid_masks[i]], - self.ambiguous_loss_value, - pts3d_losses[i], + pred_ov_rel_pose_trans = ( + ein.einsum( + pred_inv_curr_view_pose_rot_mat, + pred_info[ov_idx]["pose_trans"], + "b i j, b j -> b i", ) - else: - depth_losses[i] = torch.where( - ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), - self.ambiguous_loss_value, - depth_losses[i], + + pred_inv_curr_view_pose_trans + ) + + # Get the relative GT pose + gt_ov_rel_pose_quats = quaternion_multiply( + gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"] + ) + gt_ov_rel_pose_trans = ( + ein.einsum( + gt_inv_curr_view_pose_rot_mat, + gt_info[ov_idx]["pose_trans"], + "b i j, b j -> b i", + ) + + gt_inv_curr_view_pose_trans + ) + + # Get the valid translations using valid_norm_factor_masks for current view and other view + overall_valid_mask_for_trans = ( + valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx] + ) + + # Append the relative poses + pred_rel_pose_quats.append(pred_ov_rel_pose_quats) + pred_rel_pose_trans.append( + pred_ov_rel_pose_trans[overall_valid_mask_for_trans] + ) + gt_rel_pose_quats.append(gt_ov_rel_pose_quats) + gt_rel_pose_trans.append( + gt_ov_rel_pose_trans[overall_valid_mask_for_trans] + ) + + # Cat the N-1 relative poses along the batch dimension + pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0) + pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0) + gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0) + gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0) + + # Compute pose translation loss + pose_trans_loss = self.criterion( + pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans" + ) + pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight + pose_trans_losses.append(pose_trans_loss) + + # Compute pose rotation loss + # Handle quaternion two-to-one mapping + pose_quats_loss = torch.minimum( + self.criterion( + pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats" + ), + self.criterion( + pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats" + ), + ) + pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + pose_quats_losses.append(pose_quats_loss) + else: + # Error + raise ValueError( + "compute_absolute_pose_loss and compute_pairwise_relative_pose_loss cannot both be False" + ) + + # Compute ray direction loss + ray_directions_loss = self.criterion( + pred_ray_directions, gt_ray_directions, factor="ray_directions" + ) + ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight + ray_directions_losses.append(ray_directions_loss) + + # Compute depth loss + depth_loss = self.criterion(pred_depth, gt_depth, factor="depth") + depth_loss = depth_loss * self.depth_loss_weight + depth_losses.append(depth_loss) + + # Compute camera frame point loss + cam_pts3d_loss = self.criterion( + pred_cam_pts3d, gt_cam_pts3d, factor="points" + ) + cam_pts3d_loss = cam_pts3d_loss * self.cam_frame_points_loss_weight + cam_pts3d_losses.append(cam_pts3d_loss) + + if self.compute_world_frame_points_loss: + # Compute point loss + pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points") + pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight + pts3d_losses.append(pts3d_loss) + + # Handle ambiguous pixels + if self.ambiguous_loss_value > 0: + if not self.flatten_across_image_only: + depth_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + depth_losses[i], + ) + cam_pts3d_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + cam_pts3d_losses[i], + ) + if self.compute_world_frame_points_loss: + pts3d_losses[i] = torch.where( + ambiguous_masks[i][valid_masks[i]], + self.ambiguous_loss_value, + pts3d_losses[i], + ) + else: + depth_losses[i] = torch.where( + ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), + self.ambiguous_loss_value, + depth_losses[i], ) cam_pts3d_losses[i] = torch.where( ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), @@ -3129,6 +3425,7 @@ class FactoredGeometryScaleRegr3D(Criterion, MultiLoss): pose_quats_loss_weight=1, pose_trans_loss_weight=1, scale_loss_weight=1, + compute_absolute_pose_loss=True, compute_pairwise_relative_pose_loss=False, convert_predictions_to_view0_frame=False, compute_world_frame_points_loss=True, @@ -3162,6 +3459,7 @@ class FactoredGeometryScaleRegr3D(Criterion, MultiLoss): pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1. pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1. scale_loss_weight (float): Weight to use for the scale loss. Default: 1. + compute_absolute_pose_loss (bool): If True, compute the absolute pose loss. Default: True. compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the exhaustive pairwise relative poses. Default: False. convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame. @@ -3176,15 +3474,17 @@ class FactoredGeometryScaleRegr3D(Criterion, MultiLoss): self.loss_in_log = loss_in_log self.flatten_across_image_only = flatten_across_image_only self.depth_type_for_loss = depth_type_for_loss - assert self.depth_type_for_loss in ["depth_along_ray", "depth_z"], ( - "depth_type_for_loss must be one of ['depth_along_ray', 'depth_z']" - ) + assert self.depth_type_for_loss in [ + "depth_along_ray", + "depth_z", + ], "depth_type_for_loss must be one of ['depth_along_ray', 'depth_z']" self.cam_frame_points_loss_weight = cam_frame_points_loss_weight self.depth_loss_weight = depth_loss_weight self.ray_directions_loss_weight = ray_directions_loss_weight self.pose_quats_loss_weight = pose_quats_loss_weight self.pose_trans_loss_weight = pose_trans_loss_weight self.scale_loss_weight = scale_loss_weight + self.compute_absolute_pose_loss = compute_absolute_pose_loss self.compute_pairwise_relative_pose_loss = compute_pairwise_relative_pose_loss self.convert_predictions_to_view0_frame = convert_predictions_to_view0_frame self.compute_world_frame_points_loss = compute_world_frame_points_loss @@ -3543,7 +3843,35 @@ class FactoredGeometryScaleRegr3D(Criterion, MultiLoss): gt_pts3d = apply_log_to_norm(gt_pts3d) pred_pts3d = apply_log_to_norm(pred_pts3d) - if self.compute_pairwise_relative_pose_loss: + # Compute pose loss + if ( + self.compute_absolute_pose_loss + and self.compute_pairwise_relative_pose_loss + ): + # Compute the absolute pose loss + # Get the pose info for the current view + pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]] + gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]] + pred_pose_quats = pred_info[i]["pose_quats"] + gt_pose_quats = gt_info[i]["pose_quats"] + + # Compute pose translation loss + abs_pose_trans_loss = self.criterion( + pred_pose_trans, gt_pose_trans, factor="pose_trans" + ) + abs_pose_trans_loss = abs_pose_trans_loss * self.pose_trans_loss_weight + + # Compute pose rotation loss + # Handle quaternion two-to-one mapping + abs_pose_quats_loss = torch.minimum( + self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"), + self.criterion( + pred_pose_quats, -gt_pose_quats, factor="pose_quats" + ), + ) + abs_pose_quats_loss = abs_pose_quats_loss * self.pose_quats_loss_weight + + # Compute the pairwise relative pose loss # Get the inverse of current view predicted pose pred_inv_curr_view_pose_quats = quaternion_inverse( pred_info[i]["pose_quats"] @@ -3626,15 +3954,14 @@ class FactoredGeometryScaleRegr3D(Criterion, MultiLoss): gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0) # Compute pose translation loss - pose_trans_loss = self.criterion( + rel_pose_trans_loss = self.criterion( pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans" ) - pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight - pose_trans_losses.append(pose_trans_loss) + rel_pose_trans_loss = rel_pose_trans_loss * self.pose_trans_loss_weight # Compute pose rotation loss # Handle quaternion two-to-one mapping - pose_quats_loss = torch.minimum( + rel_pose_quats_loss = torch.minimum( self.criterion( pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats" ), @@ -3642,9 +3969,18 @@ class FactoredGeometryScaleRegr3D(Criterion, MultiLoss): pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats" ), ) - pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + rel_pose_quats_loss = rel_pose_quats_loss * self.pose_quats_loss_weight + + # Concatenate the absolute and relative pose losses together + pose_trans_loss = torch.cat( + [abs_pose_trans_loss, rel_pose_trans_loss], dim=0 + ) + pose_quats_loss = torch.cat( + [abs_pose_quats_loss, rel_pose_quats_loss], dim=0 + ) + pose_trans_losses.append(pose_trans_loss) pose_quats_losses.append(pose_quats_loss) - else: + elif self.compute_absolute_pose_loss: # Get the pose info for the current view pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]] gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]] @@ -3668,6 +4004,112 @@ class FactoredGeometryScaleRegr3D(Criterion, MultiLoss): ) pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight pose_quats_losses.append(pose_quats_loss) + elif self.compute_pairwise_relative_pose_loss: + # Get the inverse of current view predicted pose + pred_inv_curr_view_pose_quats = quaternion_inverse( + pred_info[i]["pose_quats"] + ) + pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( + pred_inv_curr_view_pose_quats + ) + pred_inv_curr_view_pose_trans = -1 * ein.einsum( + pred_inv_curr_view_pose_rot_mat, + pred_info[i]["pose_trans"], + "b i j, b j -> b i", + ) + + # Get the inverse of the current view GT pose + gt_inv_curr_view_pose_quats = quaternion_inverse( + gt_info[i]["pose_quats"] + ) + gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( + gt_inv_curr_view_pose_quats + ) + gt_inv_curr_view_pose_trans = -1 * ein.einsum( + gt_inv_curr_view_pose_rot_mat, + gt_info[i]["pose_trans"], + "b i j, b j -> b i", + ) + + # Get the other N-1 relative poses using the current pose as reference frame + pred_rel_pose_quats = [] + pred_rel_pose_trans = [] + gt_rel_pose_quats = [] + gt_rel_pose_trans = [] + for ov_idx in range(n_views): + if ov_idx == i: + continue + # Get the relative predicted pose + pred_ov_rel_pose_quats = quaternion_multiply( + pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"] + ) + pred_ov_rel_pose_trans = ( + ein.einsum( + pred_inv_curr_view_pose_rot_mat, + pred_info[ov_idx]["pose_trans"], + "b i j, b j -> b i", + ) + + pred_inv_curr_view_pose_trans + ) + + # Get the relative GT pose + gt_ov_rel_pose_quats = quaternion_multiply( + gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"] + ) + gt_ov_rel_pose_trans = ( + ein.einsum( + gt_inv_curr_view_pose_rot_mat, + gt_info[ov_idx]["pose_trans"], + "b i j, b j -> b i", + ) + + gt_inv_curr_view_pose_trans + ) + + # Get the valid translations using valid_norm_factor_masks for current view and other view + overall_valid_mask_for_trans = ( + valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx] + ) + + # Append the relative poses + pred_rel_pose_quats.append(pred_ov_rel_pose_quats) + pred_rel_pose_trans.append( + pred_ov_rel_pose_trans[overall_valid_mask_for_trans] + ) + gt_rel_pose_quats.append(gt_ov_rel_pose_quats) + gt_rel_pose_trans.append( + gt_ov_rel_pose_trans[overall_valid_mask_for_trans] + ) + + # Cat the N-1 relative poses along the batch dimension + pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0) + pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0) + gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0) + gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0) + + # Compute pose translation loss + pose_trans_loss = self.criterion( + pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans" + ) + pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight + pose_trans_losses.append(pose_trans_loss) + + # Compute pose rotation loss + # Handle quaternion two-to-one mapping + pose_quats_loss = torch.minimum( + self.criterion( + pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats" + ), + self.criterion( + pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats" + ), + ) + pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + pose_quats_losses.append(pose_quats_loss) + else: + # Error + raise ValueError( + "compute_absolute_pose_loss and compute_pairwise_relative_pose_loss cannot both be False" + ) # Compute ray direction loss ray_directions_loss = self.criterion( @@ -3822,6 +4264,7 @@ class FactoredGeometryScaleRegr3DPlusNormalGMLoss(FactoredGeometryScaleRegr3D): pose_quats_loss_weight=1, pose_trans_loss_weight=1, scale_loss_weight=1, + compute_absolute_pose_loss=True, compute_pairwise_relative_pose_loss=False, convert_predictions_to_view0_frame=False, compute_world_frame_points_loss=True, @@ -3858,6 +4301,7 @@ class FactoredGeometryScaleRegr3DPlusNormalGMLoss(FactoredGeometryScaleRegr3D): exhaustive pairwise relative poses. Default: False. convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame. Use this if the predictions are not already in the view0 frame. Default: False. + compute_absolute_pose_loss (bool): If True, compute the absolute pose loss. Default: True. compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True. world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1. apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data. @@ -3879,6 +4323,7 @@ class FactoredGeometryScaleRegr3DPlusNormalGMLoss(FactoredGeometryScaleRegr3D): pose_quats_loss_weight=pose_quats_loss_weight, pose_trans_loss_weight=pose_trans_loss_weight, scale_loss_weight=scale_loss_weight, + compute_absolute_pose_loss=compute_absolute_pose_loss, compute_pairwise_relative_pose_loss=compute_pairwise_relative_pose_loss, convert_predictions_to_view0_frame=convert_predictions_to_view0_frame, compute_world_frame_points_loss=compute_world_frame_points_loss, @@ -4010,7 +4455,35 @@ class FactoredGeometryScaleRegr3DPlusNormalGMLoss(FactoredGeometryScaleRegr3D): gt_pts3d = apply_log_to_norm(gt_pts3d) pred_pts3d = apply_log_to_norm(pred_pts3d) - if self.compute_pairwise_relative_pose_loss: + # Compute pose loss + if ( + self.compute_absolute_pose_loss + and self.compute_pairwise_relative_pose_loss + ): + # Compute the absolute pose loss + # Get the pose info for the current view + pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]] + gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]] + pred_pose_quats = pred_info[i]["pose_quats"] + gt_pose_quats = gt_info[i]["pose_quats"] + + # Compute pose translation loss + abs_pose_trans_loss = self.criterion( + pred_pose_trans, gt_pose_trans, factor="pose_trans" + ) + abs_pose_trans_loss = abs_pose_trans_loss * self.pose_trans_loss_weight + + # Compute pose rotation loss + # Handle quaternion two-to-one mapping + abs_pose_quats_loss = torch.minimum( + self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"), + self.criterion( + pred_pose_quats, -gt_pose_quats, factor="pose_quats" + ), + ) + abs_pose_quats_loss = abs_pose_quats_loss * self.pose_quats_loss_weight + + # Compute the pairwise relative pose loss # Get the inverse of current view predicted pose pred_inv_curr_view_pose_quats = quaternion_inverse( pred_info[i]["pose_quats"] @@ -4093,15 +4566,14 @@ class FactoredGeometryScaleRegr3DPlusNormalGMLoss(FactoredGeometryScaleRegr3D): gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0) # Compute pose translation loss - pose_trans_loss = self.criterion( + rel_pose_trans_loss = self.criterion( pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans" ) - pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight - pose_trans_losses.append(pose_trans_loss) + rel_pose_trans_loss = rel_pose_trans_loss * self.pose_trans_loss_weight # Compute pose rotation loss # Handle quaternion two-to-one mapping - pose_quats_loss = torch.minimum( + rel_pose_quats_loss = torch.minimum( self.criterion( pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats" ), @@ -4109,9 +4581,18 @@ class FactoredGeometryScaleRegr3DPlusNormalGMLoss(FactoredGeometryScaleRegr3D): pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats" ), ) - pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + rel_pose_quats_loss = rel_pose_quats_loss * self.pose_quats_loss_weight + + # Concatenate the absolute and relative pose losses together + pose_trans_loss = torch.cat( + [abs_pose_trans_loss, rel_pose_trans_loss], dim=0 + ) + pose_quats_loss = torch.cat( + [abs_pose_quats_loss, rel_pose_quats_loss], dim=0 + ) + pose_trans_losses.append(pose_trans_loss) pose_quats_losses.append(pose_quats_loss) - else: + elif self.compute_absolute_pose_loss: # Get the pose info for the current view pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]] gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]] @@ -4135,6 +4616,112 @@ class FactoredGeometryScaleRegr3DPlusNormalGMLoss(FactoredGeometryScaleRegr3D): ) pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight pose_quats_losses.append(pose_quats_loss) + elif self.compute_pairwise_relative_pose_loss: + # Get the inverse of current view predicted pose + pred_inv_curr_view_pose_quats = quaternion_inverse( + pred_info[i]["pose_quats"] + ) + pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( + pred_inv_curr_view_pose_quats + ) + pred_inv_curr_view_pose_trans = -1 * ein.einsum( + pred_inv_curr_view_pose_rot_mat, + pred_info[i]["pose_trans"], + "b i j, b j -> b i", + ) + + # Get the inverse of the current view GT pose + gt_inv_curr_view_pose_quats = quaternion_inverse( + gt_info[i]["pose_quats"] + ) + gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( + gt_inv_curr_view_pose_quats + ) + gt_inv_curr_view_pose_trans = -1 * ein.einsum( + gt_inv_curr_view_pose_rot_mat, + gt_info[i]["pose_trans"], + "b i j, b j -> b i", + ) + + # Get the other N-1 relative poses using the current pose as reference frame + pred_rel_pose_quats = [] + pred_rel_pose_trans = [] + gt_rel_pose_quats = [] + gt_rel_pose_trans = [] + for ov_idx in range(n_views): + if ov_idx == i: + continue + # Get the relative predicted pose + pred_ov_rel_pose_quats = quaternion_multiply( + pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"] + ) + pred_ov_rel_pose_trans = ( + ein.einsum( + pred_inv_curr_view_pose_rot_mat, + pred_info[ov_idx]["pose_trans"], + "b i j, b j -> b i", + ) + + pred_inv_curr_view_pose_trans + ) + + # Get the relative GT pose + gt_ov_rel_pose_quats = quaternion_multiply( + gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"] + ) + gt_ov_rel_pose_trans = ( + ein.einsum( + gt_inv_curr_view_pose_rot_mat, + gt_info[ov_idx]["pose_trans"], + "b i j, b j -> b i", + ) + + gt_inv_curr_view_pose_trans + ) + + # Get the valid translations using valid_norm_factor_masks for current view and other view + overall_valid_mask_for_trans = ( + valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx] + ) + + # Append the relative poses + pred_rel_pose_quats.append(pred_ov_rel_pose_quats) + pred_rel_pose_trans.append( + pred_ov_rel_pose_trans[overall_valid_mask_for_trans] + ) + gt_rel_pose_quats.append(gt_ov_rel_pose_quats) + gt_rel_pose_trans.append( + gt_ov_rel_pose_trans[overall_valid_mask_for_trans] + ) + + # Cat the N-1 relative poses along the batch dimension + pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0) + pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0) + gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0) + gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0) + + # Compute pose translation loss + pose_trans_loss = self.criterion( + pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans" + ) + pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight + pose_trans_losses.append(pose_trans_loss) + + # Compute pose rotation loss + # Handle quaternion two-to-one mapping + pose_quats_loss = torch.minimum( + self.criterion( + pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats" + ), + self.criterion( + pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats" + ), + ) + pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight + pose_quats_losses.append(pose_quats_loss) + else: + # Error + raise ValueError( + "compute_absolute_pose_loss and compute_pairwise_relative_pose_loss cannot both be False" + ) # Compute ray direction loss ray_directions_loss = self.criterion( @@ -4333,9 +4920,10 @@ class DisentangledFactoredGeometryScaleRegr3D(Criterion, MultiLoss): self.loss_in_log = loss_in_log self.flatten_across_image_only = flatten_across_image_only self.depth_type_for_loss = depth_type_for_loss - assert self.depth_type_for_loss in ["depth_along_ray", "depth_z"], ( - "depth_type_for_loss must be one of ['depth_along_ray', 'depth_z']" - ) + assert self.depth_type_for_loss in [ + "depth_along_ray", + "depth_z", + ], "depth_type_for_loss must be one of ['depth_along_ray', 'depth_z']" self.depth_loss_weight = depth_loss_weight self.ray_directions_loss_weight = ray_directions_loss_weight self.pose_quats_loss_weight = pose_quats_loss_weight diff --git a/mapanything/train/profile_dataloading.py b/mapanything/train/profile_dataloading.py index 8fb9b7615be4cbb56e07d2d182cbbd16d103070e..d73c414f5c8bb097006d357d721a6622f441e7ec 100644 --- a/mapanything/train/profile_dataloading.py +++ b/mapanything/train/profile_dataloading.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + """ Debug script to profile dataloading for MapAnything training. diff --git a/mapanything/train/training.py b/mapanything/train/training.py index f5dd83cd08caa07c9a893aba6a8e07ebf7c05c26..5e14e8edfee5782c454185ec7954d65affac9747 100644 --- a/mapanything/train/training.py +++ b/mapanything/train/training.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + """ Training Code for MapAnything. diff --git a/requirements.txt b/requirements.txt index 90e37a3ef4741cff67e074ea9b8047fe64d5cc47..6bfc23af0e52c9f9b0cd3e3b77391df5777b68f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,4 +20,4 @@ psutil pillow-heif tqdm safetensors -uniception==0.1.4 \ No newline at end of file +uniception==0.1.6 \ No newline at end of file