BFZD233 commited on
Commit
5b3b0f4
·
1 Parent(s): a209eeb
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +5 -0
  2. Depth-Anything-V2/DA-2K.md +51 -0
  3. Depth-Anything-V2/LICENSE +201 -0
  4. Depth-Anything-V2/README.md +201 -0
  5. Depth-Anything-V2/app.py +88 -0
  6. Depth-Anything-V2/depth_anything_v2/__pycache__/dinov2.cpython-310.pyc +0 -0
  7. Depth-Anything-V2/depth_anything_v2/__pycache__/dpt.cpython-310.pyc +0 -0
  8. Depth-Anything-V2/depth_anything_v2/dinov2.py +415 -0
  9. Depth-Anything-V2/depth_anything_v2/dinov2_layers/__init__.py +11 -0
  10. Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-310.pyc +0 -0
  11. Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-310.pyc +0 -0
  12. Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/block.cpython-310.pyc +0 -0
  13. Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-310.pyc +0 -0
  14. Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc +0 -0
  15. Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-310.pyc +0 -0
  16. Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc +0 -0
  17. Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc +0 -0
  18. Depth-Anything-V2/depth_anything_v2/dinov2_layers/attention.py +83 -0
  19. Depth-Anything-V2/depth_anything_v2/dinov2_layers/block.py +252 -0
  20. Depth-Anything-V2/depth_anything_v2/dinov2_layers/drop_path.py +35 -0
  21. Depth-Anything-V2/depth_anything_v2/dinov2_layers/layer_scale.py +28 -0
  22. Depth-Anything-V2/depth_anything_v2/dinov2_layers/mlp.py +41 -0
  23. Depth-Anything-V2/depth_anything_v2/dinov2_layers/patch_embed.py +89 -0
  24. Depth-Anything-V2/depth_anything_v2/dinov2_layers/swiglu_ffn.py +63 -0
  25. Depth-Anything-V2/depth_anything_v2/dpt.py +233 -0
  26. Depth-Anything-V2/depth_anything_v2/util/__pycache__/blocks.cpython-310.pyc +0 -0
  27. Depth-Anything-V2/depth_anything_v2/util/__pycache__/transform.cpython-310.pyc +0 -0
  28. Depth-Anything-V2/depth_anything_v2/util/blocks.py +148 -0
  29. Depth-Anything-V2/depth_anything_v2/util/transform.py +158 -0
  30. Depth-Anything-V2/requirements.txt +6 -0
  31. Depth-Anything-V2/run.py +73 -0
  32. Depth-Anything-V2/run_video.py +92 -0
  33. LICENSE +21 -0
  34. README.md +369 -6
  35. abs_cost/abs_cost_kernel.cu +191 -0
  36. app.py +103 -0
  37. core/ManStereo.py +302 -0
  38. core/__init__.py +0 -0
  39. core/__pycache__/__init__.cpython-310.pyc +0 -0
  40. core/__pycache__/confidence.cpython-310.pyc +0 -0
  41. core/__pycache__/corr.cpython-310.pyc +0 -0
  42. core/__pycache__/extractor.cpython-310.pyc +0 -0
  43. core/__pycache__/extractor_depthany.cpython-310.pyc +0 -0
  44. core/__pycache__/fusion.cpython-310.pyc +0 -0
  45. core/__pycache__/geometry.cpython-310.pyc +0 -0
  46. core/__pycache__/raft_stereo_depthbeta_refine.cpython-310.pyc +0 -0
  47. core/__pycache__/update_disp.cpython-310.pyc +0 -0
  48. core/confidence.py +169 -0
  49. core/corr.py +309 -0
  50. core/extractor.py +300 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ dav2_models
2
+ ckpts
3
+ mast3r
4
+ Metric3D
5
+ Depth-Anything-V2/metric_depth
Depth-Anything-V2/DA-2K.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DA-2K Evaluation Benchmark
2
+
3
+ ## Introduction
4
+
5
+ ![DA-2K](assets/DA-2K.png)
6
+
7
+ DA-2K is proposed in [Depth Anything V2](https://depth-anything-v2.github.io) to evaluate the relative depth estimation capability. It encompasses eight representative scenarios of `indoor`, `outdoor`, `non_real`, `transparent_reflective`, `adverse_style`, `aerial`, `underwater`, and `object`. It consists of 1K diverse high-quality images and 2K precise pair-wise relative depth annotations.
8
+
9
+ Please refer to our [paper](https://arxiv.org/abs/2406.09414) for details in constructing this benchmark.
10
+
11
+
12
+ ## Usage
13
+
14
+ Please first [download the benchmark](https://huggingface.co/datasets/depth-anything/DA-2K/tree/main).
15
+
16
+ All annotations are stored in `annotations.json`. The annotation file is a JSON object where each key is the path to an image file, and the value is a list of annotations associated with that image. Each annotation describes two points and identifies which point is closer to the camera. The structure is detailed below:
17
+
18
+ ```
19
+ {
20
+ "image_path": [
21
+ {
22
+ "point1": [h1, w1], # (vertical position, horizontal position)
23
+ "point2": [h2, w2], # (vertical position, horizontal position)
24
+ "closer_point": "point1" # we always set "point1" as the closer one
25
+ },
26
+ ...
27
+ ],
28
+ ...
29
+ }
30
+ ```
31
+
32
+ To visualize the annotations:
33
+ ```bash
34
+ python visualize.py [--scene-type <type>]
35
+ ```
36
+
37
+ **Options**
38
+ - `--scene-type <type>` (optional): Specify the scene type (`indoor`, `outdoor`, `non_real`, `transparent_reflective`, `adverse_style`, `aerial`, `underwater`, and `object`). Skip this argument or set <type> as `""` to include all scene types.
39
+
40
+ ## Citation
41
+
42
+ If you find this benchmark useful, please consider citing:
43
+
44
+ ```bibtex
45
+ @article{depth_anything_v2,
46
+ title={Depth Anything V2},
47
+ author={Yang, Lihe and Kang, Bingyi and Huang, Zilong and Zhao, Zhen and Xu, Xiaogang and Feng, Jiashi and Zhao, Hengshuang},
48
+ journal={arXiv:2406.09414},
49
+ year={2024}
50
+ }
51
+ ```
Depth-Anything-V2/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
Depth-Anything-V2/README.md ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h1>Depth Anything V2</h1>
3
+
4
+ [**Lihe Yang**](https://liheyoung.github.io/)<sup>1</sup> · [**Bingyi Kang**](https://bingykang.github.io/)<sup>2&dagger;</sup> · [**Zilong Huang**](http://speedinghzl.github.io/)<sup>2</sup>
5
+ <br>
6
+ [**Zhen Zhao**](http://zhaozhen.me/) · [**Xiaogang Xu**](https://xiaogang00.github.io/) · [**Jiashi Feng**](https://sites.google.com/site/jshfeng/)<sup>2</sup> · [**Hengshuang Zhao**](https://hszhao.github.io/)<sup>1*</sup>
7
+
8
+ <sup>1</sup>HKU&emsp;&emsp;&emsp;<sup>2</sup>TikTok
9
+ <br>
10
+ &dagger;project lead&emsp;*corresponding author
11
+
12
+ <a href="https://arxiv.org/abs/2406.09414"><img src='https://img.shields.io/badge/arXiv-Depth Anything V2-red' alt='Paper PDF'></a>
13
+ <a href='https://depth-anything-v2.github.io'><img src='https://img.shields.io/badge/Project_Page-Depth Anything V2-green' alt='Project Page'></a>
14
+ <a href='https://huggingface.co/spaces/depth-anything/Depth-Anything-V2'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo-blue'></a>
15
+ <a href='https://huggingface.co/datasets/depth-anything/DA-2K'><img src='https://img.shields.io/badge/Benchmark-DA--2K-yellow' alt='Benchmark'></a>
16
+ </div>
17
+
18
+ This work presents Depth Anything V2. It significantly outperforms [V1](https://github.com/LiheYoung/Depth-Anything) in fine-grained details and robustness. Compared with SD-based models, it enjoys faster inference speed, fewer parameters, and higher depth accuracy.
19
+
20
+ ![teaser](assets/teaser.png)
21
+
22
+
23
+ ## News
24
+ - **2025-01-22:** [Video Depth Anything](https://videodepthanything.github.io) has been released. It generates consistent depth maps for super-long videos (e.g., over 5 minutes).
25
+ - **2024-12-22:** [Prompt Depth Anything](https://promptda.github.io/) has been released. It supports 4K resolution metric depth estimation when low-res LiDAR is used to prompt the DA models.
26
+ - **2024-07-06:** Depth Anything V2 is supported in [Transformers](https://github.com/huggingface/transformers/). See the [instructions](https://huggingface.co/docs/transformers/main/en/model_doc/depth_anything_v2) for convenient usage.
27
+ - **2024-06-25:** Depth Anything is integrated into [Apple Core ML Models](https://developer.apple.com/machine-learning/models/). See the instructions ([V1](https://huggingface.co/apple/coreml-depth-anything-small), [V2](https://huggingface.co/apple/coreml-depth-anything-v2-small)) for usage.
28
+ - **2024-06-22:** We release [smaller metric depth models](https://github.com/DepthAnything/Depth-Anything-V2/tree/main/metric_depth#pre-trained-models) based on Depth-Anything-V2-Small and Base.
29
+ - **2024-06-20:** Our repository and project page are flagged by GitHub and removed from the public for 6 days. Sorry for the inconvenience.
30
+ - **2024-06-14:** Paper, project page, code, models, demo, and benchmark are all released.
31
+
32
+
33
+ ## Pre-trained Models
34
+
35
+ We provide **four models** of varying scales for robust relative depth estimation:
36
+
37
+ | Model | Params | Checkpoint |
38
+ |:-|-:|:-:|
39
+ | Depth-Anything-V2-Small | 24.8M | [Download](https://huggingface.co/depth-anything/Depth-Anything-V2-Small/resolve/main/depth_anything_v2_vits.pth?download=true) |
40
+ | Depth-Anything-V2-Base | 97.5M | [Download](https://huggingface.co/depth-anything/Depth-Anything-V2-Base/resolve/main/depth_anything_v2_vitb.pth?download=true) |
41
+ | Depth-Anything-V2-Large | 335.3M | [Download](https://huggingface.co/depth-anything/Depth-Anything-V2-Large/resolve/main/depth_anything_v2_vitl.pth?download=true) |
42
+ | Depth-Anything-V2-Giant | 1.3B | Coming soon |
43
+
44
+
45
+ ## Usage
46
+
47
+ ### Prepraration
48
+
49
+ ```bash
50
+ git clone https://github.com/DepthAnything/Depth-Anything-V2
51
+ cd Depth-Anything-V2
52
+ pip install -r requirements.txt
53
+ ```
54
+
55
+ Download the checkpoints listed [here](#pre-trained-models) and put them under the `checkpoints` directory.
56
+
57
+ ### Use our models
58
+ ```python
59
+ import cv2
60
+ import torch
61
+
62
+ from depth_anything_v2.dpt import DepthAnythingV2
63
+
64
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
65
+
66
+ model_configs = {
67
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
68
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
69
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
70
+ 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
71
+ }
72
+
73
+ encoder = 'vitl' # or 'vits', 'vitb', 'vitg'
74
+
75
+ model = DepthAnythingV2(**model_configs[encoder])
76
+ model.load_state_dict(torch.load(f'checkpoints/depth_anything_v2_{encoder}.pth', map_location='cpu'))
77
+ model = model.to(DEVICE).eval()
78
+
79
+ raw_img = cv2.imread('your/image/path')
80
+ depth = model.infer_image(raw_img) # HxW raw depth map in numpy
81
+ ```
82
+
83
+ If you do not want to clone this repository, you can also load our models through [Transformers](https://github.com/huggingface/transformers/). Below is a simple code snippet. Please refer to the [official page](https://huggingface.co/docs/transformers/main/en/model_doc/depth_anything_v2) for more details.
84
+
85
+ - Note 1: Make sure you can connect to Hugging Face and have installed the latest Transformers.
86
+ - Note 2: Due to the [upsampling difference](https://github.com/huggingface/transformers/pull/31522#issuecomment-2184123463) between OpenCV (we used) and Pillow (HF used), predictions may differ slightly. So you are more recommended to use our models through the way introduced above.
87
+ ```python
88
+ from transformers import pipeline
89
+ from PIL import Image
90
+
91
+ pipe = pipeline(task="depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")
92
+ image = Image.open('your/image/path')
93
+ depth = pipe(image)["depth"]
94
+ ```
95
+
96
+ ### Running script on *images*
97
+
98
+ ```bash
99
+ python run.py \
100
+ --encoder <vits | vitb | vitl | vitg> \
101
+ --img-path <path> --outdir <outdir> \
102
+ [--input-size <size>] [--pred-only] [--grayscale]
103
+ ```
104
+ Options:
105
+ - `--img-path`: You can either 1) point it to an image directory storing all interested images, 2) point it to a single image, or 3) point it to a text file storing all image paths.
106
+ - `--input-size` (optional): By default, we use input size `518` for model inference. ***You can increase the size for even more fine-grained results.***
107
+ - `--pred-only` (optional): Only save the predicted depth map, without raw image.
108
+ - `--grayscale` (optional): Save the grayscale depth map, without applying color palette.
109
+
110
+ For example:
111
+ ```bash
112
+ python run.py --encoder vitl --img-path assets/examples --outdir depth_vis
113
+ ```
114
+
115
+ ### Running script on *videos*
116
+
117
+ ```bash
118
+ python run_video.py \
119
+ --encoder <vits | vitb | vitl | vitg> \
120
+ --video-path assets/examples_video --outdir video_depth_vis \
121
+ [--input-size <size>] [--pred-only] [--grayscale]
122
+ ```
123
+
124
+ ***Our larger model has better temporal consistency on videos.***
125
+
126
+ ### Gradio demo
127
+
128
+ To use our gradio demo locally:
129
+
130
+ ```bash
131
+ python app.py
132
+ ```
133
+
134
+ You can also try our [online demo](https://huggingface.co/spaces/Depth-Anything/Depth-Anything-V2).
135
+
136
+ ***Note: Compared to V1, we have made a minor modification to the DINOv2-DPT architecture (originating from this [issue](https://github.com/LiheYoung/Depth-Anything/issues/81)).*** In V1, we *unintentionally* used features from the last four layers of DINOv2 for decoding. In V2, we use [intermediate features](https://github.com/DepthAnything/Depth-Anything-V2/blob/2cbc36a8ce2cec41d38ee51153f112e87c8e42d8/depth_anything_v2/dpt.py#L164-L169) instead. Although this modification did not improve details or accuracy, we decided to follow this common practice.
137
+
138
+
139
+ ## Fine-tuned to Metric Depth Estimation
140
+
141
+ Please refer to [metric depth estimation](./metric_depth).
142
+
143
+
144
+ ## DA-2K Evaluation Benchmark
145
+
146
+ Please refer to [DA-2K benchmark](./DA-2K.md).
147
+
148
+
149
+ ## Community Support
150
+
151
+ **We sincerely appreciate all the community support for our Depth Anything series. Thank you a lot!**
152
+
153
+ - Apple Core ML:
154
+ - https://developer.apple.com/machine-learning/models
155
+ - https://huggingface.co/apple/coreml-depth-anything-v2-small
156
+ - https://huggingface.co/apple/coreml-depth-anything-small
157
+ - Transformers:
158
+ - https://huggingface.co/docs/transformers/main/en/model_doc/depth_anything_v2
159
+ - https://huggingface.co/docs/transformers/main/en/model_doc/depth_anything
160
+ - TensorRT:
161
+ - https://github.com/spacewalk01/depth-anything-tensorrt
162
+ - https://github.com/zhujiajian98/Depth-Anythingv2-TensorRT-python
163
+ - ONNX: https://github.com/fabio-sim/Depth-Anything-ONNX
164
+ - ComfyUI: https://github.com/kijai/ComfyUI-DepthAnythingV2
165
+ - Transformers.js (real-time depth in web): https://huggingface.co/spaces/Xenova/webgpu-realtime-depth-estimation
166
+ - Android:
167
+ - https://github.com/shubham0204/Depth-Anything-Android
168
+ - https://github.com/FeiGeChuanShu/ncnn-android-depth_anything
169
+
170
+
171
+ ## Acknowledgement
172
+
173
+ We are sincerely grateful to the awesome Hugging Face team ([@Pedro Cuenca](https://huggingface.co/pcuenq), [@Niels Rogge](https://huggingface.co/nielsr), [@Merve Noyan](https://huggingface.co/merve), [@Amy Roberts](https://huggingface.co/amyeroberts), et al.) for their huge efforts in supporting our models in Transformers and Apple Core ML.
174
+
175
+ We also thank the [DINOv2](https://github.com/facebookresearch/dinov2) team for contributing such impressive models to our community.
176
+
177
+
178
+ ## LICENSE
179
+
180
+ Depth-Anything-V2-Small model is under the Apache-2.0 license. Depth-Anything-V2-Base/Large/Giant models are under the CC-BY-NC-4.0 license.
181
+
182
+
183
+ ## Citation
184
+
185
+ If you find this project useful, please consider citing:
186
+
187
+ ```bibtex
188
+ @article{depth_anything_v2,
189
+ title={Depth Anything V2},
190
+ author={Yang, Lihe and Kang, Bingyi and Huang, Zilong and Zhao, Zhen and Xu, Xiaogang and Feng, Jiashi and Zhao, Hengshuang},
191
+ journal={arXiv:2406.09414},
192
+ year={2024}
193
+ }
194
+
195
+ @inproceedings{depth_anything_v1,
196
+ title={Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data},
197
+ author={Yang, Lihe and Kang, Bingyi and Huang, Zilong and Xu, Xiaogang and Feng, Jiashi and Zhao, Hengshuang},
198
+ booktitle={CVPR},
199
+ year={2024}
200
+ }
201
+ ```
Depth-Anything-V2/app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import gradio as gr
3
+ import matplotlib
4
+ import numpy as np
5
+ from PIL import Image
6
+ import torch
7
+ import tempfile
8
+ from gradio_imageslider import ImageSlider
9
+
10
+ from depth_anything_v2.dpt import DepthAnythingV2
11
+
12
+ css = """
13
+ #img-display-container {
14
+ max-height: 100vh;
15
+ }
16
+ #img-display-input {
17
+ max-height: 80vh;
18
+ }
19
+ #img-display-output {
20
+ max-height: 80vh;
21
+ }
22
+ #download {
23
+ height: 62px;
24
+ }
25
+ """
26
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
27
+ model_configs = {
28
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
29
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
30
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
31
+ 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
32
+ }
33
+ encoder = 'vitl'
34
+ model = DepthAnythingV2(**model_configs[encoder])
35
+ state_dict = torch.load(f'checkpoints/depth_anything_v2_{encoder}.pth', map_location="cpu")
36
+ model.load_state_dict(state_dict)
37
+ model = model.to(DEVICE).eval()
38
+
39
+ title = "# Depth Anything V2"
40
+ description = """Official demo for **Depth Anything V2**.
41
+ Please refer to our [paper](https://arxiv.org/abs/2406.09414), [project page](https://depth-anything-v2.github.io), or [github](https://github.com/DepthAnything/Depth-Anything-V2) for more details."""
42
+
43
+ def predict_depth(image):
44
+ return model.infer_image(image)
45
+
46
+ with gr.Blocks(css=css) as demo:
47
+ gr.Markdown(title)
48
+ gr.Markdown(description)
49
+ gr.Markdown("### Depth Prediction demo")
50
+
51
+ with gr.Row():
52
+ input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
53
+ depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output', position=0.5)
54
+ submit = gr.Button(value="Compute Depth")
55
+ gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download",)
56
+ raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download",)
57
+
58
+ cmap = matplotlib.colormaps.get_cmap('Spectral_r')
59
+
60
+ def on_submit(image):
61
+ original_image = image.copy()
62
+
63
+ h, w = image.shape[:2]
64
+
65
+ depth = predict_depth(image[:, :, ::-1])
66
+
67
+ raw_depth = Image.fromarray(depth.astype('uint16'))
68
+ tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
69
+ raw_depth.save(tmp_raw_depth.name)
70
+
71
+ depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
72
+ depth = depth.astype(np.uint8)
73
+ colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
74
+
75
+ gray_depth = Image.fromarray(depth)
76
+ tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
77
+ gray_depth.save(tmp_gray_depth.name)
78
+
79
+ return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name]
80
+
81
+ submit.click(on_submit, inputs=[input_image], outputs=[depth_image_slider, gray_depth_file, raw_file])
82
+
83
+ example_files = glob.glob('assets/examples/*')
84
+ examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[depth_image_slider, gray_depth_file, raw_file], fn=on_submit)
85
+
86
+
87
+ if __name__ == '__main__':
88
+ demo.queue().launch()
Depth-Anything-V2/depth_anything_v2/__pycache__/dinov2.cpython-310.pyc ADDED
Binary file (12.2 kB). View file
 
Depth-Anything-V2/depth_anything_v2/__pycache__/dpt.cpython-310.pyc ADDED
Binary file (5.99 kB). View file
 
Depth-Anything-V2/depth_anything_v2/dinov2.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ import logging
13
+ from typing import Sequence, Tuple, Union, Callable
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.utils.checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+
20
+ from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
21
+
22
+
23
+ logger = logging.getLogger("dinov2")
24
+
25
+
26
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
27
+ if not depth_first and include_root:
28
+ fn(module=module, name=name)
29
+ for child_name, child_module in module.named_children():
30
+ child_name = ".".join((name, child_name)) if name else child_name
31
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
32
+ if depth_first and include_root:
33
+ fn(module=module, name=name)
34
+ return module
35
+
36
+
37
+ class BlockChunk(nn.ModuleList):
38
+ def forward(self, x):
39
+ for b in self:
40
+ x = b(x)
41
+ return x
42
+
43
+
44
+ class DinoVisionTransformer(nn.Module):
45
+ def __init__(
46
+ self,
47
+ img_size=224,
48
+ patch_size=16,
49
+ in_chans=3,
50
+ embed_dim=768,
51
+ depth=12,
52
+ num_heads=12,
53
+ mlp_ratio=4.0,
54
+ qkv_bias=True,
55
+ ffn_bias=True,
56
+ proj_bias=True,
57
+ drop_path_rate=0.0,
58
+ drop_path_uniform=False,
59
+ init_values=None, # for layerscale: None or 0 => no layerscale
60
+ embed_layer=PatchEmbed,
61
+ act_layer=nn.GELU,
62
+ block_fn=Block,
63
+ ffn_layer="mlp",
64
+ block_chunks=1,
65
+ num_register_tokens=0,
66
+ interpolate_antialias=False,
67
+ interpolate_offset=0.1,
68
+ ):
69
+ """
70
+ Args:
71
+ img_size (int, tuple): input image size
72
+ patch_size (int, tuple): patch size
73
+ in_chans (int): number of input channels
74
+ embed_dim (int): embedding dimension
75
+ depth (int): depth of transformer
76
+ num_heads (int): number of attention heads
77
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
78
+ qkv_bias (bool): enable bias for qkv if True
79
+ proj_bias (bool): enable bias for proj in attn if True
80
+ ffn_bias (bool): enable bias for ffn if True
81
+ drop_path_rate (float): stochastic depth rate
82
+ drop_path_uniform (bool): apply uniform drop rate across blocks
83
+ weight_init (str): weight init scheme
84
+ init_values (float): layer-scale init values
85
+ embed_layer (nn.Module): patch embedding layer
86
+ act_layer (nn.Module): MLP activation layer
87
+ block_fn (nn.Module): transformer block class
88
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
89
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
90
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
91
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
92
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
93
+ """
94
+ super().__init__()
95
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
96
+
97
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
98
+ self.num_tokens = 1
99
+ self.n_blocks = depth
100
+ self.num_heads = num_heads
101
+ self.patch_size = patch_size
102
+ self.num_register_tokens = num_register_tokens
103
+ self.interpolate_antialias = interpolate_antialias
104
+ self.interpolate_offset = interpolate_offset
105
+
106
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
107
+ num_patches = self.patch_embed.num_patches
108
+
109
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
110
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
111
+ assert num_register_tokens >= 0
112
+ self.register_tokens = (
113
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
114
+ )
115
+
116
+ if drop_path_uniform is True:
117
+ dpr = [drop_path_rate] * depth
118
+ else:
119
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
120
+
121
+ if ffn_layer == "mlp":
122
+ logger.info("using MLP layer as FFN")
123
+ ffn_layer = Mlp
124
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
125
+ logger.info("using SwiGLU layer as FFN")
126
+ ffn_layer = SwiGLUFFNFused
127
+ elif ffn_layer == "identity":
128
+ logger.info("using Identity layer as FFN")
129
+
130
+ def f(*args, **kwargs):
131
+ return nn.Identity()
132
+
133
+ ffn_layer = f
134
+ else:
135
+ raise NotImplementedError
136
+
137
+ blocks_list = [
138
+ block_fn(
139
+ dim=embed_dim,
140
+ num_heads=num_heads,
141
+ mlp_ratio=mlp_ratio,
142
+ qkv_bias=qkv_bias,
143
+ proj_bias=proj_bias,
144
+ ffn_bias=ffn_bias,
145
+ drop_path=dpr[i],
146
+ norm_layer=norm_layer,
147
+ act_layer=act_layer,
148
+ ffn_layer=ffn_layer,
149
+ init_values=init_values,
150
+ )
151
+ for i in range(depth)
152
+ ]
153
+ if block_chunks > 0:
154
+ self.chunked_blocks = True
155
+ chunked_blocks = []
156
+ chunksize = depth // block_chunks
157
+ for i in range(0, depth, chunksize):
158
+ # this is to keep the block index consistent if we chunk the block list
159
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
160
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
161
+ else:
162
+ self.chunked_blocks = False
163
+ self.blocks = nn.ModuleList(blocks_list)
164
+
165
+ self.norm = norm_layer(embed_dim)
166
+ self.head = nn.Identity()
167
+
168
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
169
+
170
+ self.init_weights()
171
+
172
+ def init_weights(self):
173
+ trunc_normal_(self.pos_embed, std=0.02)
174
+ nn.init.normal_(self.cls_token, std=1e-6)
175
+ if self.register_tokens is not None:
176
+ nn.init.normal_(self.register_tokens, std=1e-6)
177
+ named_apply(init_weights_vit_timm, self)
178
+
179
+ def interpolate_pos_encoding(self, x, w, h):
180
+ previous_dtype = x.dtype
181
+ npatch = x.shape[1] - 1
182
+ N = self.pos_embed.shape[1] - 1
183
+ if npatch == N and w == h:
184
+ return self.pos_embed
185
+ pos_embed = self.pos_embed.float()
186
+ class_pos_embed = pos_embed[:, 0]
187
+ patch_pos_embed = pos_embed[:, 1:]
188
+ dim = x.shape[-1]
189
+ w0 = w // self.patch_size
190
+ h0 = h // self.patch_size
191
+ # we add a small number to avoid floating point error in the interpolation
192
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
193
+ # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
194
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
195
+ # w0, h0 = w0 + 0.1, h0 + 0.1
196
+
197
+ sqrt_N = math.sqrt(N)
198
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
199
+ patch_pos_embed = nn.functional.interpolate(
200
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
201
+ scale_factor=(sx, sy),
202
+ # (int(w0), int(h0)), # to solve the upsampling shape issue
203
+ mode="bicubic",
204
+ antialias=self.interpolate_antialias
205
+ )
206
+
207
+ assert int(w0) == patch_pos_embed.shape[-2]
208
+ assert int(h0) == patch_pos_embed.shape[-1]
209
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
210
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
211
+
212
+ def prepare_tokens_with_masks(self, x, masks=None):
213
+ B, nc, w, h = x.shape
214
+ x = self.patch_embed(x)
215
+ if masks is not None:
216
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
217
+
218
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
219
+ x = x + self.interpolate_pos_encoding(x, w, h)
220
+
221
+ if self.register_tokens is not None:
222
+ x = torch.cat(
223
+ (
224
+ x[:, :1],
225
+ self.register_tokens.expand(x.shape[0], -1, -1),
226
+ x[:, 1:],
227
+ ),
228
+ dim=1,
229
+ )
230
+
231
+ return x
232
+
233
+ def forward_features_list(self, x_list, masks_list):
234
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
235
+ for blk in self.blocks:
236
+ x = blk(x)
237
+
238
+ all_x = x
239
+ output = []
240
+ for x, masks in zip(all_x, masks_list):
241
+ x_norm = self.norm(x)
242
+ output.append(
243
+ {
244
+ "x_norm_clstoken": x_norm[:, 0],
245
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
246
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
247
+ "x_prenorm": x,
248
+ "masks": masks,
249
+ }
250
+ )
251
+ return output
252
+
253
+ def forward_features(self, x, masks=None):
254
+ if isinstance(x, list):
255
+ return self.forward_features_list(x, masks)
256
+
257
+ x = self.prepare_tokens_with_masks(x, masks)
258
+
259
+ for blk in self.blocks:
260
+ x = blk(x)
261
+
262
+ x_norm = self.norm(x)
263
+ return {
264
+ "x_norm_clstoken": x_norm[:, 0],
265
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
266
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
267
+ "x_prenorm": x,
268
+ "masks": masks,
269
+ }
270
+
271
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
272
+ x = self.prepare_tokens_with_masks(x)
273
+ # If n is an int, take the n last blocks. If it's a list, take them
274
+ output, total_block_len = [], len(self.blocks)
275
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
276
+ for i, blk in enumerate(self.blocks):
277
+ x = blk(x)
278
+ if i in blocks_to_take:
279
+ output.append(x)
280
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
281
+ return output
282
+
283
+ def _get_intermediate_layers_chunked(self, x, n=1):
284
+ x = self.prepare_tokens_with_masks(x)
285
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
286
+ # If n is an int, take the n last blocks. If it's a list, take them
287
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
288
+ for block_chunk in self.blocks:
289
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
290
+ x = blk(x)
291
+ if i in blocks_to_take:
292
+ output.append(x)
293
+ i += 1
294
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
295
+ return output
296
+
297
+ def get_intermediate_layers(
298
+ self,
299
+ x: torch.Tensor,
300
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
301
+ reshape: bool = False,
302
+ return_class_token: bool = False,
303
+ norm=True
304
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
305
+ if self.chunked_blocks:
306
+ outputs = self._get_intermediate_layers_chunked(x, n)
307
+ else:
308
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
309
+ if norm:
310
+ outputs = [self.norm(out) for out in outputs]
311
+ class_tokens = [out[:, 0] for out in outputs]
312
+ outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
313
+ if reshape:
314
+ B, _, w, h = x.shape
315
+ outputs = [
316
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
317
+ for out in outputs
318
+ ]
319
+ if return_class_token:
320
+ return tuple(zip(outputs, class_tokens))
321
+ return tuple(outputs)
322
+
323
+ def forward(self, *args, is_training=False, **kwargs):
324
+ ret = self.forward_features(*args, **kwargs)
325
+ if is_training:
326
+ return ret
327
+ else:
328
+ return self.head(ret["x_norm_clstoken"])
329
+
330
+
331
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
332
+ """ViT weight initialization, original timm impl (for reproducibility)"""
333
+ if isinstance(module, nn.Linear):
334
+ trunc_normal_(module.weight, std=0.02)
335
+ if module.bias is not None:
336
+ nn.init.zeros_(module.bias)
337
+
338
+
339
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
340
+ model = DinoVisionTransformer(
341
+ patch_size=patch_size,
342
+ embed_dim=384,
343
+ depth=12,
344
+ num_heads=6,
345
+ mlp_ratio=4,
346
+ block_fn=partial(Block, attn_class=MemEffAttention),
347
+ num_register_tokens=num_register_tokens,
348
+ **kwargs,
349
+ )
350
+ return model
351
+
352
+
353
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
354
+ model = DinoVisionTransformer(
355
+ patch_size=patch_size,
356
+ embed_dim=768,
357
+ depth=12,
358
+ num_heads=12,
359
+ mlp_ratio=4,
360
+ block_fn=partial(Block, attn_class=MemEffAttention),
361
+ num_register_tokens=num_register_tokens,
362
+ **kwargs,
363
+ )
364
+ return model
365
+
366
+
367
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
368
+ model = DinoVisionTransformer(
369
+ patch_size=patch_size,
370
+ embed_dim=1024,
371
+ depth=24,
372
+ num_heads=16,
373
+ mlp_ratio=4,
374
+ block_fn=partial(Block, attn_class=MemEffAttention),
375
+ num_register_tokens=num_register_tokens,
376
+ **kwargs,
377
+ )
378
+ return model
379
+
380
+
381
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
382
+ """
383
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
384
+ """
385
+ model = DinoVisionTransformer(
386
+ patch_size=patch_size,
387
+ embed_dim=1536,
388
+ depth=40,
389
+ num_heads=24,
390
+ mlp_ratio=4,
391
+ block_fn=partial(Block, attn_class=MemEffAttention),
392
+ num_register_tokens=num_register_tokens,
393
+ **kwargs,
394
+ )
395
+ return model
396
+
397
+
398
+ def DINOv2(model_name):
399
+ model_zoo = {
400
+ "vits": vit_small,
401
+ "vitb": vit_base,
402
+ "vitl": vit_large,
403
+ "vitg": vit_giant2
404
+ }
405
+
406
+ return model_zoo[model_name](
407
+ img_size=518,
408
+ patch_size=14,
409
+ init_values=1.0,
410
+ ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
411
+ block_chunks=0,
412
+ num_register_tokens=0,
413
+ interpolate_antialias=False,
414
+ interpolate_offset=0.1
415
+ )
Depth-Anything-V2/depth_anything_v2/dinov2_layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock
11
+ from .attention import MemEffAttention
Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (438 Bytes). View file
 
Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-310.pyc ADDED
Binary file (2.41 kB). View file
 
Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/block.cpython-310.pyc ADDED
Binary file (8.01 kB). View file
 
Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-310.pyc ADDED
Binary file (1.24 kB). View file
 
Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc ADDED
Binary file (1.04 kB). View file
 
Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-310.pyc ADDED
Binary file (1.23 kB). View file
 
Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc ADDED
Binary file (2.68 kB). View file
 
Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc ADDED
Binary file (2.03 kB). View file
 
Depth-Anything-V2/depth_anything_v2/dinov2_layers/attention.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10
+
11
+ import logging
12
+
13
+ from torch import Tensor
14
+ from torch import nn
15
+
16
+
17
+ logger = logging.getLogger("dinov2")
18
+
19
+
20
+ try:
21
+ from xformers.ops import memory_efficient_attention, unbind, fmha
22
+
23
+ XFORMERS_AVAILABLE = True
24
+ except ImportError:
25
+ logger.warning("xFormers not available")
26
+ XFORMERS_AVAILABLE = False
27
+
28
+
29
+ class Attention(nn.Module):
30
+ def __init__(
31
+ self,
32
+ dim: int,
33
+ num_heads: int = 8,
34
+ qkv_bias: bool = False,
35
+ proj_bias: bool = True,
36
+ attn_drop: float = 0.0,
37
+ proj_drop: float = 0.0,
38
+ ) -> None:
39
+ super().__init__()
40
+ self.num_heads = num_heads
41
+ head_dim = dim // num_heads
42
+ self.scale = head_dim**-0.5
43
+
44
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
45
+ self.attn_drop = nn.Dropout(attn_drop)
46
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
47
+ self.proj_drop = nn.Dropout(proj_drop)
48
+
49
+ def forward(self, x: Tensor) -> Tensor:
50
+ B, N, C = x.shape
51
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
52
+
53
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
54
+ attn = q @ k.transpose(-2, -1)
55
+
56
+ attn = attn.softmax(dim=-1)
57
+ attn = self.attn_drop(attn)
58
+
59
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
60
+ x = self.proj(x)
61
+ x = self.proj_drop(x)
62
+ return x
63
+
64
+
65
+ class MemEffAttention(Attention):
66
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
67
+ if not XFORMERS_AVAILABLE:
68
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
69
+ return super().forward(x)
70
+
71
+ B, N, C = x.shape
72
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
73
+
74
+ q, k, v = unbind(qkv, 2)
75
+
76
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
77
+ x = x.reshape([B, N, C])
78
+
79
+ x = self.proj(x)
80
+ x = self.proj_drop(x)
81
+ return x
82
+
83
+
Depth-Anything-V2/depth_anything_v2/dinov2_layers/block.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ import logging
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+
14
+ import torch
15
+ from torch import nn, Tensor
16
+
17
+ from .attention import Attention, MemEffAttention
18
+ from .drop_path import DropPath
19
+ from .layer_scale import LayerScale
20
+ from .mlp import Mlp
21
+
22
+
23
+ logger = logging.getLogger("dinov2")
24
+
25
+
26
+ try:
27
+ from xformers.ops import fmha
28
+ from xformers.ops import scaled_index_add, index_select_cat
29
+
30
+ XFORMERS_AVAILABLE = True
31
+ except ImportError:
32
+ logger.warning("xFormers not available")
33
+ XFORMERS_AVAILABLE = False
34
+
35
+
36
+ class Block(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ num_heads: int,
41
+ mlp_ratio: float = 4.0,
42
+ qkv_bias: bool = False,
43
+ proj_bias: bool = True,
44
+ ffn_bias: bool = True,
45
+ drop: float = 0.0,
46
+ attn_drop: float = 0.0,
47
+ init_values=None,
48
+ drop_path: float = 0.0,
49
+ act_layer: Callable[..., nn.Module] = nn.GELU,
50
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
51
+ attn_class: Callable[..., nn.Module] = Attention,
52
+ ffn_layer: Callable[..., nn.Module] = Mlp,
53
+ ) -> None:
54
+ super().__init__()
55
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
56
+ self.norm1 = norm_layer(dim)
57
+ self.attn = attn_class(
58
+ dim,
59
+ num_heads=num_heads,
60
+ qkv_bias=qkv_bias,
61
+ proj_bias=proj_bias,
62
+ attn_drop=attn_drop,
63
+ proj_drop=drop,
64
+ )
65
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
66
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
67
+
68
+ self.norm2 = norm_layer(dim)
69
+ mlp_hidden_dim = int(dim * mlp_ratio)
70
+ self.mlp = ffn_layer(
71
+ in_features=dim,
72
+ hidden_features=mlp_hidden_dim,
73
+ act_layer=act_layer,
74
+ drop=drop,
75
+ bias=ffn_bias,
76
+ )
77
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
78
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
79
+
80
+ self.sample_drop_ratio = drop_path
81
+
82
+ def forward(self, x: Tensor) -> Tensor:
83
+ def attn_residual_func(x: Tensor) -> Tensor:
84
+ return self.ls1(self.attn(self.norm1(x)))
85
+
86
+ def ffn_residual_func(x: Tensor) -> Tensor:
87
+ return self.ls2(self.mlp(self.norm2(x)))
88
+
89
+ if self.training and self.sample_drop_ratio > 0.1:
90
+ # the overhead is compensated only for a drop path rate larger than 0.1
91
+ x = drop_add_residual_stochastic_depth(
92
+ x,
93
+ residual_func=attn_residual_func,
94
+ sample_drop_ratio=self.sample_drop_ratio,
95
+ )
96
+ x = drop_add_residual_stochastic_depth(
97
+ x,
98
+ residual_func=ffn_residual_func,
99
+ sample_drop_ratio=self.sample_drop_ratio,
100
+ )
101
+ elif self.training and self.sample_drop_ratio > 0.0:
102
+ x = x + self.drop_path1(attn_residual_func(x))
103
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
104
+ else:
105
+ x = x + attn_residual_func(x)
106
+ x = x + ffn_residual_func(x)
107
+ return x
108
+
109
+
110
+ def drop_add_residual_stochastic_depth(
111
+ x: Tensor,
112
+ residual_func: Callable[[Tensor], Tensor],
113
+ sample_drop_ratio: float = 0.0,
114
+ ) -> Tensor:
115
+ # 1) extract subset using permutation
116
+ b, n, d = x.shape
117
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
118
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
119
+ x_subset = x[brange]
120
+
121
+ # 2) apply residual_func to get residual
122
+ residual = residual_func(x_subset)
123
+
124
+ x_flat = x.flatten(1)
125
+ residual = residual.flatten(1)
126
+
127
+ residual_scale_factor = b / sample_subset_size
128
+
129
+ # 3) add the residual
130
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
131
+ return x_plus_residual.view_as(x)
132
+
133
+
134
+ def get_branges_scales(x, sample_drop_ratio=0.0):
135
+ b, n, d = x.shape
136
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
137
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
138
+ residual_scale_factor = b / sample_subset_size
139
+ return brange, residual_scale_factor
140
+
141
+
142
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
143
+ if scaling_vector is None:
144
+ x_flat = x.flatten(1)
145
+ residual = residual.flatten(1)
146
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
147
+ else:
148
+ x_plus_residual = scaled_index_add(
149
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
150
+ )
151
+ return x_plus_residual
152
+
153
+
154
+ attn_bias_cache: Dict[Tuple, Any] = {}
155
+
156
+
157
+ def get_attn_bias_and_cat(x_list, branges=None):
158
+ """
159
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
160
+ """
161
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
162
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
163
+ if all_shapes not in attn_bias_cache.keys():
164
+ seqlens = []
165
+ for b, x in zip(batch_sizes, x_list):
166
+ for _ in range(b):
167
+ seqlens.append(x.shape[1])
168
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
169
+ attn_bias._batch_sizes = batch_sizes
170
+ attn_bias_cache[all_shapes] = attn_bias
171
+
172
+ if branges is not None:
173
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
174
+ else:
175
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
176
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
177
+
178
+ return attn_bias_cache[all_shapes], cat_tensors
179
+
180
+
181
+ def drop_add_residual_stochastic_depth_list(
182
+ x_list: List[Tensor],
183
+ residual_func: Callable[[Tensor, Any], Tensor],
184
+ sample_drop_ratio: float = 0.0,
185
+ scaling_vector=None,
186
+ ) -> Tensor:
187
+ # 1) generate random set of indices for dropping samples in the batch
188
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
189
+ branges = [s[0] for s in branges_scales]
190
+ residual_scale_factors = [s[1] for s in branges_scales]
191
+
192
+ # 2) get attention bias and index+concat the tensors
193
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
194
+
195
+ # 3) apply residual_func to get residual, and split the result
196
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
197
+
198
+ outputs = []
199
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
200
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
201
+ return outputs
202
+
203
+
204
+ class NestedTensorBlock(Block):
205
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
206
+ """
207
+ x_list contains a list of tensors to nest together and run
208
+ """
209
+ assert isinstance(self.attn, MemEffAttention)
210
+
211
+ if self.training and self.sample_drop_ratio > 0.0:
212
+
213
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
214
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
215
+
216
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
217
+ return self.mlp(self.norm2(x))
218
+
219
+ x_list = drop_add_residual_stochastic_depth_list(
220
+ x_list,
221
+ residual_func=attn_residual_func,
222
+ sample_drop_ratio=self.sample_drop_ratio,
223
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
224
+ )
225
+ x_list = drop_add_residual_stochastic_depth_list(
226
+ x_list,
227
+ residual_func=ffn_residual_func,
228
+ sample_drop_ratio=self.sample_drop_ratio,
229
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
230
+ )
231
+ return x_list
232
+ else:
233
+
234
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
235
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
236
+
237
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
238
+ return self.ls2(self.mlp(self.norm2(x)))
239
+
240
+ attn_bias, x = get_attn_bias_and_cat(x_list)
241
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
242
+ x = x + ffn_residual_func(x)
243
+ return attn_bias.split(x)
244
+
245
+ def forward(self, x_or_x_list):
246
+ if isinstance(x_or_x_list, Tensor):
247
+ return super().forward(x_or_x_list)
248
+ elif isinstance(x_or_x_list, list):
249
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
250
+ return self.forward_nested(x_or_x_list)
251
+ else:
252
+ raise AssertionError
Depth-Anything-V2/depth_anything_v2/dinov2_layers/drop_path.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
10
+
11
+
12
+ from torch import nn
13
+
14
+
15
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
16
+ if drop_prob == 0.0 or not training:
17
+ return x
18
+ keep_prob = 1 - drop_prob
19
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
20
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
21
+ if keep_prob > 0.0:
22
+ random_tensor.div_(keep_prob)
23
+ output = x * random_tensor
24
+ return output
25
+
26
+
27
+ class DropPath(nn.Module):
28
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
29
+
30
+ def __init__(self, drop_prob=None):
31
+ super(DropPath, self).__init__()
32
+ self.drop_prob = drop_prob
33
+
34
+ def forward(self, x):
35
+ return drop_path(x, self.drop_prob, self.training)
Depth-Anything-V2/depth_anything_v2/dinov2_layers/layer_scale.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
8
+
9
+ from typing import Union
10
+
11
+ import torch
12
+ from torch import Tensor
13
+ from torch import nn
14
+
15
+
16
+ class LayerScale(nn.Module):
17
+ def __init__(
18
+ self,
19
+ dim: int,
20
+ init_values: Union[float, Tensor] = 1e-5,
21
+ inplace: bool = False,
22
+ ) -> None:
23
+ super().__init__()
24
+ self.inplace = inplace
25
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
26
+
27
+ def forward(self, x: Tensor) -> Tensor:
28
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
Depth-Anything-V2/depth_anything_v2/dinov2_layers/mlp.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
10
+
11
+
12
+ from typing import Callable, Optional
13
+
14
+ from torch import Tensor, nn
15
+
16
+
17
+ class Mlp(nn.Module):
18
+ def __init__(
19
+ self,
20
+ in_features: int,
21
+ hidden_features: Optional[int] = None,
22
+ out_features: Optional[int] = None,
23
+ act_layer: Callable[..., nn.Module] = nn.GELU,
24
+ drop: float = 0.0,
25
+ bias: bool = True,
26
+ ) -> None:
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x: Tensor) -> Tensor:
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
Depth-Anything-V2/depth_anything_v2/dinov2_layers/patch_embed.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ from typing import Callable, Optional, Tuple, Union
12
+
13
+ from torch import Tensor
14
+ import torch.nn as nn
15
+
16
+
17
+ def make_2tuple(x):
18
+ if isinstance(x, tuple):
19
+ assert len(x) == 2
20
+ return x
21
+
22
+ assert isinstance(x, int)
23
+ return (x, x)
24
+
25
+
26
+ class PatchEmbed(nn.Module):
27
+ """
28
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
29
+
30
+ Args:
31
+ img_size: Image size.
32
+ patch_size: Patch token size.
33
+ in_chans: Number of input image channels.
34
+ embed_dim: Number of linear projection output channels.
35
+ norm_layer: Normalization layer.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ img_size: Union[int, Tuple[int, int]] = 224,
41
+ patch_size: Union[int, Tuple[int, int]] = 16,
42
+ in_chans: int = 3,
43
+ embed_dim: int = 768,
44
+ norm_layer: Optional[Callable] = None,
45
+ flatten_embedding: bool = True,
46
+ ) -> None:
47
+ super().__init__()
48
+
49
+ image_HW = make_2tuple(img_size)
50
+ patch_HW = make_2tuple(patch_size)
51
+ patch_grid_size = (
52
+ image_HW[0] // patch_HW[0],
53
+ image_HW[1] // patch_HW[1],
54
+ )
55
+
56
+ self.img_size = image_HW
57
+ self.patch_size = patch_HW
58
+ self.patches_resolution = patch_grid_size
59
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
60
+
61
+ self.in_chans = in_chans
62
+ self.embed_dim = embed_dim
63
+
64
+ self.flatten_embedding = flatten_embedding
65
+
66
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
67
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
68
+
69
+ def forward(self, x: Tensor) -> Tensor:
70
+ _, _, H, W = x.shape
71
+ patch_H, patch_W = self.patch_size
72
+
73
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
74
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
75
+
76
+ x = self.proj(x) # B C H W
77
+ H, W = x.size(2), x.size(3)
78
+ x = x.flatten(2).transpose(1, 2) # B HW C
79
+ x = self.norm(x)
80
+ if not self.flatten_embedding:
81
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
82
+ return x
83
+
84
+ def flops(self) -> float:
85
+ Ho, Wo = self.patches_resolution
86
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
87
+ if self.norm is not None:
88
+ flops += Ho * Wo * self.embed_dim
89
+ return flops
Depth-Anything-V2/depth_anything_v2/dinov2_layers/swiglu_ffn.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Callable, Optional
8
+
9
+ from torch import Tensor, nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class SwiGLUFFN(nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_features: int,
17
+ hidden_features: Optional[int] = None,
18
+ out_features: Optional[int] = None,
19
+ act_layer: Callable[..., nn.Module] = None,
20
+ drop: float = 0.0,
21
+ bias: bool = True,
22
+ ) -> None:
23
+ super().__init__()
24
+ out_features = out_features or in_features
25
+ hidden_features = hidden_features or in_features
26
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
27
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
28
+
29
+ def forward(self, x: Tensor) -> Tensor:
30
+ x12 = self.w12(x)
31
+ x1, x2 = x12.chunk(2, dim=-1)
32
+ hidden = F.silu(x1) * x2
33
+ return self.w3(hidden)
34
+
35
+
36
+ try:
37
+ from xformers.ops import SwiGLU
38
+
39
+ XFORMERS_AVAILABLE = True
40
+ except ImportError:
41
+ SwiGLU = SwiGLUFFN
42
+ XFORMERS_AVAILABLE = False
43
+
44
+
45
+ class SwiGLUFFNFused(SwiGLU):
46
+ def __init__(
47
+ self,
48
+ in_features: int,
49
+ hidden_features: Optional[int] = None,
50
+ out_features: Optional[int] = None,
51
+ act_layer: Callable[..., nn.Module] = None,
52
+ drop: float = 0.0,
53
+ bias: bool = True,
54
+ ) -> None:
55
+ out_features = out_features or in_features
56
+ hidden_features = hidden_features or in_features
57
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
58
+ super().__init__(
59
+ in_features=in_features,
60
+ hidden_features=hidden_features,
61
+ out_features=out_features,
62
+ bias=bias,
63
+ )
Depth-Anything-V2/depth_anything_v2/dpt.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torchvision.transforms import Compose
6
+
7
+ from .dinov2 import DINOv2
8
+ from .util.blocks import FeatureFusionBlock, _make_scratch
9
+ from .util.transform import Resize, NormalizeImage, PrepareForNet
10
+
11
+
12
+ def _make_fusion_block(features, use_bn, size=None):
13
+ return FeatureFusionBlock(
14
+ features,
15
+ nn.ReLU(False),
16
+ deconv=False,
17
+ bn=use_bn,
18
+ expand=False,
19
+ align_corners=True,
20
+ size=size,
21
+ )
22
+
23
+
24
+ class ConvBlock(nn.Module):
25
+ def __init__(self, in_feature, out_feature):
26
+ super().__init__()
27
+
28
+ self.conv_block = nn.Sequential(
29
+ nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
30
+ nn.BatchNorm2d(out_feature),
31
+ nn.ReLU(True)
32
+ )
33
+
34
+ def forward(self, x):
35
+ return self.conv_block(x)
36
+
37
+
38
+ class DPTHead(nn.Module):
39
+ def __init__(
40
+ self,
41
+ in_channels,
42
+ features=256,
43
+ use_bn=False,
44
+ out_channels=[256, 512, 1024, 1024],
45
+ use_clstoken=False
46
+ ):
47
+ super(DPTHead, self).__init__()
48
+
49
+ self.use_clstoken = use_clstoken
50
+
51
+ self.projects = nn.ModuleList([
52
+ nn.Conv2d(
53
+ in_channels=in_channels,
54
+ out_channels=out_channel,
55
+ kernel_size=1,
56
+ stride=1,
57
+ padding=0,
58
+ ) for out_channel in out_channels
59
+ ])
60
+
61
+ self.resize_layers = nn.ModuleList([
62
+ nn.ConvTranspose2d(
63
+ in_channels=out_channels[0],
64
+ out_channels=out_channels[0],
65
+ kernel_size=4,
66
+ stride=4,
67
+ padding=0),
68
+ nn.ConvTranspose2d(
69
+ in_channels=out_channels[1],
70
+ out_channels=out_channels[1],
71
+ kernel_size=2,
72
+ stride=2,
73
+ padding=0),
74
+ nn.Identity(),
75
+ nn.Conv2d(
76
+ in_channels=out_channels[3],
77
+ out_channels=out_channels[3],
78
+ kernel_size=3,
79
+ stride=2,
80
+ padding=1)
81
+ ])
82
+
83
+ if use_clstoken:
84
+ self.readout_projects = nn.ModuleList()
85
+ for _ in range(len(self.projects)):
86
+ self.readout_projects.append(
87
+ nn.Sequential(
88
+ nn.Linear(2 * in_channels, in_channels),
89
+ nn.GELU()))
90
+
91
+ self.scratch = _make_scratch(
92
+ out_channels,
93
+ features,
94
+ groups=1,
95
+ expand=False,
96
+ )
97
+
98
+ self.scratch.stem_transpose = None
99
+
100
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
101
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
102
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
103
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
104
+
105
+ head_features_1 = features
106
+ head_features_2 = 32
107
+
108
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
109
+ self.scratch.output_conv2 = nn.Sequential(
110
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
111
+ nn.ReLU(True),
112
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
113
+ nn.ReLU(True),
114
+ nn.Identity(),
115
+ )
116
+
117
+ def forward(self, out_features, patch_h, patch_w):
118
+ out = []
119
+ for i, x in enumerate(out_features):
120
+ if self.use_clstoken:
121
+ x, cls_token = x[0], x[1]
122
+ readout = cls_token.unsqueeze(1).expand_as(x)
123
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
124
+ else:
125
+ x = x[0]
126
+
127
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
128
+
129
+ x = self.projects[i](x)
130
+ x = self.resize_layers[i](x)
131
+
132
+ # project&resize 0: torch.Size([1, 256, 148, 216])
133
+ # project&resize 1: torch.Size([1, 512, 74, 108])
134
+ # project&resize 2: torch.Size([1, 1024, 37, 54])
135
+ # project&resize 3: torch.Size([1, 1024, 19, 27])
136
+
137
+ out.append(x)
138
+
139
+ layer_1, layer_2, layer_3, layer_4 = out
140
+
141
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
142
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
143
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
144
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
145
+
146
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
147
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
148
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
149
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
150
+
151
+ out_fea = self.scratch.output_conv1(path_1)
152
+ # scratch.output_conv1: torch.Size([1, 128, 296, 432])
153
+ out = F.interpolate(out_fea, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
154
+ # interpolate: torch.Size([1, 128, 518, 756])
155
+ out = self.scratch.output_conv2(out)
156
+ # scratch.output_conv2: torch.Size([1, 1, 518, 756])
157
+
158
+ return out, out_fea
159
+
160
+
161
+ class DepthAnythingV2(nn.Module):
162
+ def __init__(
163
+ self,
164
+ encoder='vitl',
165
+ features=256,
166
+ out_channels=[256, 512, 1024, 1024],
167
+ use_bn=False,
168
+ use_clstoken=False
169
+ ):
170
+ super(DepthAnythingV2, self).__init__()
171
+
172
+ self.intermediate_layer_idx = {
173
+ 'vits': [2, 5, 8, 11],
174
+ 'vitb': [2, 5, 8, 11],
175
+ 'vitl': [4, 11, 17, 23],
176
+ 'vitg': [9, 19, 29, 39]
177
+ }
178
+
179
+ self.encoder = encoder
180
+ self.pretrained = DINOv2(model_name=encoder)
181
+
182
+ self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
183
+
184
+ def forward(self, x):
185
+ patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14
186
+
187
+ # features 0: torch.Size([1, 1998, 1024]) torch.Size([1, 1024])
188
+ # features 1: torch.Size([1, 1998, 1024]) torch.Size([1, 1024])
189
+ # features 2: torch.Size([1, 1998, 1024]) torch.Size([1, 1024])
190
+ # features 3: torch.Size([1, 1998, 1024]) torch.Size([1, 1024])
191
+ features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True)
192
+
193
+ depth, out_fea = self.depth_head(features, patch_h, patch_w)
194
+ depth = F.relu(depth)
195
+
196
+ return depth, out_fea
197
+
198
+ @torch.no_grad()
199
+ def infer_image(self, raw_image, input_size=518):
200
+ image, (h, w) = self.image2tensor(raw_image, input_size)
201
+
202
+ depth = self.forward(image)
203
+
204
+ depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0]
205
+
206
+ return depth.cpu().numpy()
207
+
208
+ def image2tensor(self, raw_image, input_size=518):
209
+ transform = Compose([
210
+ Resize(
211
+ width=input_size,
212
+ height=input_size,
213
+ resize_target=False,
214
+ keep_aspect_ratio=True,
215
+ ensure_multiple_of=14,
216
+ resize_method='lower_bound',
217
+ image_interpolation_method=cv2.INTER_CUBIC,
218
+ ),
219
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
220
+ PrepareForNet(),
221
+ ])
222
+
223
+ h, w = raw_image.shape[:2]
224
+
225
+ image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0
226
+
227
+ image = transform({'image': image})['image']
228
+ image = torch.from_numpy(image).unsqueeze(0)
229
+
230
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
231
+ image = image.to(DEVICE)
232
+
233
+ return image, (h, w)
Depth-Anything-V2/depth_anything_v2/util/__pycache__/blocks.cpython-310.pyc ADDED
Binary file (3.3 kB). View file
 
Depth-Anything-V2/depth_anything_v2/util/__pycache__/transform.cpython-310.pyc ADDED
Binary file (4.74 kB). View file
 
Depth-Anything-V2/depth_anything_v2/util/blocks.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
5
+ scratch = nn.Module()
6
+
7
+ out_shape1 = out_shape
8
+ out_shape2 = out_shape
9
+ out_shape3 = out_shape
10
+ if len(in_shape) >= 4:
11
+ out_shape4 = out_shape
12
+
13
+ if expand:
14
+ out_shape1 = out_shape
15
+ out_shape2 = out_shape * 2
16
+ out_shape3 = out_shape * 4
17
+ if len(in_shape) >= 4:
18
+ out_shape4 = out_shape * 8
19
+
20
+ scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
21
+ scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
22
+ scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
23
+ if len(in_shape) >= 4:
24
+ scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
25
+
26
+ return scratch
27
+
28
+
29
+ class ResidualConvUnit(nn.Module):
30
+ """Residual convolution module.
31
+ """
32
+
33
+ def __init__(self, features, activation, bn):
34
+ """Init.
35
+
36
+ Args:
37
+ features (int): number of features
38
+ """
39
+ super().__init__()
40
+
41
+ self.bn = bn
42
+
43
+ self.groups=1
44
+
45
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
46
+
47
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
48
+
49
+ if self.bn == True:
50
+ self.bn1 = nn.BatchNorm2d(features)
51
+ self.bn2 = nn.BatchNorm2d(features)
52
+
53
+ self.activation = activation
54
+
55
+ self.skip_add = nn.quantized.FloatFunctional()
56
+
57
+ def forward(self, x):
58
+ """Forward pass.
59
+
60
+ Args:
61
+ x (tensor): input
62
+
63
+ Returns:
64
+ tensor: output
65
+ """
66
+
67
+ out = self.activation(x)
68
+ out = self.conv1(out)
69
+ if self.bn == True:
70
+ out = self.bn1(out)
71
+
72
+ out = self.activation(out)
73
+ out = self.conv2(out)
74
+ if self.bn == True:
75
+ out = self.bn2(out)
76
+
77
+ if self.groups > 1:
78
+ out = self.conv_merge(out)
79
+
80
+ return self.skip_add.add(out, x)
81
+
82
+
83
+ class FeatureFusionBlock(nn.Module):
84
+ """Feature fusion block.
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ features,
90
+ activation,
91
+ deconv=False,
92
+ bn=False,
93
+ expand=False,
94
+ align_corners=True,
95
+ size=None
96
+ ):
97
+ """Init.
98
+
99
+ Args:
100
+ features (int): number of features
101
+ """
102
+ super(FeatureFusionBlock, self).__init__()
103
+
104
+ self.deconv = deconv
105
+ self.align_corners = align_corners
106
+
107
+ self.groups=1
108
+
109
+ self.expand = expand
110
+ out_features = features
111
+ if self.expand == True:
112
+ out_features = features // 2
113
+
114
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
115
+
116
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
117
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
118
+
119
+ self.skip_add = nn.quantized.FloatFunctional()
120
+
121
+ self.size=size
122
+
123
+ def forward(self, *xs, size=None):
124
+ """Forward pass.
125
+
126
+ Returns:
127
+ tensor: output
128
+ """
129
+ output = xs[0]
130
+
131
+ if len(xs) == 2:
132
+ res = self.resConfUnit1(xs[1])
133
+ output = self.skip_add.add(output, res)
134
+
135
+ output = self.resConfUnit2(output)
136
+
137
+ if (size is None) and (self.size is None):
138
+ modifier = {"scale_factor": 2}
139
+ elif size is None:
140
+ modifier = {"size": self.size}
141
+ else:
142
+ modifier = {"size": size}
143
+
144
+ output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
145
+
146
+ output = self.out_conv(output)
147
+
148
+ return output
Depth-Anything-V2/depth_anything_v2/util/transform.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+
4
+
5
+ class Resize(object):
6
+ """Resize sample to given size (width, height).
7
+ """
8
+
9
+ def __init__(
10
+ self,
11
+ width,
12
+ height,
13
+ resize_target=True,
14
+ keep_aspect_ratio=False,
15
+ ensure_multiple_of=1,
16
+ resize_method="lower_bound",
17
+ image_interpolation_method=cv2.INTER_AREA,
18
+ ):
19
+ """Init.
20
+
21
+ Args:
22
+ width (int): desired output width
23
+ height (int): desired output height
24
+ resize_target (bool, optional):
25
+ True: Resize the full sample (image, mask, target).
26
+ False: Resize image only.
27
+ Defaults to True.
28
+ keep_aspect_ratio (bool, optional):
29
+ True: Keep the aspect ratio of the input sample.
30
+ Output sample might not have the given width and height, and
31
+ resize behaviour depends on the parameter 'resize_method'.
32
+ Defaults to False.
33
+ ensure_multiple_of (int, optional):
34
+ Output width and height is constrained to be multiple of this parameter.
35
+ Defaults to 1.
36
+ resize_method (str, optional):
37
+ "lower_bound": Output will be at least as large as the given size.
38
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
39
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
40
+ Defaults to "lower_bound".
41
+ """
42
+ self.__width = width
43
+ self.__height = height
44
+
45
+ self.__resize_target = resize_target
46
+ self.__keep_aspect_ratio = keep_aspect_ratio
47
+ self.__multiple_of = ensure_multiple_of
48
+ self.__resize_method = resize_method
49
+ self.__image_interpolation_method = image_interpolation_method
50
+
51
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
52
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
53
+
54
+ if max_val is not None and y > max_val:
55
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
56
+
57
+ if y < min_val:
58
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
59
+
60
+ return y
61
+
62
+ def get_size(self, width, height):
63
+ # determine new height and width
64
+ scale_height = self.__height / height
65
+ scale_width = self.__width / width
66
+
67
+ if self.__keep_aspect_ratio:
68
+ if self.__resize_method == "lower_bound":
69
+ # scale such that output size is lower bound
70
+ if scale_width > scale_height:
71
+ # fit width
72
+ scale_height = scale_width
73
+ else:
74
+ # fit height
75
+ scale_width = scale_height
76
+ elif self.__resize_method == "upper_bound":
77
+ # scale such that output size is upper bound
78
+ if scale_width < scale_height:
79
+ # fit width
80
+ scale_height = scale_width
81
+ else:
82
+ # fit height
83
+ scale_width = scale_height
84
+ elif self.__resize_method == "minimal":
85
+ # scale as least as possbile
86
+ if abs(1 - scale_width) < abs(1 - scale_height):
87
+ # fit width
88
+ scale_height = scale_width
89
+ else:
90
+ # fit height
91
+ scale_width = scale_height
92
+ else:
93
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
94
+
95
+ if self.__resize_method == "lower_bound":
96
+ new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
97
+ new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
98
+ elif self.__resize_method == "upper_bound":
99
+ new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
100
+ new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
101
+ elif self.__resize_method == "minimal":
102
+ new_height = self.constrain_to_multiple_of(scale_height * height)
103
+ new_width = self.constrain_to_multiple_of(scale_width * width)
104
+ else:
105
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
106
+
107
+ return (new_width, new_height)
108
+
109
+ def __call__(self, sample):
110
+ width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
111
+
112
+ # resize sample
113
+ sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method)
114
+
115
+ if self.__resize_target:
116
+ if "depth" in sample:
117
+ sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
118
+
119
+ if "mask" in sample:
120
+ sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
121
+
122
+ return sample
123
+
124
+
125
+ class NormalizeImage(object):
126
+ """Normlize image by given mean and std.
127
+ """
128
+
129
+ def __init__(self, mean, std):
130
+ self.__mean = mean
131
+ self.__std = std
132
+
133
+ def __call__(self, sample):
134
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
135
+
136
+ return sample
137
+
138
+
139
+ class PrepareForNet(object):
140
+ """Prepare sample for usage as network input.
141
+ """
142
+
143
+ def __init__(self):
144
+ pass
145
+
146
+ def __call__(self, sample):
147
+ image = np.transpose(sample["image"], (2, 0, 1))
148
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
149
+
150
+ if "depth" in sample:
151
+ depth = sample["depth"].astype(np.float32)
152
+ sample["depth"] = np.ascontiguousarray(depth)
153
+
154
+ if "mask" in sample:
155
+ sample["mask"] = sample["mask"].astype(np.float32)
156
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
157
+
158
+ return sample
Depth-Anything-V2/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio_imageslider
2
+ gradio==4.29.0
3
+ matplotlib
4
+ opencv-python
5
+ torch
6
+ torchvision
Depth-Anything-V2/run.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import glob
4
+ import matplotlib
5
+ import numpy as np
6
+ import os
7
+ import torch
8
+
9
+ from depth_anything_v2.dpt import DepthAnythingV2
10
+
11
+
12
+ if __name__ == '__main__':
13
+ parser = argparse.ArgumentParser(description='Depth Anything V2')
14
+
15
+ parser.add_argument('--img-path', type=str)
16
+ parser.add_argument('--input-size', type=int, default=518)
17
+ parser.add_argument('--outdir', type=str, default='./vis_depth')
18
+
19
+ parser.add_argument('--encoder', type=str, default='vitl', choices=['vits', 'vitb', 'vitl', 'vitg'])
20
+
21
+ parser.add_argument('--pred-only', dest='pred_only', action='store_true', help='only display the prediction')
22
+ parser.add_argument('--grayscale', dest='grayscale', action='store_true', help='do not apply colorful palette')
23
+
24
+ args = parser.parse_args()
25
+
26
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
27
+
28
+ model_configs = {
29
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
30
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
31
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
32
+ 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
33
+ }
34
+
35
+ depth_anything = DepthAnythingV2(**model_configs[args.encoder])
36
+ depth_anything.load_state_dict(torch.load(f'checkpoints/depth_anything_v2_{args.encoder}.pth', map_location='cpu'))
37
+ depth_anything = depth_anything.to(DEVICE).eval()
38
+
39
+ if os.path.isfile(args.img_path):
40
+ if args.img_path.endswith('txt'):
41
+ with open(args.img_path, 'r') as f:
42
+ filenames = f.read().splitlines()
43
+ else:
44
+ filenames = [args.img_path]
45
+ else:
46
+ filenames = glob.glob(os.path.join(args.img_path, '**/*'), recursive=True)
47
+
48
+ os.makedirs(args.outdir, exist_ok=True)
49
+
50
+ cmap = matplotlib.colormaps.get_cmap('Spectral_r')
51
+
52
+ for k, filename in enumerate(filenames):
53
+ print(f'Progress {k+1}/{len(filenames)}: {filename}')
54
+
55
+ raw_image = cv2.imread(filename)
56
+
57
+ depth = depth_anything.infer_image(raw_image, args.input_size)
58
+
59
+ depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
60
+ depth = depth.astype(np.uint8)
61
+
62
+ if args.grayscale:
63
+ depth = np.repeat(depth[..., np.newaxis], 3, axis=-1)
64
+ else:
65
+ depth = (cmap(depth)[:, :, :3] * 255)[:, :, ::-1].astype(np.uint8)
66
+
67
+ if args.pred_only:
68
+ cv2.imwrite(os.path.join(args.outdir, os.path.splitext(os.path.basename(filename))[0] + '.png'), depth)
69
+ else:
70
+ split_region = np.ones((raw_image.shape[0], 50, 3), dtype=np.uint8) * 255
71
+ combined_result = cv2.hconcat([raw_image, split_region, depth])
72
+
73
+ cv2.imwrite(os.path.join(args.outdir, os.path.splitext(os.path.basename(filename))[0] + '.png'), combined_result)
Depth-Anything-V2/run_video.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import glob
4
+ import matplotlib
5
+ import numpy as np
6
+ import os
7
+ import torch
8
+
9
+ from depth_anything_v2.dpt import DepthAnythingV2
10
+
11
+
12
+ if __name__ == '__main__':
13
+ parser = argparse.ArgumentParser(description='Depth Anything V2')
14
+
15
+ parser.add_argument('--video-path', type=str)
16
+ parser.add_argument('--input-size', type=int, default=518)
17
+ parser.add_argument('--outdir', type=str, default='./vis_video_depth')
18
+
19
+ parser.add_argument('--encoder', type=str, default='vitl', choices=['vits', 'vitb', 'vitl', 'vitg'])
20
+
21
+ parser.add_argument('--pred-only', dest='pred_only', action='store_true', help='only display the prediction')
22
+ parser.add_argument('--grayscale', dest='grayscale', action='store_true', help='do not apply colorful palette')
23
+
24
+ args = parser.parse_args()
25
+
26
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
27
+
28
+ model_configs = {
29
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
30
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
31
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
32
+ 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
33
+ }
34
+
35
+ depth_anything = DepthAnythingV2(**model_configs[args.encoder])
36
+ depth_anything.load_state_dict(torch.load(f'checkpoints/depth_anything_v2_{args.encoder}.pth', map_location='cpu'))
37
+ depth_anything = depth_anything.to(DEVICE).eval()
38
+
39
+ if os.path.isfile(args.video_path):
40
+ if args.video_path.endswith('txt'):
41
+ with open(args.video_path, 'r') as f:
42
+ lines = f.read().splitlines()
43
+ else:
44
+ filenames = [args.video_path]
45
+ else:
46
+ filenames = glob.glob(os.path.join(args.video_path, '**/*'), recursive=True)
47
+
48
+ os.makedirs(args.outdir, exist_ok=True)
49
+
50
+ margin_width = 50
51
+ cmap = matplotlib.colormaps.get_cmap('Spectral_r')
52
+
53
+ for k, filename in enumerate(filenames):
54
+ print(f'Progress {k+1}/{len(filenames)}: {filename}')
55
+
56
+ raw_video = cv2.VideoCapture(filename)
57
+ frame_width, frame_height = int(raw_video.get(cv2.CAP_PROP_FRAME_WIDTH)), int(raw_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
58
+ frame_rate = int(raw_video.get(cv2.CAP_PROP_FPS))
59
+
60
+ if args.pred_only:
61
+ output_width = frame_width
62
+ else:
63
+ output_width = frame_width * 2 + margin_width
64
+
65
+ output_path = os.path.join(args.outdir, os.path.splitext(os.path.basename(filename))[0] + '.mp4')
66
+ out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_rate, (output_width, frame_height))
67
+
68
+ while raw_video.isOpened():
69
+ ret, raw_frame = raw_video.read()
70
+ if not ret:
71
+ break
72
+
73
+ depth = depth_anything.infer_image(raw_frame, args.input_size)
74
+
75
+ depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
76
+ depth = depth.astype(np.uint8)
77
+
78
+ if args.grayscale:
79
+ depth = np.repeat(depth[..., np.newaxis], 3, axis=-1)
80
+ else:
81
+ depth = (cmap(depth)[:, :, :3] * 255)[:, :, ::-1].astype(np.uint8)
82
+
83
+ if args.pred_only:
84
+ out.write(depth)
85
+ else:
86
+ split_region = np.ones((frame_height, margin_width, 3), dtype=np.uint8) * 255
87
+ combined_frame = cv2.hconcat([raw_frame, split_region, depth])
88
+
89
+ out.write(combined_frame)
90
+
91
+ raw_video.release()
92
+ out.release()
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 Princeton Vision & Learning Lab
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,10 +1,8 @@
1
  ---
2
- title: >-
3
- Diving Into The Fusion Of Monocular Priors For Generalized Stereo Matching
4
- Demo
5
- emoji: 🏆
6
- colorFrom: indigo
7
- colorTo: blue
8
  sdk: gradio
9
  sdk_version: 5.38.0
10
  app_file: app.py
@@ -12,3 +10,368 @@ pinned: false
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Diving Into The Fusion Of Monocular Priors For Generalized Stereo Matching
3
+ emoji: 😻
4
+ colorFrom: red
5
+ colorTo: indigo
 
 
6
  sdk: gradio
7
  sdk_version: 5.38.0
8
  app_file: app.py
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
13
+ # [ICCV25] Diving into the Fusion of Monocular Priors for Generalized Stereo Matching
14
+
15
+ Detailed images can be found at [Google Driver](https://drive.google.com/file/d/1u2u_-AgxkdtnkQENEf1d2JjtutwrtCPb/view?usp=sharing)
16
+
17
+ <!-- > ⚠️ **Warning**: It is highly recommended to view this markdown in a preview format! -->
18
+ <!-- > ⚠️ **Warning**: We strongly recommend researchers retrain the model on GPUs other than A40 for better results. -->
19
+
20
+
21
+ ## Requirements
22
+ ```Shell
23
+ conda env create -f envs/environment_GStereo.yaml
24
+ conda activate raftstereo
25
+ ```
26
+
27
+
28
+ ## Required Data
29
+ ```Shell
30
+ ├── datasets
31
+ ├── sceneflow
32
+ ├── driving
33
+ │   ├── disparity
34
+ │   ├── frames_cleanpass
35
+ │   └── frames_finalpass
36
+ ├── flying3d
37
+ │   ├── disparity
38
+ │   ├── frames_cleanpass
39
+ │   └── frames_finalpass
40
+ └── monkaa
41
+ ├── disparity
42
+ ├── frames_cleanpass
43
+ └── frames_finalpass
44
+ ├── Kitti15
45
+ ├── testing
46
+ │   ├── image_2
47
+ │   └── image_3
48
+ └── training
49
+ ├── disp_noc_0
50
+ ├── disp_noc_1
51
+ ├── disp_occ_0
52
+ ├── disp_occ_1
53
+ ├── flow_noc
54
+ ├── flow_occ
55
+ ├── image_2
56
+ ├── image_3
57
+ └── obj_map
58
+ ├── Kitti12
59
+ ├── testing
60
+ │   ├── calib
61
+ │   ├── colored_0
62
+ │   ├── colored_1
63
+ │   ├── disp_noc
64
+ │   ├── disp_occ
65
+ │   ├── flow_noc
66
+ │   ├── flow_occ
67
+ │   ├── image_0
68
+ │   └── image_1
69
+ └── training
70
+ ├── calib
71
+ ├── colored_0
72
+ └── colored_1
73
+ ├── Middlebury
74
+ └── MiddEval3
75
+ ├── testF
76
+ ├── testH
77
+ ├── testQ
78
+ ├── trainingF
79
+ ├── trainingH
80
+ └── trainingQ
81
+ ├── ETH3D
82
+ ├── two_view_testing
83
+ └── two_view_training
84
+    ├── delivery_area_1l
85
+    ├── delivery_area_1s
86
+    ├── delivery_area_2l
87
+ ├── Booster
88
+ ├── test
89
+ │   ├── balanced
90
+ │   └── unbalanced
91
+ └── train
92
+ ├── balanced
93
+ └── unbalanced
94
+ ```
95
+
96
+
97
+
98
+ ## Code
99
+ All codes are provided here, including DepthAnything v2.
100
+ Since we modified `dpt.py` to get intermediate features and depth output, please use the modified code.
101
+
102
+
103
+ - ### Training
104
+ All training script is presented in [script/train_stereo_raftstereo.sh](script/train_stereo_raftstereo.sh) and [script/train_stereo_raftstereo_depthany.sh](script/train_stereo_raftstereo_depthany.sh).
105
+ Please specify the following variable in scripts before training.
106
+ | variable | meaning |
107
+ |---------------|----------------------|
108
+ | `NCCL_P2P_DISABLE` | We set `NCCL_P2P_DISABLE=1` as the distributed training went wrong at our `A40` GPU. |
109
+ | `CUDA_VISIBLE_DEVICES` | avaliable GPU id, e.g., `CUDA_VISIBLE_DEVICES=0,1,2,3` |
110
+ | `DATASET_ROOT` | the training dataset path, e.g., `./datasets/sceneflow` |
111
+ | `LOG_ROOT` | path to save log file |
112
+ | `TB_ROOT` | path to save tensorboard data |
113
+ | `CKPOINT_ROOT` | path to save checkpoint |
114
+
115
+
116
+ In order to reproduce our results, please download `depth_anything_v2_vitl.pth` from DepthAnything v2 before training and specify `--depthany_model_dir` in script shell to path of directory where `depth_anything_v2_vitl.pth` is saved. Here, we do not provide the link as it maybe conflicts to the CVPR guideline.
117
+ We also explain the code for ablation study, in which each experiment is mostly controlled by the `--model_name` used in the training shell.
118
+ | `--model_name` | meaning |
119
+ |-----------------|-------------------------|
120
+ | `RaftStereo` | Original RaftStereo model |
121
+ | `RaftStereoDisp` | The output of GRU is a single channel for disparity instead of two channels for optical flow, `Baseline` in Table 3 of the main text. |
122
+ | `RAFTStereoMast3r` | The pre-trained MASt3R is used as the backbone, and its features are used for cost volume construction, `RaftStereo + backbone Mast3r` in supplemental text. |
123
+ | `RaftStereoNoCTX` | RaftStereo model without context network, `Baseline w/o mono feature` in Table 3 of the main text. |
124
+ | `RAFTStereoDepthAny` | RaftStereo model with our monocular encoder, `Baseline + ME` in Table 3 of the main text. |
125
+ | `RAFTStereoDepthFusion` | RaftStereo model with our monocular encoder, `Baseline + ME + IDF` in Table 3 of the main text. |
126
+ | `RAFTStereoDepthBeta` | RaftStereo model with our monocular encoder and iterative local fusion, `Baseline + ME + ILF` in Table 3 of the main text. |
127
+ | `RAFTStereoDepthBetaNoLBP` | RaftStereo model with our monocular encoder and iterative local fusion without LBPEncoder, `L(6)` and `L(7)` in Table 4 of the main text. |
128
+ | `RAFTStereoDepthMatch` | RaftStereo model with DepthAnything v2 as feature extractor for cost volume construction, `RaftStereo + backbone DepthAnything` in the supplemental text. |
129
+ | `RAFTStereoDepthPostFusion` | RaftStereo model with our monocular encoder, iterative local fusion and post fusion, `Baseline + ME + PF` in Table 3 of the main text. |
130
+ | `RAFTStereoDepthBetaRefine` | RaftStereo model with our monocular encoder, iterative local fusion, and global fusion, `Baseline + ME + ILF + GF` in Table 3 of the main text. |
131
+
132
+
133
+ | variable | meaning |
134
+ |--------------------------|-------------------------|
135
+ | `--lbp_neighbor_offsets` | control `LBP Kernel` used in Table 4 of the main text. |
136
+ | `--modulation_ratio` | control `r` amplitude parameter used in Table 4 of the main text. |
137
+ | `--conf_from_fea` | `Cost` or `Hybrid` for `Confidence` used in Table 4 of the main text. |
138
+ | `--refine_pool` | learning registration parameters via pooling in the supplemental text. |
139
+
140
+
141
+ The training is launched by following
142
+ ```Shell
143
+ bash ./script/train_stereo_raftstereo_depthany.sh EXP_NAME
144
+ ```
145
+ `EXP_NAME` specifies the experiment name. We use this name to save each log file, tensorboard data, and checkpoint for different experiments. The corresponding file structure is as follows
146
+ ```Shell
147
+ ├── runs
148
+    ├── ckpoint
149
+ │ ├── RaftStereoDepthAny
150
+ │ ├── RaftStereoMast3r
151
+ │ └── RaftStereoNoCTX
152
+    ├── log
153
+ │ ├── RaftStereoDepthAny
154
+ │ ├── RaftStereoMast3r
155
+ │ └── RaftStereoNoCTX
156
+    └── tboard
157
+ ├── RaftStereoDepthAny
158
+ ├── RaftStereoMast3r
159
+ └── RaftStereoNoCTX
160
+ ```
161
+ > ⚠️ **Warning**: **Please follow the training process mentioned in our main text.** We first train the model without the global fusion module. Then, we train the monocular registration of the global fusion module while keeping the other modules frozen with a well-trained model from the first stage. Finally, we train the entire global fusion module while keeping the other modules frozen with a well-trained model from the second stage.
162
+
163
+ - ### Evaluation
164
+ The evaluation script is presented in [script/evaluate_stereo_raftstereo.sh](script/evaluate_stereo_raftstereo.sh).
165
+ We use `--test_exp_name` to specify the evaluation experiment name.
166
+ The results of each experiment are restored in `LOG_ROOT/eval.xlsx`. We also merge all experiments' results in `LOG_ROOT/merged_eval.xlsx` through `python3 merge_sheet.py`.
167
+ The evaluation metrics remain the same for different methods.
168
+ The `mean ± std` is computed via [tools/get_statistics.py](tools/get_statistics.py).
169
+
170
+ - ### Visualization
171
+ We visualize the error map via [script/gen_sample_stereo_raftstereo.sh](script/gen_sample_stereo_raftstereo.sh) and intermediate results via [script/vis_inter_stereo_raftstereo.sh](script/vis_inter_stereo_raftstereo.sh).
172
+ We provide an easy-to-use visualization toolbox to fully understand each module.
173
+
174
+ - ### Demo
175
+ The model weights, pre-trained on SceneFlow, can be downloaded from [Google Drive](https://drive.google.com/file/d/1T1o7soh3p4C_tHzmUd0ZCtnQbVczPmXz/view?usp=sharing).
176
+ The demo used to infer disparity maps from custom image pairs is presented in `infer_stereo_raftstereo.py`. For specific usage, please refer to `script/infer_stereo_raftstereo.sh`.
177
+
178
+
179
+ ## More Results
180
+ The results after using our custom synthetic data [Trans Dataset](https://github.com/BFZD233/TranScene), which is built for multi-label transparent scenes.
181
+
182
+ <table>
183
+ <thead>
184
+ <tr>
185
+ <th rowspan="3">Method</th>
186
+ <th colspan="21">Booster</th>
187
+ </tr>
188
+ <tr>
189
+ <th colspan="7">ALL</th>
190
+ <th colspan="7">Trans</th>
191
+ <th colspan="7">No_Trans</th>
192
+ </tr>
193
+ <tr>
194
+ <th>EPE</th>
195
+ <th>RMSE</th>
196
+ <th>2px</th>
197
+ <th>3px</th>
198
+ <th>5px</th>
199
+ <th>6px</th>
200
+ <th>8px</th>
201
+ <th>EPE</th>
202
+ <th>RMSE</th>
203
+ <th>2px</th>
204
+ <th>3px</th>
205
+ <th>5px</th>
206
+ <th>6px</th>
207
+ <th>8px</th>
208
+ <th>EPE</th>
209
+ <th>RMSE</th>
210
+ <th>2px</th>
211
+ <th>3px</th>
212
+ <th>5px</th>
213
+ <th>6px</th>
214
+ <th>8px</th>
215
+ </tr>
216
+ </thead>
217
+ <tbody>
218
+ <tr>
219
+ <td>Ours</td>
220
+ <td>2.26</td>
221
+ <td>5.60</td>
222
+ <td>11.02</td>
223
+ <td>8.59</td>
224
+ <td>6.60</td>
225
+ <td>6.00</td>
226
+ <td>5.35</td>
227
+ <td>7.93</td>
228
+ <td>11.03</td>
229
+ <td>59.83</td>
230
+ <td>50.36</td>
231
+ <td>38.44</td>
232
+ <td>33.87</td>
233
+ <td>27.56</td>
234
+ <td>1.52</td>
235
+ <td>3.93</td>
236
+ <td>6.98</td>
237
+ <td>4.97</td>
238
+ <td>3.64</td>
239
+ <td>3.27</td>
240
+ <td>2.89</td>
241
+ </tr>
242
+ <tr>
243
+ <td>Ours+Trans</td>
244
+ <td>1.24</td>
245
+ <td>4.19</td>
246
+ <td>7.91</td>
247
+ <td>5.97</td>
248
+ <td>4.52</td>
249
+ <td>4.08</td>
250
+ <td>3.44</td>
251
+ <td>5.67</td>
252
+ <td>8.42</td>
253
+ <td>46.78</td>
254
+ <td>38.55</td>
255
+ <td>28.65</td>
256
+ <td>25.41</td>
257
+ <td>21.30</td>
258
+ <td>0.75</td>
259
+ <td>3.07</td>
260
+ <td>4.77</td>
261
+ <td>3.23</td>
262
+ <td>2.29</td>
263
+ <td>2.01</td>
264
+ <td>1.59</td>
265
+ </tr>
266
+ </tbody>
267
+ </table>
268
+
269
+ <table>
270
+ <thead>
271
+ <tr>
272
+ <th rowspan="3">Method</th>
273
+ <th colspan="28">Booster</th>
274
+ </tr>
275
+ <tr>
276
+ <th colspan="7">Class 0</th>
277
+ <th colspan="7">Class 1</th>
278
+ <th colspan="7">Class 2</th>
279
+ <th colspan="7">Class 3</th>
280
+ </tr>
281
+ <tr>
282
+ <th>EPE</th>
283
+ <th>RMSE</th>
284
+ <th>2px</th>
285
+ <th>3px</th>
286
+ <th>5px</th>
287
+ <th>6px</th>
288
+ <th>8px</th>
289
+ <th>EPE</th>
290
+ <th>RMSE</th>
291
+ <th>2px</th>
292
+ <th>3px</th>
293
+ <th>5px</th>
294
+ <th>6px</th>
295
+ <th>8px</th>
296
+ <th>EPE</th>
297
+ <th>RMSE</th>
298
+ <th>2px</th>
299
+ <th>3px</th>
300
+ <th>5px</th>
301
+ <th>6px</th>
302
+ <th>8px</th>
303
+ <th>EPE</th>
304
+ <th>RMSE</th>
305
+ <th>2px</th>
306
+ <th>3px</th>
307
+ <th>5px</th>
308
+ <th>6px</th>
309
+ <th>8px</th>
310
+ </tr>
311
+ </thead>
312
+ <tbody>
313
+ <tr>
314
+ <td>Ours</td>
315
+ <td>0.79</td>
316
+ <td>3.02</td>
317
+ <td>5.90</td>
318
+ <td>4.57</td>
319
+ <td>3.17</td>
320
+ <td>2.58</td>
321
+ <td>1.45</td>
322
+ <td>1.53</td>
323
+ <td>4.70</td>
324
+ <td>12.67</td>
325
+ <td>7.80</td>
326
+ <td>4.88</td>
327
+ <td>3.96</td>
328
+ <td>3.14</td>
329
+ <td>5.32</td>
330
+ <td>6.39</td>
331
+ <td>23.34</td>
332
+ <td>17.62</td>
333
+ <td>13.50</td>
334
+ <td>12.80</td>
335
+ <td>12.15</td>
336
+ <td>7.93</td>
337
+ <td>11.03</td>
338
+ <td>59.83</td>
339
+ <td>50.36</td>
340
+ <td>38.44</td>
341
+ <td>33.87</td>
342
+ <td>27.56</td>
343
+ </tr>
344
+ <tr>
345
+ <td>Ours+Trans</td>
346
+ <td>0.75</td>
347
+ <td>2.99</td>
348
+ <td>5.15</td>
349
+ <td>4.08</td>
350
+ <td>3.00</td>
351
+ <td>2.59</td>
352
+ <td>1.73</td>
353
+ <td>1.40</td>
354
+ <td>4.74</td>
355
+ <td>9.17</td>
356
+ <td>5.63</td>
357
+ <td>3.80</td>
358
+ <td>3.37</td>
359
+ <td>2.86</td>
360
+ <td>1.62</td>
361
+ <td>2.26</td>
362
+ <td>13.51</td>
363
+ <td>10.23</td>
364
+ <td>7.40</td>
365
+ <td>6.50</td>
366
+ <td>4.93</td>
367
+ <td>5.67</td>
368
+ <td>8.42</td>
369
+ <td>46.78</td>
370
+ <td>38.55</td>
371
+ <td>28.65</td>
372
+ <td>25.41</td>
373
+ <td>21.30</td>
374
+ </tr>
375
+ </tbody>
376
+ </table>
377
+
abs_cost/abs_cost_kernel.cu ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <cuda.h>
3
+ #include <cuda_runtime.h>
4
+ #include <vector>
5
+ #include <cuda_fp16.h>
6
+ #include <cuda_runtime.h>
7
+ #include <math.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/NativeFunctions.h>
11
+ #include <ATen/Parallel.h>
12
+
13
+ #define BLOCK 16
14
+
15
+ // (B,H,W1,C) (B,H,W2,C) -> (B,H,W1,W2)
16
+
17
+ __forceinline__ __device__ bool within_bounds(int h, int w1, int w2, int H, int W1, int W2) {
18
+ return h >= 0 && h < H && w1 >= 0 && w1 < W1 && w2 >= 0 && w2 < W2;
19
+ }
20
+
21
+ template <typename scalar_t>
22
+ __global__ void absolute_difference_forward_kernel(
23
+ const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1,
24
+ const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2,
25
+ torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> result)
26
+ {
27
+ const int C = fmap1.size(3);
28
+ const int H = fmap1.size(1);
29
+ const int W1 = fmap1.size(2);
30
+ const int W2 = fmap2.size(2);
31
+
32
+ // 获取当前线程的索引
33
+ const int w1 = blockIdx.x * blockDim.x + threadIdx.x;
34
+ const int w2 = blockIdx.y * blockDim.y + threadIdx.y;
35
+ const int h = blockIdx.z % H;
36
+ const int b = blockIdx.z / H;
37
+
38
+ if (!within_bounds(h, w1, w2, H, W1, W2)) {
39
+ return;
40
+ }
41
+
42
+ scalar_t sum = 0.0;
43
+ for (int c = 0; i < C; ++c) {
44
+ scalar_t diff = fabs(fmap1[b][h][w1][c] - fmap2[b][h][w2][c]);
45
+ sum += diff;
46
+ }
47
+
48
+ result[b][h][w1][w2] = sum;
49
+ }
50
+
51
+ template <typename scalar_t>
52
+ __global__ void absolute_difference_backward_kernel_fmap1(
53
+ const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1,
54
+ const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2,
55
+ const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> grad_output,
56
+ torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> grad_fmap1)
57
+ {
58
+ const int k = blockIdx.x * blockDim.x + threadIdx.x;
59
+ const int h = blockIdx.y * blockDim.y + threadIdx.y;
60
+ const int n = blockIdx.z;
61
+
62
+ const int i_size = fmap1.size(1);
63
+ const int j_size = fmap1.size(2);
64
+ const int k_size = fmap1.size(3);
65
+ const int h_size = fmap2.size(3);
66
+
67
+ if (!within_bounds(h, k, j_size, k_size)) {
68
+ return;
69
+ }
70
+
71
+ for (int i = 0; i < i_size; ++i) {
72
+ for (int j = 0; j < j_size; ++j) {
73
+ scalar_t grad = 0.0;
74
+
75
+ scalar_t diff = fmap1[n][i][j][k] - fmap2[n][i][j][h];
76
+ if (diff >= 0) {
77
+ grad = grad_output[n][h][k][h];
78
+ } else {
79
+ grad = -grad_output[n][h][k][h];
80
+ }
81
+
82
+ grad_fmap1[n][i][j][k] += grad;
83
+ }
84
+ }
85
+ }
86
+
87
+ template <typename scalar_t>
88
+ __global__ void absolute_difference_backward_kernel_fmap2(
89
+ const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1,
90
+ const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2,
91
+ const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> grad_output,
92
+ torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> grad_fmap2)
93
+ {
94
+ const int k = blockIdx.x * blockDim.x + threadIdx.x;
95
+ const int h = blockIdx.y * blockDim.y + threadIdx.y;
96
+ const int n = blockIdx.z;
97
+
98
+ const int i_size = fmap1.size(1);
99
+ const int j_size = fmap1.size(2);
100
+ const int k_size = fmap1.size(3);
101
+ const int h_size = fmap2.size(3);
102
+
103
+ if (!within_bounds(h, k, j_size, k_size)) {
104
+ return;
105
+ }
106
+
107
+ for (int i = 0; i < i_size; ++i) {
108
+ for (int j = 0; j < j_size; ++j) {
109
+ scalar_t grad = 0.0;
110
+
111
+ scalar_t diff = fmap2[n][i][j][h] - fmap1[n][i][j][k];
112
+ if (diff >= 0) {
113
+ grad = grad_output[n][h][k][h];
114
+ } else {
115
+ grad = -grad_output[n][h][k][h];
116
+ }
117
+
118
+ grad_fmap2[n][i][j][h] += grad;
119
+ }
120
+ }
121
+ }
122
+
123
+ /**
124
+ * compute correlation between each element (h,w1)~(h,w2).
125
+ * (B,H,W1,C) (B,H,W2,C) -> (B,H,W1,W2)
126
+ */
127
+ std::vector<torch::Tensor> absolute_difference_cuda_forward(
128
+ torch::Tensor fmap1,
129
+ torch::Tensor fmap2)
130
+ {
131
+ const auto B = fmap1.size(0);
132
+ const auto H = fmap1.size(1);
133
+ const auto W1 = fmap1.size(2);
134
+ const auto W2 = fmap2.size(2);
135
+
136
+ const dim3 blocks((W1 + BLOCK - 1) / BLOCK,
137
+ (W2 + BLOCK - 1) / BLOCK,
138
+ B*H);
139
+
140
+ const dim3 threads(BLOCK, BLOCK);
141
+
142
+ auto opts = fmap1.options();
143
+ torch::Tensor result = torch::zeros({B, H, W1, W2}, opts);
144
+
145
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(fmap1.scalar_type(), "absolute_difference_forward_kernel", ([&] {
146
+ absolute_difference_forward_kernel<scalar_t><<<blocks, threads>>>(
147
+ fmap1.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
148
+ fmap2.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
149
+ result.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>());
150
+ }));
151
+
152
+ return {result};
153
+ }
154
+
155
+ std::vector<torch::Tensor> absolute_difference_cuda_backward(
156
+ torch::Tensor fmap1,
157
+ torch::Tensor fmap2,
158
+ torch::Tensor grad_output)
159
+ {
160
+ const auto B = fmap1.size(0);
161
+ const auto H = fmap1.size(1);
162
+ const auto W1 = fmap1.size(2);
163
+ const auto W2 = fmap2.size(2);
164
+
165
+ auto grad_fmap1 = torch::zeros_like(fmap1);
166
+ auto grad_fmap2 = torch::zeros_like(fmap2);
167
+
168
+ const dim3 blocks((k_size + BLOCK - 1) / BLOCK,
169
+ (h_size + BLOCK - 1) / BLOCK,
170
+ batch_size);
171
+
172
+ const dim3 threads(BLOCK, BLOCK);
173
+
174
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(fmap1.scalar_type(), "absolute_difference_backward_kernel_fmap1", ([&] {
175
+ absolute_difference_backward_kernel_fmap1<scalar_t><<<blocks, threads>>>(
176
+ fmap1.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
177
+ fmap2.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
178
+ grad_output.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
179
+ grad_fmap1.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>());
180
+ }));
181
+
182
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(fmap2.scalar_type(), "absolute_difference_backward_kernel_fmap2", ([&] {
183
+ absolute_difference_backward_kernel_fmap2<scalar_t><<<blocks, threads>>>(
184
+ fmap1.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
185
+ fmap2.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
186
+ grad_output.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
187
+ grad_fmap2.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>());
188
+ }));
189
+
190
+ return {grad_fmap1, grad_fmap2};
191
+ }
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, division
2
+ import sys
3
+ sys.path.insert(0,'core')
4
+ sys.path.append('core/utils')
5
+
6
+ import os
7
+ import argparse
8
+ import gradio as gr
9
+ import cv2
10
+ from core.raft_stereo_depthbeta_refine import RAFTStereoDepthBetaRefine
11
+ import torch
12
+ import torch.nn as nn
13
+ from core.utils.utils import InputPadder
14
+ import matplotlib.pyplot as plt
15
+ from huggingface_hub import hf_hub_download
16
+
17
+
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument('--root', help="dataset root", default=None)
20
+ parser.add_argument('--sv_root', help="visualization root", default=None)
21
+ parser.add_argument('--test_exp_name', default='', help="name your experiment in testing")
22
+ parser.add_argument('--mast3r_model_path', default='MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth', help="pretrained model path for MaSt3R")
23
+ parser.add_argument('--depthany_model_dir', default='./dav2_models', help="directory of pretrained model path for DepthAnything")
24
+ parser.add_argument('--restore_ckpt', help="restore checkpoint", default="./ckpts/diving_stereo.pth")
25
+ parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
26
+ parser.add_argument('--valid_iters', type=int, default=32, help='number of flow-field updates during forward pass')
27
+ parser.add_argument('--eval', action='store_true', help='evaluation mode')
28
+ parser.add_argument('--is_test', action='store_true', help='on testing')
29
+
30
+ # Architecure choices
31
+ parser.add_argument('--hidden_dims', nargs='+', type=int, default=[128]*3, help="hidden state and context dimensions")
32
+ parser.add_argument('--corr_implementation', choices=["reg", "alt", "reg_cuda", "alt_cuda"], default="reg", help="correlation volume implementation")
33
+ parser.add_argument('--shared_backbone', action='store_true', help="use a single backbone for the context and feature encoders")
34
+ parser.add_argument('--corr_levels', type=int, default=4, help="number of levels in the correlation pyramid")
35
+ parser.add_argument('--corr_radius', type=int, default=4, help="width of the correlation pyramid")
36
+ parser.add_argument('--n_downsample', type=int, default=2, help="resolution of the disparity field (1/2^K)")
37
+ parser.add_argument('--context_norm', type=str, default="batch", choices=['group', 'batch', 'instance', 'none'], help="normalization of context encoder")
38
+ parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
39
+ parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
40
+
41
+ parser.add_argument('--lbp_neighbor_offsets', default='(-5,-5), (5,5), (5,-5), (-5,5), (-3,0), (3,0), (0,-3), (0,3)', help="determine the neighbors used in LBP encoder")
42
+ parser.add_argument('--modulation_ratio', type=float, default=1., help="hyperparameters for modulation")
43
+ parser.add_argument('--modulation_alg', choices=["linear", "sigmoid"], default="linear", help="rescale modulation")
44
+ parser.add_argument('--conf_from_fea', action='store_true', help="confidence in refinement not only from cost volume but also from other features")
45
+ parser.add_argument('--refine_pool', action='store_true', help="use pooling in refinement")
46
+ parser.add_argument('--refine_unet', action='store_true', help="use EfficientUnet in refinement")
47
+
48
+ parser.add_argument('--improvement', action='store_true', help="visualize improvement map (error_map[i] - error_map[i-1])")
49
+ parser.add_argument('--movement', action='store_true', help="visualize movement map (flow_pr[i] - flow_pr[i-1])")
50
+ parser.add_argument('--acceleration', action='store_true', help="visualize acceleration map (movement_map[i] - movement_map[i-1])")
51
+ parser.add_argument('--mask', action='store_true', help="visualize mask")
52
+ parser.add_argument('--binary_thold', type=float, default=0.5, help="visualize binary mask")
53
+
54
+ args = parser.parse_args()
55
+ args.conf_from_fea = True
56
+ args.eval = True
57
+
58
+ model = RAFTStereoDepthBetaRefine(args)
59
+ model = torch.nn.DataParallel(model, device_ids=[0])
60
+
61
+
62
+ checkpoint_path = hf_hub_download(
63
+ repo_id="BFZD/Diving-into-the-Fusion-of-Monocular-Priors-for-Generalized-Stereo-Matching",
64
+ filename="ckpts/diving_stereo.pth",
65
+ )
66
+
67
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
68
+ # model.load_state_dict(checkpoint, strict=True)
69
+ new_state_dict = {}
70
+ for key, value in checkpoint.items():
71
+ if key.find("lbp_encoder.lbp_conv") != -1:
72
+ continue
73
+ new_state_dict[key] = value
74
+ # model.load_state_dict(new_state_dict, strict=True)
75
+ model.load_state_dict(new_state_dict, strict=False)
76
+
77
+ # model.cuda()
78
+ model.eval()
79
+
80
+
81
+
82
+ def predict(image1, image2):
83
+ with torch.no_grad():
84
+ image1 = torch.from_numpy(image1).permute(2, 0, 1).float()
85
+ image2 = torch.from_numpy(image2).permute(2, 0, 1).float()
86
+ image1 = image1[None][:,:3,:,:]
87
+ image2 = image2[None][:,:3,:,:]
88
+ padder = InputPadder(image1.shape, divis_by=32)
89
+ image1, image2 = padder.pad(image1, image2)
90
+ atom_dict = model(image1, image2, iters=args.valid_iters, test_mode=False, vis_mode=True)
91
+ output = atom_dict['disp_predictions'][-1].abs().cpu().numpy()
92
+ disp = padder.unpad(output)
93
+ disp = disp.squeeze()
94
+ normalized_disp = (disp - disp.min()) / (disp.max() - disp.min())
95
+ cmap = plt.get_cmap('jet')
96
+ colored_disp = cmap(normalized_disp)[:, :, :3] # Get RGB channels
97
+
98
+ return colored_disp
99
+ interface = gr.Interface(fn=predict,
100
+ inputs=[gr.Image(label="Left Image"),
101
+ gr.Image(label="Right Image")],
102
+ outputs="image")
103
+ interface.launch()
core/ManStereo.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import logging
4
+ import numpy as np
5
+ from datetime import datetime
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from core.update import ManifoldBasicMultiUpdateBlock
12
+ from core.extractor import BasicEncoder, MultiBasicEncoder, ResidualBlock
13
+ from core.corr import CorrBlock1D, PytorchAlternateCorrBlock1D, CorrBlockFast1D, AlternateCorrBlock
14
+ from core.utils.utils import coords_grid, upflow8, LoggerCommon
15
+ from core.confidence import OffsetConfidence
16
+ from core.refinement import Refinement, UpdateHistory
17
+ from core import geometry as GEO
18
+ from core.utils.plane import get_pos, convert2patch, predict_disp
19
+
20
+ logger = LoggerCommon("ARCHI")
21
+
22
+ try:
23
+ autocast = torch.cuda.amp.autocast
24
+ except:
25
+ # dummy autocast for PyTorch < 1.6
26
+ class autocast:
27
+ def __init__(self, enabled):
28
+ pass
29
+ def __enter__(self):
30
+ pass
31
+ def __exit__(self, *args):
32
+ pass
33
+
34
+ class RAFTStereo(nn.Module):
35
+ def __init__(self, args):
36
+ super().__init__()
37
+ self.args = args
38
+
39
+ context_dims = args.hidden_dims
40
+
41
+ self.cnet = MultiBasicEncoder(output_dim=[args.hidden_dims, context_dims], norm_fn=args.context_norm, downsample=args.n_downsample)
42
+ self.update_block = ManifoldBasicMultiUpdateBlock(self.args, hidden_dims=args.hidden_dims)
43
+
44
+ self.context_zqr_convs = nn.ModuleList([nn.Conv2d(context_dims[i], args.hidden_dims[i]*3, 3, padding=3//2) for i in range(self.args.n_gru_layers)])
45
+
46
+ if args.shared_backbone:
47
+ self.conv2 = nn.Sequential(
48
+ ResidualBlock(128, 128, 'instance', stride=1),
49
+ nn.Conv2d(128, 256, 3, padding=1))
50
+ else:
51
+ self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', downsample=args.n_downsample)
52
+
53
+ if args.confidence:
54
+ self.confidence_computer = OffsetConfidence(args)
55
+
56
+ if args.geo_estimator=="geometry_mlp":
57
+ self.geometry_builder = GEO.Geometry_MLP(args)
58
+ elif args.geo_estimator=="geometry_conv":
59
+ self.geometry_builder = GEO.Geometry_Conv(args)
60
+ elif args.geo_estimator=="geometry_conv_split":
61
+ self.geometry_builder = GEO.Geometry_Conv_Split(args)
62
+
63
+ if args.refinement is not None and len(args.refinement)>0:
64
+ if self.args.slant is None or len(self.args.slant)==0 :
65
+ dim_disp = 1
66
+ elif self.args.slant in ["slant", "slant_local"] :
67
+ dim_disp = 6
68
+
69
+ if args.refinement.lower()=="refinement":
70
+ self.refine = Refinement(args, in_chans=256, dim_fea=96, dim_disp=dim_disp)
71
+ else:
72
+ raise Exception("No such refinement: {}".format(args.refinement))
73
+
74
+ if self.args.update_his:
75
+ self.update_hist = UpdateHistory(args, 128, dim_disp)
76
+
77
+ logger.info(f"RAFTStereo ~ " +\
78
+ f"Confidence: {args.confidence}, offset_memory_size: {args.offset_memory_size}, " +\
79
+ f"offset_memory_last_iter: {args.offset_memory_last_iter}, " +\
80
+ f"slant: {args.slant}, slant_norm: {args.slant_norm}, " +\
81
+ f"geo estimator: {args.geo_estimator}, geo_fusion: {args.geo_fusion}, " +\
82
+ f"refine: {args.refinement}, refine_win_size: {args.refine_win_size}, num_heads:{args.num_heads}, " +\
83
+ f"split_win: {args.split_win}, refine_start_itr: {args.refine_start_itr}, " +\
84
+ f"update_his: {args.update_his}, U_thold: {args.U_thold}, " +\
85
+ f"stop_freeze_bn: {args.stop_freeze_bn}" )
86
+
87
+ def freeze_bn(self):
88
+ for m in self.modules():
89
+ if isinstance(m, nn.BatchNorm2d):
90
+ m.eval()
91
+
92
+ def initialize_flow(self, img):
93
+ """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
94
+ N, _, H, W = img.shape
95
+
96
+ coords0 = coords_grid(N, H, W).to(img.device)
97
+ coords1 = coords_grid(N, H, W).to(img.device)
98
+
99
+ return coords0, coords1
100
+
101
+ def upsample_flow(self, flow, mask):
102
+ """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
103
+ N, D, H, W = flow.shape
104
+ factor = 2 ** self.args.n_downsample
105
+ mask = mask.view(N, 1, 9, factor, factor, H, W)
106
+ mask = torch.softmax(mask, dim=2)
107
+
108
+ up_flow = F.unfold(factor * flow, [3,3], padding=1)
109
+ up_flow = up_flow.view(N, D, 9, 1, 1, H, W)
110
+ up_flow = torch.sum(mask * up_flow, dim=2)
111
+
112
+ img_coord = None
113
+ if self.args.geo_estimator is not None and len(self.args.geo_estimator)>0:
114
+ img_coord = get_pos(H*factor, W*factor, disp=None,
115
+ slant=self.args.slant,
116
+ slant_norm=self.args.slant_norm,
117
+ patch_size=factor,
118
+ device=flow.device) # (1,2,H*factor,W*factor)
119
+ img_coord = img_coord.repeat(N,1,1,1)
120
+
121
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
122
+ return up_flow.reshape(N, D, factor*H, factor*W), img_coord
123
+
124
+ def upsample_geo(self, mask=None, mask_disp=None, params=None):
125
+ """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
126
+ N, D, H, W = params.shape
127
+ factor = 2 ** self.args.n_downsample
128
+ if mask is not None:
129
+ mask = mask.view(N, 1, 9, factor, factor, H, W)
130
+ mask = torch.softmax(mask, dim=2) # (B,1,9,factor,factor,H,W)
131
+ if mask_disp is not None:
132
+ mask_disp = mask_disp.view(N, 1, 9, factor, factor, H, W)
133
+ mask_disp = torch.softmax(mask_disp, dim=2) # (B,1,9,factor,factor,H,W)
134
+
135
+ # d_p = a_q\cdot\Delta u_{q\to p} + b_q\cdot\Delta v_{q\to p} + d_q
136
+ delta_pq = get_pos(H*factor, W*factor, disp=None,
137
+ slant=self.args.slant,
138
+ slant_norm=self.args.slant_norm,
139
+ patch_size=factor,
140
+ device=params.device) # (1,2,H*factor,W*factor)
141
+ patch_delta_pq = convert2patch(delta_pq, patch_size=factor, div_last=False).detach() # (1,2,factor*factor,H,W)
142
+
143
+ disp = predict_disp(params, patch_delta_pq, patch_size=factor, mul_last=True) # (B,factor*factor,H,W)
144
+
145
+ if mask_disp is not None:
146
+ disp = F.unfold(disp, [3,3], padding=1) # (B,factor*factor*9,H,W)
147
+ disp = disp.view(N, 1, factor, factor, 9, H, W) # (B,1,factor,factor,9,H,W)
148
+ disp = disp.permute((0,1,4,2,3,5,6)) # (B,1,9,factor,factor,H,W)
149
+ disp = torch.sum(mask_disp * disp, dim=2) # (B,1,factor,factor,H,W)
150
+ disp = disp.permute(0, 1, 4, 2, 5, 3) # (B,1,H,factor,W,factor)
151
+ return disp.reshape(N, 1, factor*H, factor*W)
152
+
153
+ elif mask is not None:
154
+ disp = F.unfold(disp, [3,3], padding=1) # (B,factor*factor*9,H,W)
155
+ disp = disp.view(N, 1, factor, factor, 9, H, W) # (B,1,factor,factor,9,H,W)
156
+ disp = disp.permute((0,1,4,2,3,5,6)) # (B,1,9,factor,factor,H,W)
157
+ disp = torch.sum(mask * disp, dim=2) # (B,1,factor,factor,H,W)
158
+ disp = disp.permute(0, 1, 4, 2, 5, 3) # (B,1,H,factor,W,factor)
159
+ return disp.reshape(N, 1, factor*H, factor*W)
160
+
161
+ disp = F.fold(disp.flatten(-2,-1), (H*factor,W*factor), kernel_size=factor, stride=factor).view(N,1,H*factor,W*factor)
162
+ return disp
163
+
164
+
165
+ def forward(self, image1, image2, iters=12, flow_init=None,
166
+ test_mode=False, vis_mode=False, enable_refinement=True):
167
+ """ Estimate optical flow between pair of frames """
168
+
169
+ image1 = (2 * (image1 / 255.0) - 1.0).contiguous()
170
+ image2 = (2 * (image2 / 255.0) - 1.0).contiguous()
171
+
172
+ # run the context network
173
+ with autocast(enabled=self.args.mixed_precision):
174
+ if self.args.shared_backbone:
175
+ *cnet_list, x = self.cnet(torch.cat((image1, image2), dim=0), dual_inp=True, num_layers=self.args.n_gru_layers)
176
+ fmap1, fmap2 = self.conv2(x).split(dim=0, split_size=x.shape[0]//2)
177
+ else:
178
+ cnet_list = self.cnet(image1, num_layers=self.args.n_gru_layers)
179
+ fmap1, fmap2 = self.fnet([image1, image2])
180
+ net_list = [torch.tanh(x[0]) for x in cnet_list]
181
+ inp_list = [torch.relu(x[1]) for x in cnet_list]
182
+
183
+ # Rather than running the GRU's conv layers on the context features multiple times, we do it once at the beginning
184
+ inp_list = [list(conv(i).split(split_size=conv.out_channels//3, dim=1)) for i,conv in zip(inp_list, self.context_zqr_convs)]
185
+
186
+ if self.args.corr_implementation == "reg": # Default
187
+ corr_block = CorrBlock1D
188
+ fmap1, fmap2 = fmap1.float(), fmap2.float()
189
+ elif self.args.corr_implementation == "alt": # More memory efficient than reg
190
+ corr_block = PytorchAlternateCorrBlock1D
191
+ fmap1, fmap2 = fmap1.float(), fmap2.float()
192
+ elif self.args.corr_implementation == "reg_cuda": # Faster version of reg
193
+ corr_block = CorrBlockFast1D
194
+ elif self.args.corr_implementation == "alt_cuda": # Faster version of alt
195
+ corr_block = AlternateCorrBlock
196
+ corr_fn = corr_block(fmap1, fmap2, radius=self.args.corr_radius, num_levels=self.args.corr_levels)
197
+
198
+ coords0, coords1 = self.initialize_flow(net_list[0])
199
+
200
+ if flow_init is not None:
201
+ coords1 = coords1 + flow_init
202
+
203
+ flow_predictions = []
204
+ disp_predictions = []
205
+ disp_predictions_refine = []
206
+ params_list = []
207
+ params_list_refine = []
208
+ confidence_list = []
209
+ offset_memory = []
210
+ for itr in range(iters):
211
+ coords1 = coords1.detach()
212
+ corr = corr_fn(coords1) # index correlation volume
213
+ flow = coords1 - coords0
214
+
215
+ with autocast(enabled=self.args.mixed_precision):
216
+ ## first-stage in geometry estimation
217
+ if self.args.n_gru_layers == 3 and self.args.slow_fast_gru: # Update low-res GRU
218
+ net_list = self.update_block(net_list, inp_list, iter32=True, iter16=False, iter08=False, update=False)
219
+ if self.args.n_gru_layers >= 2 and self.args.slow_fast_gru:# Update low-res GRU and mid-res GRU
220
+ net_list = self.update_block(net_list, inp_list, iter32=self.args.n_gru_layers==3, iter16=True, iter08=False, update=False)
221
+ net_list, up_mask, delta_flow, up_mask_disp = self.update_block(net_list, inp_list, corr, flow, iter32=self.args.n_gru_layers==3, iter16=self.args.n_gru_layers>=2)
222
+
223
+ ## region detection: acquire confidence
224
+ if self.args.confidence:
225
+ offset_memory.append(delta_flow[:,0:2])
226
+ if itr<self.args.offset_memory_size:
227
+ confidence = None
228
+ else:
229
+ if self.args.offset_memory_last_iter<0 or itr<=self.args.offset_memory_last_iter:
230
+ input_offset_mem = offset_memory[-self.args.offset_memory_size:]
231
+ else:
232
+ start_itr = self.args.offset_memory_last_iter - self.args.offset_memory_size
233
+ end_itr = self.args.offset_memory_last_iter
234
+ input_offset_mem = offset_memory[start_itr:end_itr]
235
+ confidence = self.confidence_computer(inp_list[0], input_offset_mem)
236
+ else:
237
+ confidence = None
238
+ confidence_list.append(confidence)
239
+
240
+ # in stereo mode, project flow onto epipolar
241
+ delta_flow[:,1] = 0.0
242
+
243
+ # F(t+1) = F(t) + \Delta(t)
244
+ coords1 = coords1 + delta_flow
245
+ flow = coords1 - coords0
246
+
247
+ # We do not need to upsample or output intermediate results in test_mode for raftStereo
248
+ if test_mode and itr < iters-1 and \
249
+ (self.args.refinement is None or len(self.args.refinement)==0):
250
+ continue
251
+
252
+ # upsample disparity map
253
+ if up_mask is None:
254
+ flow_up = upflow8(flow)
255
+ else:
256
+ flow_up, img_coord = self.upsample_flow(flow, up_mask)
257
+ flow_up = flow_up[:,:1]
258
+ flow_predictions.append(flow_up)
259
+
260
+ # second-stage in geometry estimation
261
+ geo_params = None
262
+ disparity = -flow[:,:1]
263
+ if self.args.geo_estimator is not None and len(self.args.geo_estimator)>0:
264
+ geo_params = self.geometry_builder(img_coord, -flow_up, disparity)
265
+
266
+ # disp_up = self.upsample_geo(up_mask, params=geo_params)
267
+ disp_up = self.upsample_geo(mask=None, mask_disp=up_mask_disp, params=geo_params)
268
+ params_list.append(geo_params)
269
+ disp_predictions.append(disp_up)
270
+
271
+ ## curvature-aware propagation
272
+ disparity_refine = None
273
+ geo_params_refine = None
274
+ if self.args.refinement is not None and len(self.args.refinement)>0 and enable_refinement:
275
+ if itr>=self.args.refine_start_itr:
276
+ geo_params_refine = self.refine(geo_params, inp_list[0], confidence,
277
+ if_shift=(itr-self.args.refine_start_itr)%2>0)
278
+ coords1 = coords0 - geo_params_refine[:,:1]
279
+ disparity_refine = geo_params_refine[:,:1]
280
+ ### update hidden state
281
+ if self.args.update_his:
282
+ net_list[0] = self.update_hist(net_list[0], -disparity_refine)
283
+ params_list_refine.append(geo_params_refine)
284
+
285
+ # upsample refinement
286
+ disp_up_refine = None
287
+ if geo_params_refine is not None:
288
+ # disp_up_refine = self.upsample_geo(up_mask, params=geo_params_refine)
289
+ disp_up_refine = self.upsample_geo(mask=None, mask_disp=up_mask_disp, params=geo_params_refine)
290
+ # disp_up_refine = disp_up_refine[:,:1]
291
+ disp_predictions_refine.append(disp_up_refine)
292
+
293
+ if test_mode:
294
+ if self.args.refinement is not None and len(self.args.refinement)>0 and enable_refinement:
295
+ return coords1 - coords0, flow_up_refine
296
+ return coords1 - coords0, flow_up
297
+ # return coords1 - coords0, -disp_up
298
+
299
+ if vis_mode:
300
+ return flow_predictions, disp_predictions, disp_predictions_refine, confidence_list
301
+
302
+ return flow_predictions, disp_predictions, disp_predictions_refine, confidence_list, params_list, params_list_refine
core/__init__.py ADDED
File without changes
core/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (151 Bytes). View file
 
core/__pycache__/confidence.cpython-310.pyc ADDED
Binary file (5.17 kB). View file
 
core/__pycache__/corr.cpython-310.pyc ADDED
Binary file (9.64 kB). View file
 
core/__pycache__/extractor.cpython-310.pyc ADDED
Binary file (7.14 kB). View file
 
core/__pycache__/extractor_depthany.cpython-310.pyc ADDED
Binary file (6.43 kB). View file
 
core/__pycache__/fusion.cpython-310.pyc ADDED
Binary file (5.16 kB). View file
 
core/__pycache__/geometry.cpython-310.pyc ADDED
Binary file (5.86 kB). View file
 
core/__pycache__/raft_stereo_depthbeta_refine.cpython-310.pyc ADDED
Binary file (7.77 kB). View file
 
core/__pycache__/update_disp.cpython-310.pyc ADDED
Binary file (5.95 kB). View file
 
core/confidence.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import logging
4
+ import numpy as np
5
+ from collections import OrderedDict
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+
13
+ class OffsetConfidence(nn.Module):
14
+ def __init__(self, args):
15
+ super(OffsetConfidence, self).__init__()
16
+ self.detach = args.detach_in_confidence
17
+ self.offset_memory_size = args.offset_memory_size
18
+ self.conv_fea = nn.Conv2d(256, 16, 3, padding=1)
19
+ self.conv_offset = nn.Conv2d(2*args.offset_memory_size, 16, 3, padding=1)
20
+ self.fusion = nn.Sequential(OrderedDict([
21
+ ('conv1', nn.Conv2d(32, 8, 3, padding=1)),
22
+ ('relu1', nn.LeakyReLU(inplace=True)),
23
+ ('conv2', nn.Conv2d(8, 2, 3, padding=1)),
24
+ ('relu2', nn.LeakyReLU(inplace=True)),
25
+ ('conv3', nn.Conv2d(2, 1, 1, padding=0)),
26
+ ]))
27
+
28
+ if "local_rank" not in args or args.local_rank==0 :
29
+ logging.info(f"OffsetConfidence: " + \
30
+ f"detach: {args.detach_in_confidence}")
31
+
32
+ def forward(self, fea, offset_memory):
33
+ if type(fea) is list:
34
+ fea = torch.cat(fea, dim=1)
35
+ context = self.conv_fea(fea.detach() if self.detach else fea)
36
+ offset_memory = torch.cat([offset.detach() if self.detach else offset for offset in offset_memory], dim=1)
37
+ confidence = self.conv_offset( -offset_memory )
38
+ confidence = self.fusion( torch.cat([confidence,context], dim=1) )
39
+ return confidence
40
+
41
+
42
+
43
+ class MBConvBlockSimple(nn.Module):
44
+ def __init__(self, in_channels, out_channels, expand_ratio=1, kernel_size=3, stride=1, se_ratio=0.25):
45
+ super(MBConvBlockSimple, self).__init__()
46
+
47
+ self.has_se = se_ratio is not None and 0 < se_ratio <= 1
48
+ self.expand_ratio = expand_ratio
49
+ mid_channels = in_channels * expand_ratio
50
+ if expand_ratio != 1:
51
+ self.expand_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False)
52
+ self.bn0 = nn.BatchNorm2d(mid_channels)
53
+
54
+ self.depthwise_conv = nn.Conv2d(mid_channels, mid_channels, kernel_size=kernel_size, stride=stride,
55
+ padding=kernel_size // 2, groups=mid_channels, bias=False)
56
+ self.bn1 = nn.BatchNorm2d(mid_channels)
57
+
58
+ if self.has_se:
59
+ se_channels = max(1, int(in_channels * se_ratio))
60
+ self.se_reduce = nn.Conv2d(mid_channels, se_channels, kernel_size=1)
61
+ self.se_expand = nn.Conv2d(se_channels, mid_channels, kernel_size=1)
62
+
63
+ self.project_conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False)
64
+ self.bn2 = nn.BatchNorm2d(out_channels)
65
+
66
+ self.swish = nn.SiLU(inplace=True)
67
+ self.use_residual = (stride == 1 and in_channels == out_channels)
68
+
69
+ def forward(self, x):
70
+ identity = x
71
+ if self.expand_ratio != 1:
72
+ x = self.swish(self.bn0(self.expand_conv(x)))
73
+
74
+ x = self.swish(self.bn1(self.depthwise_conv(x)))
75
+
76
+ if self.has_se:
77
+ se = F.adaptive_avg_pool2d(x, 1)
78
+ se = self.swish(self.se_reduce(se))
79
+ se = torch.sigmoid(self.se_expand(se))
80
+ x = x * se
81
+
82
+ x = self.bn2(self.project_conv(x))
83
+
84
+ if self.use_residual:
85
+ x = x + identity
86
+
87
+ return x
88
+
89
+
90
+ class EfficientNetB1SimpleEncoder(nn.Module):
91
+ def __init__(self, in_C=2):
92
+ super(EfficientNetB1SimpleEncoder, self).__init__()
93
+
94
+ self.pre_pro = nn.Sequential(
95
+ nn.Conv2d(in_C, 8, 3, padding=1),
96
+ nn.BatchNorm2d(8),
97
+ nn.SiLU(inplace=True),
98
+ nn.Conv2d(8, 8, 3, padding=1),
99
+ nn.BatchNorm2d(8),
100
+ nn.SiLU(inplace=True),
101
+ )
102
+
103
+ # Stem, first downsampling
104
+ self.stem = nn.Sequential(
105
+ nn.Conv2d(8, 32, kernel_size=3, stride=2, padding=1, bias=False),
106
+ nn.BatchNorm2d(32),
107
+ nn.SiLU(inplace=True)
108
+ )
109
+
110
+ # EfficientNet-B1 Layers Configuration
111
+ layers_config = [
112
+ (32, 16, 1, 3, 1, 1), # Stage 1 (no downsampling)
113
+ (16, 24, 6, 3, 2, 2), # Stage 2 (second downsampling)
114
+ (24, 40, 6, 5, 2, 2), # Stage 3 (third downsampling)
115
+ ]
116
+
117
+ # Building EfficientNet-B1 stages
118
+ self.blocks = nn.ModuleList()
119
+ for in_channels, out_channels, expand_ratio, kernel_size, stride, repeats in layers_config:
120
+ block_layers = []
121
+ block_layers.append(MBConvBlockSimple(in_channels, out_channels, expand_ratio, kernel_size, stride))
122
+ for _ in range(repeats - 1):
123
+ block_layers.append(MBConvBlockSimple(out_channels, out_channels, expand_ratio, kernel_size, stride=1))
124
+ self.blocks.append(nn.Sequential(*block_layers))
125
+
126
+ def forward(self, x):
127
+ features = []
128
+ x = self.pre_pro(x)
129
+ features.append(x) # Store features for skip connections
130
+ x = self.stem(x)
131
+ for block in self.blocks:
132
+ x = block(x)
133
+ features.append(x) # Store features for skip connections
134
+ return features
135
+
136
+
137
+ class EfficientUNetSimple(nn.Module):
138
+ def __init__(self, num_classes=1):
139
+ super(EfficientUNetSimple, self).__init__()
140
+
141
+ # Encoder using EfficientNet-B1 with only three stages
142
+ self.encoder = EfficientNetB1SimpleEncoder()
143
+
144
+ # Decoder layers (Upsampling)
145
+ self.upconv3 = nn.Conv2d(40, 24, kernel_size=1)
146
+ self.up3 = nn.ConvTranspose2d(24, 24, kernel_size=2, stride=2)
147
+
148
+ self.upconv2 = nn.Conv2d(24, 16, kernel_size=1)
149
+ self.up2 = nn.ConvTranspose2d(16, 16, kernel_size=2, stride=2)
150
+
151
+ self.upconv1 = nn.Conv2d(16, 8, kernel_size=1)
152
+ self.up1 = nn.ConvTranspose2d(8, 8, kernel_size=2, stride=2)
153
+
154
+ # Final conv layer
155
+ self.final_conv = nn.Conv2d(8, num_classes, kernel_size=1)
156
+
157
+ def forward(self, x):
158
+ # Encoder
159
+ features = self.encoder(x)
160
+ # print("-"*30, features[-1].shape, features[-2].shape, features[-3].shape, features[-4].shape)
161
+
162
+ # Decoder with skip connections
163
+ x = self.up3(self.upconv3(features[-1])) + features[-2] # 1/8 ~ 1/4
164
+ x = self.up2(self.upconv2(x)) + features[-3] # 1/4 ~ 1/2
165
+ x = self.up1(self.upconv1(x)) + features[-4] # 1/2 ~ 1
166
+
167
+ # Final output layer
168
+ x = self.final_conv(x)
169
+ return x
core/corr.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from core.utils.utils import bilinear_sampler
4
+
5
+ try:
6
+ import corr_sampler
7
+ except:
8
+ pass
9
+
10
+ try:
11
+ import alt_cuda_corr
12
+ except:
13
+ # alt_cuda_corr is not compiled
14
+ pass
15
+
16
+
17
+ class CorrSampler(torch.autograd.Function):
18
+ @staticmethod
19
+ def forward(ctx, volume, coords, radius):
20
+ ctx.save_for_backward(volume,coords)
21
+ ctx.radius = radius
22
+ corr, = corr_sampler.forward(volume, coords, radius)
23
+ return corr
24
+ @staticmethod
25
+ def backward(ctx, grad_output):
26
+ volume, coords = ctx.saved_tensors
27
+ grad_output = grad_output.contiguous()
28
+ grad_volume, = corr_sampler.backward(volume, coords, grad_output, ctx.radius)
29
+ return grad_volume, None, None
30
+
31
+ class CorrBlockFast1D:
32
+ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
33
+ self.num_levels = num_levels
34
+ self.radius = radius
35
+ self.corr_pyramid = []
36
+ # all pairs correlation
37
+ corr = CorrBlockFast1D.corr(fmap1, fmap2)
38
+ batch, h1, w1, dim, w2 = corr.shape
39
+ corr = corr.reshape(batch*h1*w1, dim, 1, w2)
40
+ for i in range(self.num_levels):
41
+ self.corr_pyramid.append(corr.view(batch, h1, w1, -1, w2//2**i))
42
+ corr = F.avg_pool2d(corr, [1,2], stride=[1,2])
43
+
44
+ def __call__(self, coords):
45
+ out_pyramid = []
46
+ bz, _, ht, wd = coords.shape
47
+ coords = coords[:, [0]]
48
+ for i in range(self.num_levels):
49
+ corr = CorrSampler.apply(self.corr_pyramid[i].squeeze(3), coords/2**i, self.radius)
50
+ out_pyramid.append(corr.view(bz, -1, ht, wd))
51
+ return torch.cat(out_pyramid, dim=1)
52
+
53
+ @staticmethod
54
+ def corr(fmap1, fmap2):
55
+ B, D, H, W1 = fmap1.shape
56
+ _, _, _, W2 = fmap2.shape
57
+ fmap1 = fmap1.view(B, D, H, W1)
58
+ fmap2 = fmap2.view(B, D, H, W2)
59
+ corr = torch.einsum('aijk,aijh->ajkh', fmap1, fmap2)
60
+ corr = corr.reshape(B, H, W1, 1, W2).contiguous()
61
+ return corr / torch.sqrt(torch.tensor(D).float())
62
+
63
+
64
+ class PytorchAlternateCorrBlock1D:
65
+ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
66
+ self.num_levels = num_levels
67
+ self.radius = radius
68
+ self.corr_pyramid = []
69
+ self.fmap1 = fmap1
70
+ self.fmap2 = fmap2
71
+
72
+ def corr(self, fmap1, fmap2, coords):
73
+ B, D, H, W = fmap2.shape
74
+ # map grid coordinates to [-1,1]
75
+ xgrid, ygrid = coords.split([1,1], dim=-1)
76
+ xgrid = 2*xgrid/(W-1) - 1
77
+ ygrid = 2*ygrid/(H-1) - 1
78
+
79
+ grid = torch.cat([xgrid, ygrid], dim=-1)
80
+ output_corr = []
81
+ for grid_slice in grid.unbind(3):
82
+ fmapw_mini = F.grid_sample(fmap2, grid_slice, align_corners=True)
83
+ corr = torch.sum(fmapw_mini * fmap1, dim=1)
84
+ output_corr.append(corr)
85
+ corr = torch.stack(output_corr, dim=1).permute(0,2,3,1)
86
+
87
+ return corr / torch.sqrt(torch.tensor(D).float())
88
+
89
+ def __call__(self, coords):
90
+ r = self.radius
91
+ coords = coords.permute(0, 2, 3, 1)
92
+ batch, h1, w1, _ = coords.shape
93
+ fmap1 = self.fmap1
94
+ fmap2 = self.fmap2
95
+ out_pyramid = []
96
+ for i in range(self.num_levels):
97
+ dx = torch.zeros(1)
98
+ dy = torch.linspace(-r, r, 2*r+1)
99
+ delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
100
+ centroid_lvl = coords.reshape(batch, h1, w1, 1, 2).clone()
101
+ centroid_lvl[...,0] = centroid_lvl[...,0] / 2**i
102
+ coords_lvl = centroid_lvl + delta.view(-1, 2)
103
+ corr = self.corr(fmap1, fmap2, coords_lvl)
104
+ fmap2 = F.avg_pool2d(fmap2, [1, 2], stride=[1, 2])
105
+ out_pyramid.append(corr)
106
+ out = torch.cat(out_pyramid, dim=-1)
107
+ return out.permute(0, 3, 1, 2).contiguous().float()
108
+
109
+
110
+ class PytorchAlternateAbsCorrBlock1D:
111
+ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
112
+ self.num_levels = num_levels
113
+ self.radius = radius
114
+ self.corr_pyramid = []
115
+ self.fmap1 = fmap1
116
+
117
+ self.fmap2_pyramid = [fmap2]
118
+ for i in range(num_levels):
119
+ fmap2 = F.avg_pool2d(fmap2, [1, 2], stride=[1, 2])
120
+ self.fmap2_pyramid.append(fmap2)
121
+
122
+ def corr(self, fmap1, fmap2, coords):
123
+ B, C, H, W = fmap1.shape
124
+ # map grid coordinates to [-1,1]
125
+ xgrid, ygrid = coords.split([1,1], dim=-1)
126
+ xgrid = 2*xgrid/(W-1) - 1
127
+ ygrid = 2*ygrid/(H-1) - 1
128
+
129
+ grid = torch.cat([xgrid, ygrid], dim=-1)
130
+
131
+ disp_num = 2 * self.radius + 1
132
+ fmapw_mini = F.grid_sample(fmap2, grid.view(B, H, W*disp_num, 2), mode='bilinear',
133
+ padding_mode='zeros').view(B, C, H, W, disp_num) # (B, C, H, W, S)
134
+ corr = torch.sum(fmap1.unsqueeze(-1) * fmapw_mini, dim=1)
135
+
136
+ return corr / torch.sqrt(torch.tensor(C).float())
137
+
138
+ def __call__(self, coords):
139
+ print(f"当前显存消耗量: {torch.distributed.get_rank()} {torch.cuda.memory_allocated() / 1024 / 1024:.2f} MB")
140
+
141
+ # in case of only disparity used in coordinates
142
+ B, D, H, W = coords.shape
143
+ if D==1:
144
+ y_coord = torch.arange(H).unsqueeze(1).float().repeat(B, 1, 1, W).to(coords.device)
145
+ coords = torch.cat([coords,y_coord], dim=1)
146
+
147
+ r = self.radius
148
+ coords = coords.permute(0, 2, 3, 1)
149
+ batch, h1, w1, _ = coords.shape
150
+
151
+ fmap1 = self.fmap1
152
+ out_pyramid = []
153
+ for i in range(self.num_levels):
154
+ fmap2 = self.fmap2_pyramid[i]
155
+
156
+ dx = torch.zeros(1)
157
+ dy = torch.linspace(-r, r, 2*r+1)
158
+ delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
159
+ centroid_lvl = coords.reshape(batch, h1, w1, 1, 2).clone()
160
+ centroid_lvl[...,0] = centroid_lvl[...,0] / 2**i
161
+ coords_lvl = centroid_lvl + delta.view(-1, 2)
162
+
163
+ corr = self.corr(fmap1, fmap2, coords_lvl)
164
+ out_pyramid.append(corr)
165
+ out = torch.cat(out_pyramid, dim=-1)
166
+ return out.permute(0, 3, 1, 2).contiguous().float()
167
+
168
+
169
+ class CorrBlock1D:
170
+ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
171
+ self.num_levels = num_levels
172
+ self.radius = radius
173
+ self.corr_pyramid = []
174
+
175
+ # all pairs correlation
176
+ corr = CorrBlock1D.corr(fmap1, fmap2)
177
+
178
+ batch, h1, w1, _, w2 = corr.shape
179
+ corr = corr.reshape(batch*h1*w1, 1, 1, w2)
180
+
181
+ self.corr_pyramid.append(corr)
182
+ for i in range(self.num_levels):
183
+ corr = F.avg_pool2d(corr, [1,2], stride=[1,2])
184
+ self.corr_pyramid.append(corr)
185
+
186
+ def __call__(self, coords):
187
+ r = self.radius
188
+ coords = coords[:, :1].permute(0, 2, 3, 1)
189
+ batch, h1, w1, _ = coords.shape
190
+
191
+ # print(f"当前显存消耗量: {torch.distributed.get_rank()} {torch.cuda.memory_allocated() / 1024 / 1024:.2f} MB")
192
+
193
+ out_pyramid = []
194
+ for i in range(self.num_levels):
195
+ corr = self.corr_pyramid[i]
196
+ dx = torch.linspace(-r, r, 2*r+1)
197
+ dx = dx.view(2*r+1, 1).to(coords.device)
198
+ x0 = dx + coords.reshape(batch*h1*w1, 1, 1, 1) / 2**i
199
+ y0 = torch.zeros_like(x0)
200
+
201
+ coords_lvl = torch.cat([x0,y0], dim=-1)
202
+ corr = bilinear_sampler(corr, coords_lvl)
203
+ corr = corr.view(batch, h1, w1, -1)
204
+ out_pyramid.append(corr)
205
+
206
+ out = torch.cat(out_pyramid, dim=-1)
207
+ return out.permute(0, 3, 1, 2).contiguous().float()
208
+
209
+ @staticmethod
210
+ def corr(fmap1, fmap2):
211
+ B, D, H, W1 = fmap1.shape
212
+ _, _, _, W2 = fmap2.shape
213
+ fmap1 = fmap1.view(B, D, H, W1)
214
+ fmap2 = fmap2.view(B, D, H, W2)
215
+ corr = torch.einsum('aijk,aijh->ajkh', fmap1, fmap2)
216
+ corr = corr.reshape(B, H, W1, 1, W2).contiguous()
217
+ return corr / torch.sqrt(torch.tensor(D).float())
218
+
219
+ class AbsCorrBlock1D:
220
+ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
221
+ self.num_levels = num_levels
222
+ self.radius = radius
223
+ self.abs_corr_matrix_pyramid = []
224
+
225
+ # all pairs correlation
226
+ abs_corr_matrix = AbsCorrBlock1D.abs_corr(fmap1, fmap2)
227
+
228
+ batch, h1, w1, _, w2 = abs_corr_matrix.shape
229
+ abs_corr_matrix = abs_corr_matrix.reshape(batch*h1*w1, 1, 1, w2)
230
+
231
+ self.abs_corr_matrix_pyramid.append(abs_corr_matrix)
232
+ for i in range(self.num_levels):
233
+ abs_corr_matrix = F.avg_pool2d(abs_corr_matrix, [1,2], stride=[1,2])
234
+ self.abs_corr_matrix_pyramid.append(abs_corr_matrix)
235
+
236
+ def __call__(self, coords):
237
+ r = self.radius
238
+ coords = coords[:, :1].permute(0, 2, 3, 1)
239
+ batch, h1, w1, _ = coords.shape
240
+
241
+ out_pyramid = []
242
+ for i in range(self.num_levels):
243
+ abs_corr_matrix = self.abs_corr_matrix_pyramid[i]
244
+ dx = torch.linspace(-r, r, 2*r+1)
245
+ dx = dx.view(2*r+1, 1).to(coords.device)
246
+ x0 = dx + coords.reshape(batch*h1*w1, 1, 1, 1) / 2**i
247
+ y0 = torch.zeros_like(x0)
248
+
249
+ coords_lvl = torch.cat([x0,y0], dim=-1)
250
+ abs_corr_matrix = bilinear_sampler(abs_corr_matrix, coords_lvl)
251
+ abs_corr_matrix = abs_corr_matrix.view(batch, h1, w1, -1)
252
+ out_pyramid.append(abs_corr_matrix)
253
+
254
+ out = torch.cat(out_pyramid, dim=-1)
255
+ return out.permute(0, 3, 1, 2).contiguous().float()
256
+
257
+ @staticmethod
258
+ def abs_corr(fmap1, fmap2):
259
+ """fucntion: build the correlation matrix (not traditional cost volume) for each pixel in the same line.
260
+ args:
261
+ fmap1: feature maps from left view, B*C*H*W1;
262
+ fmap2: feature maps from right view, B*C*H*W2;
263
+ return:
264
+ the correlation matrix, B*H*W1*W2;
265
+ """
266
+ B, D, H, W1 = fmap1.shape
267
+ _, _, _, W2 = fmap2.shape
268
+
269
+ # 计算 L1 匹配代价
270
+ # corr_matrix = torch.einsum('aijk,aijh->ajkh', fmap1, fmap2)
271
+ # corr_matrix = torch.sum(torch.abs(fmap1.unsqueeze(-1) - fmap2.unsqueeze(-2)), dim=1) # shape (B, H, W1, W2)
272
+ corr_matrix = (fmap1.unsqueeze(-1) - fmap2.unsqueeze(-2)).abs_().sum(dim=1) # shape (B, H, W1, W2)
273
+ # corr_matrix = fmap1.sum(dim=1).unsqueeze(-1) - fmap2.sum(dim=1).unsqueeze(-2) # shape (B, H, W1, W2)
274
+ print("-"*10, " AbsCorrBlock1D: {} ".format(corr_matrix.shape), "-"*10)
275
+ print(f"当前显存消耗量: {torch.distributed.get_rank()} {torch.cuda.memory_allocated() / 1024 / 1024:.2f} MB")
276
+
277
+ corr_matrix = corr_matrix.reshape(B, H, W1, 1, W2).contiguous()
278
+ return corr_matrix / torch.sqrt(torch.tensor(D).float())
279
+
280
+ class AlternateCorrBlock:
281
+ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
282
+ raise NotImplementedError
283
+ self.num_levels = num_levels
284
+ self.radius = radius
285
+
286
+ self.pyramid = [(fmap1, fmap2)]
287
+ for i in range(self.num_levels):
288
+ fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
289
+ fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
290
+ self.pyramid.append((fmap1, fmap2))
291
+
292
+ def __call__(self, coords):
293
+ coords = coords.permute(0, 2, 3, 1)
294
+ B, H, W, _ = coords.shape
295
+ dim = self.pyramid[0][0].shape[1]
296
+
297
+ corr_list = []
298
+ for i in range(self.num_levels):
299
+ r = self.radius
300
+ fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
301
+ fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
302
+
303
+ coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
304
+ corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
305
+ corr_list.append(corr.squeeze(1))
306
+
307
+ corr = torch.stack(corr_list, dim=1)
308
+ corr = corr.reshape(B, -1, H, W)
309
+ return corr / torch.sqrt(torch.tensor(dim).float())
core/extractor.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class ResidualBlock(nn.Module):
7
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
8
+ super(ResidualBlock, self).__init__()
9
+
10
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
11
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
12
+ self.relu = nn.ReLU(inplace=True)
13
+
14
+ num_groups = planes // 8
15
+
16
+ if norm_fn == 'group':
17
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
18
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19
+ if not (stride == 1 and in_planes == planes):
20
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
21
+
22
+ elif norm_fn == 'batch':
23
+ self.norm1 = nn.BatchNorm2d(planes)
24
+ self.norm2 = nn.BatchNorm2d(planes)
25
+ if not (stride == 1 and in_planes == planes):
26
+ self.norm3 = nn.BatchNorm2d(planes)
27
+
28
+ elif norm_fn == 'instance':
29
+ self.norm1 = nn.InstanceNorm2d(planes)
30
+ self.norm2 = nn.InstanceNorm2d(planes)
31
+ if not (stride == 1 and in_planes == planes):
32
+ self.norm3 = nn.InstanceNorm2d(planes)
33
+
34
+ elif norm_fn == 'none':
35
+ self.norm1 = nn.Sequential()
36
+ self.norm2 = nn.Sequential()
37
+ if not (stride == 1 and in_planes == planes):
38
+ self.norm3 = nn.Sequential()
39
+
40
+ if stride == 1 and in_planes == planes:
41
+ self.downsample = None
42
+
43
+ else:
44
+ self.downsample = nn.Sequential(
45
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
46
+
47
+
48
+ def forward(self, x):
49
+ y = x
50
+ y = self.conv1(y)
51
+ y = self.norm1(y)
52
+ y = self.relu(y)
53
+ y = self.conv2(y)
54
+ y = self.norm2(y)
55
+ y = self.relu(y)
56
+
57
+ if self.downsample is not None:
58
+ x = self.downsample(x)
59
+
60
+ return self.relu(x+y)
61
+
62
+
63
+
64
+ class BottleneckBlock(nn.Module):
65
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
66
+ super(BottleneckBlock, self).__init__()
67
+
68
+ self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
69
+ self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
70
+ self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
71
+ self.relu = nn.ReLU(inplace=True)
72
+
73
+ num_groups = planes // 8
74
+
75
+ if norm_fn == 'group':
76
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
77
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
78
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
79
+ if not stride == 1:
80
+ self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
81
+
82
+ elif norm_fn == 'batch':
83
+ self.norm1 = nn.BatchNorm2d(planes//4)
84
+ self.norm2 = nn.BatchNorm2d(planes//4)
85
+ self.norm3 = nn.BatchNorm2d(planes)
86
+ if not stride == 1:
87
+ self.norm4 = nn.BatchNorm2d(planes)
88
+
89
+ elif norm_fn == 'instance':
90
+ self.norm1 = nn.InstanceNorm2d(planes//4)
91
+ self.norm2 = nn.InstanceNorm2d(planes//4)
92
+ self.norm3 = nn.InstanceNorm2d(planes)
93
+ if not stride == 1:
94
+ self.norm4 = nn.InstanceNorm2d(planes)
95
+
96
+ elif norm_fn == 'none':
97
+ self.norm1 = nn.Sequential()
98
+ self.norm2 = nn.Sequential()
99
+ self.norm3 = nn.Sequential()
100
+ if not stride == 1:
101
+ self.norm4 = nn.Sequential()
102
+
103
+ if stride == 1:
104
+ self.downsample = None
105
+
106
+ else:
107
+ self.downsample = nn.Sequential(
108
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
109
+
110
+
111
+ def forward(self, x):
112
+ y = x
113
+ y = self.relu(self.norm1(self.conv1(y)))
114
+ y = self.relu(self.norm2(self.conv2(y)))
115
+ y = self.relu(self.norm3(self.conv3(y)))
116
+
117
+ if self.downsample is not None:
118
+ x = self.downsample(x)
119
+
120
+ return self.relu(x+y)
121
+
122
+ class BasicEncoder(nn.Module):
123
+ def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0, downsample=3):
124
+ super(BasicEncoder, self).__init__()
125
+ self.norm_fn = norm_fn
126
+ self.downsample = downsample
127
+
128
+ if self.norm_fn == 'group':
129
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
130
+
131
+ elif self.norm_fn == 'batch':
132
+ self.norm1 = nn.BatchNorm2d(64)
133
+
134
+ elif self.norm_fn == 'instance':
135
+ self.norm1 = nn.InstanceNorm2d(64)
136
+
137
+ elif self.norm_fn == 'none':
138
+ self.norm1 = nn.Sequential()
139
+
140
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1 + (downsample > 2), padding=3)
141
+ self.relu1 = nn.ReLU(inplace=True)
142
+
143
+ self.in_planes = 64
144
+ self.layer1 = self._make_layer(64, stride=1)
145
+ self.layer2 = self._make_layer(96, stride=1 + (downsample > 1))
146
+ self.layer3 = self._make_layer(128, stride=1 + (downsample > 0))
147
+
148
+ # output convolution
149
+ self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
150
+
151
+ self.dropout = None
152
+ if dropout > 0:
153
+ self.dropout = nn.Dropout2d(p=dropout)
154
+
155
+ for m in self.modules():
156
+ if isinstance(m, nn.Conv2d):
157
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
158
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
159
+ if m.weight is not None:
160
+ nn.init.constant_(m.weight, 1)
161
+ if m.bias is not None:
162
+ nn.init.constant_(m.bias, 0)
163
+
164
+ def _make_layer(self, dim, stride=1):
165
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
166
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
167
+ layers = (layer1, layer2)
168
+
169
+ self.in_planes = dim
170
+ return nn.Sequential(*layers)
171
+
172
+
173
+ def forward(self, x, dual_inp=False):
174
+
175
+ # if input is list, combine batch dimension
176
+ is_list = isinstance(x, tuple) or isinstance(x, list)
177
+ if is_list:
178
+ batch_dim = x[0].shape[0]
179
+ x = torch.cat(x, dim=0)
180
+
181
+ x = self.conv1(x)
182
+ x = self.norm1(x)
183
+ x = self.relu1(x)
184
+
185
+ x = self.layer1(x)
186
+ x = self.layer2(x)
187
+ x = self.layer3(x)
188
+
189
+ x = self.conv2(x)
190
+
191
+ if self.training and self.dropout is not None:
192
+ x = self.dropout(x)
193
+
194
+ if is_list:
195
+ x = x.split(split_size=batch_dim, dim=0)
196
+
197
+ return x
198
+
199
+ class MultiBasicEncoder(nn.Module):
200
+ def __init__(self, output_dim=[128], norm_fn='batch', dropout=0.0, downsample=3):
201
+ super(MultiBasicEncoder, self).__init__()
202
+ self.norm_fn = norm_fn
203
+ self.downsample = downsample
204
+
205
+ if self.norm_fn == 'group':
206
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
207
+
208
+ elif self.norm_fn == 'batch':
209
+ self.norm1 = nn.BatchNorm2d(64)
210
+
211
+ elif self.norm_fn == 'instance':
212
+ self.norm1 = nn.InstanceNorm2d(64)
213
+
214
+ elif self.norm_fn == 'none':
215
+ self.norm1 = nn.Sequential()
216
+
217
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1 + (downsample > 2), padding=3)
218
+ self.relu1 = nn.ReLU(inplace=True)
219
+
220
+ self.in_planes = 64
221
+ self.layer1 = self._make_layer(64, stride=1)
222
+ self.layer2 = self._make_layer(96, stride=1 + (downsample > 1))
223
+ self.layer3 = self._make_layer(128, stride=1 + (downsample > 0))
224
+ self.layer4 = self._make_layer(128, stride=2)
225
+ self.layer5 = self._make_layer(128, stride=2)
226
+
227
+ output_list = []
228
+ for dim in output_dim:
229
+ conv_out = nn.Sequential(
230
+ ResidualBlock(128, 128, self.norm_fn, stride=1),
231
+ nn.Conv2d(128, dim[2], 3, padding=1))
232
+ output_list.append(conv_out)
233
+
234
+ self.outputs08 = nn.ModuleList(output_list)
235
+
236
+ output_list = []
237
+ for dim in output_dim:
238
+ conv_out = nn.Sequential(
239
+ ResidualBlock(128, 128, self.norm_fn, stride=1),
240
+ nn.Conv2d(128, dim[1], 3, padding=1))
241
+ output_list.append(conv_out)
242
+
243
+ self.outputs16 = nn.ModuleList(output_list)
244
+
245
+ output_list = []
246
+ for dim in output_dim:
247
+ conv_out = nn.Conv2d(128, dim[0], 3, padding=1)
248
+ output_list.append(conv_out)
249
+
250
+ self.outputs32 = nn.ModuleList(output_list)
251
+
252
+ if dropout > 0:
253
+ self.dropout = nn.Dropout2d(p=dropout)
254
+ else:
255
+ self.dropout = None
256
+
257
+ for m in self.modules():
258
+ if isinstance(m, nn.Conv2d):
259
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
260
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
261
+ if m.weight is not None:
262
+ nn.init.constant_(m.weight, 1)
263
+ if m.bias is not None:
264
+ nn.init.constant_(m.bias, 0)
265
+
266
+ def _make_layer(self, dim, stride=1):
267
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
268
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
269
+ layers = (layer1, layer2)
270
+
271
+ self.in_planes = dim
272
+ return nn.Sequential(*layers)
273
+
274
+ def forward(self, x, dual_inp=False, num_layers=3):
275
+
276
+ x = self.conv1(x)
277
+ x = self.norm1(x)
278
+ x = self.relu1(x)
279
+
280
+ x = self.layer1(x)
281
+ x = self.layer2(x)
282
+ x = self.layer3(x)
283
+ if dual_inp:
284
+ v = x
285
+ x = x[:(x.shape[0]//2)]
286
+
287
+ outputs08 = [f(x) for f in self.outputs08]
288
+ if num_layers == 1:
289
+ return (outputs08, v) if dual_inp else (outputs08,)
290
+
291
+ y = self.layer4(x)
292
+ outputs16 = [f(y) for f in self.outputs16]
293
+
294
+ if num_layers == 2:
295
+ return (outputs08, outputs16, v) if dual_inp else (outputs08, outputs16)
296
+
297
+ z = self.layer5(y)
298
+ outputs32 = [f(z) for f in self.outputs32]
299
+
300
+ return (outputs08, outputs16, outputs32, v) if dual_inp else (outputs08, outputs16, outputs32)