jhj0517 commited on
Commit
f5ba9ea
1 Parent(s): 75adb90

Add segment-anything-2

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +2 -0
  2. segment-anything-2/.clang-format +85 -0
  3. segment-anything-2/.gitignore +10 -0
  4. segment-anything-2/CODE_OF_CONDUCT.md +80 -0
  5. segment-anything-2/CONTRIBUTING.md +31 -0
  6. segment-anything-2/INSTALL.md +89 -0
  7. segment-anything-2/LICENSE +201 -0
  8. segment-anything-2/LICENSE_cctorch +29 -0
  9. segment-anything-2/README.md +147 -0
  10. segment-anything-2/notebooks/automatic_mask_generator_example.ipynb +0 -0
  11. segment-anything-2/notebooks/image_predictor_example.ipynb +0 -0
  12. segment-anything-2/pyproject.toml +6 -0
  13. segment-anything-2/sam2/__init__.py +9 -0
  14. segment-anything-2/sam2/automatic_mask_generator.py +434 -0
  15. segment-anything-2/sam2/build_sam.py +89 -0
  16. segment-anything-2/sam2/csrc/connected_components.cu +289 -0
  17. segment-anything-2/sam2/modeling/__init__.py +5 -0
  18. segment-anything-2/sam2/modeling/backbones/__init__.py +5 -0
  19. segment-anything-2/sam2/modeling/backbones/hieradet.py +295 -0
  20. segment-anything-2/sam2/modeling/backbones/image_encoder.py +133 -0
  21. segment-anything-2/sam2/modeling/backbones/utils.py +95 -0
  22. segment-anything-2/sam2/modeling/memory_attention.py +169 -0
  23. segment-anything-2/sam2/modeling/memory_encoder.py +181 -0
  24. segment-anything-2/sam2/modeling/position_encoding.py +216 -0
  25. segment-anything-2/sam2/modeling/sam/__init__.py +5 -0
  26. segment-anything-2/sam2/modeling/sam/mask_decoder.py +295 -0
  27. segment-anything-2/sam2/modeling/sam/prompt_encoder.py +182 -0
  28. segment-anything-2/sam2/modeling/sam/transformer.py +327 -0
  29. segment-anything-2/sam2/modeling/sam2_base.py +829 -0
  30. segment-anything-2/sam2/modeling/sam2_utils.py +149 -0
  31. segment-anything-2/sam2/sam2_image_predictor.py +446 -0
  32. segment-anything-2/sam2/sam2_video_predictor.py +898 -0
  33. segment-anything-2/sam2/utils/__init__.py +5 -0
  34. segment-anything-2/sam2/utils/amg.py +348 -0
  35. segment-anything-2/sam2/utils/misc.py +238 -0
  36. segment-anything-2/sam2/utils/transforms.py +99 -0
  37. segment-anything-2/sam2_configs/__init__.py +5 -0
  38. segment-anything-2/sam2_configs/sam2_hiera_l.yaml +117 -0
  39. segment-anything-2/sam2_configs/sam2_hiera_s.yaml +116 -0
  40. segment-anything-2/sam2_configs/sam2_hiera_t.yaml +118 -0
  41. segment-anything-2/sav_dataset/LICENSE +30 -0
  42. segment-anything-2/sav_dataset/LICENSE_DAVIS +29 -0
  43. segment-anything-2/sav_dataset/LICENSE_VOS_BENCHMARK +7 -0
  44. segment-anything-2/sav_dataset/README.md +164 -0
  45. segment-anything-2/sav_dataset/requirements.txt +7 -0
  46. segment-anything-2/sav_dataset/sav_evaluator.py +89 -0
  47. segment-anything-2/sav_dataset/sav_visualization_example.ipynb +0 -0
  48. segment-anything-2/sav_dataset/utils/sav_benchmark.py +488 -0
  49. segment-anything-2/sav_dataset/utils/sav_utils.py +175 -0
  50. segment-anything-2/setup.py +73 -0
.gitignore CHANGED
@@ -1,5 +1,7 @@
 
1
  outputs/
2
  models/
 
3
  *.png
4
  *.jpg
5
  *.mp4
 
1
+ .idea/
2
  outputs/
3
  models/
4
+ venv/
5
  *.png
6
  *.jpg
7
  *.mp4
segment-anything-2/.clang-format ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ AccessModifierOffset: -1
2
+ AlignAfterOpenBracket: AlwaysBreak
3
+ AlignConsecutiveAssignments: false
4
+ AlignConsecutiveDeclarations: false
5
+ AlignEscapedNewlinesLeft: true
6
+ AlignOperands: false
7
+ AlignTrailingComments: false
8
+ AllowAllParametersOfDeclarationOnNextLine: false
9
+ AllowShortBlocksOnASingleLine: false
10
+ AllowShortCaseLabelsOnASingleLine: false
11
+ AllowShortFunctionsOnASingleLine: Empty
12
+ AllowShortIfStatementsOnASingleLine: false
13
+ AllowShortLoopsOnASingleLine: false
14
+ AlwaysBreakAfterReturnType: None
15
+ AlwaysBreakBeforeMultilineStrings: true
16
+ AlwaysBreakTemplateDeclarations: true
17
+ BinPackArguments: false
18
+ BinPackParameters: false
19
+ BraceWrapping:
20
+ AfterClass: false
21
+ AfterControlStatement: false
22
+ AfterEnum: false
23
+ AfterFunction: false
24
+ AfterNamespace: false
25
+ AfterObjCDeclaration: false
26
+ AfterStruct: false
27
+ AfterUnion: false
28
+ BeforeCatch: false
29
+ BeforeElse: false
30
+ IndentBraces: false
31
+ BreakBeforeBinaryOperators: None
32
+ BreakBeforeBraces: Attach
33
+ BreakBeforeTernaryOperators: true
34
+ BreakConstructorInitializersBeforeComma: false
35
+ BreakAfterJavaFieldAnnotations: false
36
+ BreakStringLiterals: false
37
+ ColumnLimit: 80
38
+ CommentPragmas: '^ IWYU pragma:'
39
+ ConstructorInitializerAllOnOneLineOrOnePerLine: true
40
+ ConstructorInitializerIndentWidth: 4
41
+ ContinuationIndentWidth: 4
42
+ Cpp11BracedListStyle: true
43
+ DerivePointerAlignment: false
44
+ DisableFormat: false
45
+ ForEachMacros: [ FOR_EACH, FOR_EACH_R, FOR_EACH_RANGE, ]
46
+ IncludeCategories:
47
+ - Regex: '^<.*\.h(pp)?>'
48
+ Priority: 1
49
+ - Regex: '^<.*'
50
+ Priority: 2
51
+ - Regex: '.*'
52
+ Priority: 3
53
+ IndentCaseLabels: true
54
+ IndentWidth: 2
55
+ IndentWrappedFunctionNames: false
56
+ KeepEmptyLinesAtTheStartOfBlocks: false
57
+ MacroBlockBegin: ''
58
+ MacroBlockEnd: ''
59
+ MaxEmptyLinesToKeep: 1
60
+ NamespaceIndentation: None
61
+ ObjCBlockIndentWidth: 2
62
+ ObjCSpaceAfterProperty: false
63
+ ObjCSpaceBeforeProtocolList: false
64
+ PenaltyBreakBeforeFirstCallParameter: 1
65
+ PenaltyBreakComment: 300
66
+ PenaltyBreakFirstLessLess: 120
67
+ PenaltyBreakString: 1000
68
+ PenaltyExcessCharacter: 1000000
69
+ PenaltyReturnTypeOnItsOwnLine: 200
70
+ PointerAlignment: Left
71
+ ReflowComments: true
72
+ SortIncludes: true
73
+ SpaceAfterCStyleCast: false
74
+ SpaceBeforeAssignmentOperators: true
75
+ SpaceBeforeParens: ControlStatements
76
+ SpaceInEmptyParentheses: false
77
+ SpacesBeforeTrailingComments: 1
78
+ SpacesInAngles: false
79
+ SpacesInContainerLiterals: true
80
+ SpacesInCStyleCastParentheses: false
81
+ SpacesInParentheses: false
82
+ SpacesInSquareBrackets: false
83
+ Standard: Cpp11
84
+ TabWidth: 8
85
+ UseTab: Never
segment-anything-2/.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ .vscode/
2
+ .DS_Store
3
+ __pycache__/
4
+ *-checkpoint.ipynb
5
+ .venv
6
+ *.egg*
7
+ build/*
8
+ _C.*
9
+ outputs/*
10
+ checkpoints/*.pt
segment-anything-2/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ This Code of Conduct also applies outside the project spaces when there is a
56
+ reasonable belief that an individual's behavior may have a negative impact on
57
+ the project or its community.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported by contacting the project team at <[email protected]>. All
63
+ complaints will be reviewed and investigated and will result in a response that
64
+ is deemed necessary and appropriate to the circumstances. The project team is
65
+ obligated to maintain confidentiality with regard to the reporter of an incident.
66
+ Further details of specific enforcement policies may be posted separately.
67
+
68
+ Project maintainers who do not follow or enforce the Code of Conduct in good
69
+ faith may face temporary or permanent repercussions as determined by other
70
+ members of the project's leadership.
71
+
72
+ ## Attribution
73
+
74
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76
+
77
+ [homepage]: https://www.contributor-covenant.org
78
+
79
+ For answers to common questions about this code of conduct, see
80
+ https://www.contributor-covenant.org/faq
segment-anything-2/CONTRIBUTING.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to segment-anything
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Pull Requests
6
+ We actively welcome your pull requests.
7
+
8
+ 1. Fork the repo and create your branch from `main`.
9
+ 2. If you've added code that should be tested, add tests.
10
+ 3. If you've changed APIs, update the documentation.
11
+ 4. Ensure the test suite passes.
12
+ 5. Make sure your code lints, using the `ufmt format` command. Linting requires `black==24.2.0`, `usort==1.0.2`, and `ufmt==2.0.0b2`, which can be installed via `pip install -e ".[dev]"`.
13
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14
+
15
+ ## Contributor License Agreement ("CLA")
16
+ In order to accept your pull request, we need you to submit a CLA. You only need
17
+ to do this once to work on any of Facebook's open source projects.
18
+
19
+ Complete your CLA here: <https://code.facebook.com/cla>
20
+
21
+ ## Issues
22
+ We use GitHub issues to track public bugs. Please ensure your description is
23
+ clear and has sufficient instructions to be able to reproduce the issue.
24
+
25
+ Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
26
+ disclosure of security bugs. In those cases, please go through the process
27
+ outlined on that page and do not file a public issue.
28
+
29
+ ## License
30
+ By contributing to segment-anything, you agree that your contributions will be licensed
31
+ under the LICENSE file in the root directory of this source tree.
segment-anything-2/INSTALL.md ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Installation
2
+
3
+ ### Requirements
4
+
5
+ - Linux with Python ≥ 3.10, PyTorch ≥ 2.3.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. Install them together at https://pytorch.org to ensure this.
6
+ * Note older versions of Python or PyTorch may also work. However, the versions above are strongly recommended to provide all features such as `torch.compile`.
7
+ - [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. This should typically be CUDA 12.1 if you follow the default installation command.
8
+
9
+ Then, install SAM 2 from the root of this repository via
10
+ ```bash
11
+ pip install -e ".[demo]"
12
+ ```
13
+
14
+ ### Common Installation Issues
15
+
16
+ Click each issue for its solutions:
17
+
18
+ <details>
19
+ <summary>
20
+ I got `ImportError: cannot import name '_C' from 'sam2'`
21
+ </summary>
22
+ <br/>
23
+
24
+ This is usually because you haven't run the `pip install -e ".[demo]"` step above or the installation failed. Please install SAM 2 first, and see the other issues if your installation fails.
25
+ </details>
26
+
27
+ <details>
28
+ <summary>
29
+ I got `MissingConfigException: Cannot find primary config 'sam2_hiera_l.yaml'`
30
+ </summary>
31
+ <br/>
32
+
33
+ This is usually because you haven't run the `pip install -e .` step above, so `sam2_configs` isn't in your Python's `sys.path`. Please run this installation step. In case it still fails after the installation step, you may try manually adding the root of this repo to `PYTHONPATH` via
34
+ ```bash
35
+ export SAM2_REPO_ROOT=/path/to/segment-anything # path to this repo
36
+ export PYTHONPATH="${SAM2_REPO_ROOT}:${PYTHONPATH}"
37
+ ```
38
+ to manually add `sam2_configs` into your Python's `sys.path`.
39
+
40
+ </details>
41
+
42
+ <details>
43
+ <summary>
44
+ My installation failed with `CUDA_HOME environment variable is not set`
45
+ </summary>
46
+ <br/>
47
+
48
+ This usually happens because the installation step cannot find the CUDA toolkits (that contain the NVCC compiler) to build a custom CUDA kernel in SAM 2. Please install [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) or the version that matches the CUDA version for your PyTorch installation. If the error persists after installing CUDA toolkits, you may explicitly specify `CUDA_HOME` via
49
+ ```
50
+ export CUDA_HOME=/usr/local/cuda # change to your CUDA toolkit path
51
+ ```
52
+ and rerun the installation.
53
+
54
+ Also, you should make sure
55
+ ```
56
+ python -c 'import torch; from torch.utils.cpp_extension import CUDA_HOME; print(torch.cuda.is_available(), CUDA_HOME)'
57
+ ```
58
+ print `(True, a directory with cuda)` to verify that the CUDA toolkits are correctly set up.
59
+ </details>
60
+
61
+ <details>
62
+ <summary>
63
+ I got `undefined symbol: _ZN3c1015SmallVectorBaseIjE8grow_podEPKvmm` (or similar errors)
64
+ </summary>
65
+ <br/>
66
+
67
+ This usually happens because you have multiple versions of dependencies (PyTorch or CUDA) in your environment. During installation, the SAM 2 library is compiled against one version library while at run time it links against another version. This might be due to that you have different versions of PyTorch or CUDA installed separately via `pip` or `conda`. You may delete one of the duplicates to only keep a single PyTorch and CUDA version.
68
+
69
+ In particular, if you have a lower PyTorch version than 2.3.1, it's recommended to upgrade to PyTorch 2.3.1 or higher first. Otherwise, the installation script will try to upgrade to the latest PyTorch using `pip`, which could sometimes lead to duplicated PyTorch installation if you have previously installed another PyTorch version using `conda`.
70
+
71
+ We have been building SAM 2 against PyTorch 2.3.1 internally. However, a few user comments (e.g. https://github.com/facebookresearch/segment-anything-2/issues/22, https://github.com/facebookresearch/segment-anything-2/issues/14) suggested that downgrading to PyTorch 2.1.0 might resolve this problem. In case the error persists, you may try changing the restriction from `torch>=2.3.1` to `torch>=2.1.0` in both [`pyproject.toml`](pyproject.toml) and [`setup.py`](setup.py) to allow PyTorch 2.1.0.
72
+ </details>
73
+
74
+ <details>
75
+ <summary>
76
+ I got `RuntimeError: No available kernel. Aborting execution.` (or similar errors)
77
+ </summary>
78
+ <br/>
79
+
80
+ This is probably because your machine doesn't have a GPU or a compatible PyTorch version for Flash Attention (see also https://discuss.pytorch.org/t/using-f-scaled-dot-product-attention-gives-the-error-runtimeerror-no-available-kernel-aborting-execution/180900 for a discussion in PyTorch forum). You may be able to resolve this error by replacing the line
81
+ ```python
82
+ OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
83
+ ```
84
+ in [`sam2/modeling/sam/transformer.py`](sam2/modeling/sam/transformer.py) with
85
+ ```python
86
+ OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = True, True, True
87
+ ```
88
+ to relax the attention kernel setting and use other kernels than Flash Attention.
89
+ </details>
segment-anything-2/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
segment-anything-2/LICENSE_cctorch ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2020, the respective contributors, as shown by the AUTHORS file.
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ 3. Neither the name of the copyright holder nor the names of its
17
+ contributors may be used to endorse or promote products derived from
18
+ this software without specific prior written permission.
19
+
20
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
segment-anything-2/README.md ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SAM 2: Segment Anything in Images and Videos
2
+
3
+ **[AI at Meta, FAIR](https://ai.meta.com/research/)**
4
+
5
+ [Nikhila Ravi](https://nikhilaravi.com/), [Valentin Gabeur](https://gabeur.github.io/), [Yuan-Ting Hu](https://scholar.google.com/citations?user=E8DVVYQAAAAJ&hl=en), [Ronghang Hu](https://ronghanghu.com/), [Chaitanya Ryali](https://scholar.google.com/citations?user=4LWx24UAAAAJ&hl=en), [Tengyu Ma](https://scholar.google.com/citations?user=VeTSl0wAAAAJ&hl=en), [Haitham Khedr](https://hkhedr.com/), [Roman Rädle](https://scholar.google.de/citations?user=Tpt57v0AAAAJ&hl=en), [Chloe Rolland](https://scholar.google.com/citations?hl=fr&user=n-SnMhoAAAAJ), [Laura Gustafson](https://scholar.google.com/citations?user=c8IpF9gAAAAJ&hl=en), [Eric Mintun](https://ericmintun.github.io/), [Junting Pan](https://junting.github.io/), [Kalyan Vasudev Alwala](https://scholar.google.co.in/citations?user=m34oaWEAAAAJ&hl=en), [Nicolas Carion](https://www.nicolascarion.com/), [Chao-Yuan Wu](https://chaoyuan.org/), [Ross Girshick](https://www.rossgirshick.info/), [Piotr Dollár](https://pdollar.github.io/), [Christoph Feichtenhofer](https://feichtenhofer.github.io/)
6
+
7
+ [[`Paper`](https://ai.meta.com/research/publications/sam-2-segment-anything-in-images-and-videos/)] [[`Project`](https://ai.meta.com/sam2)] [[`Demo`](https://sam2.metademolab.com/)] [[`Dataset`](https://ai.meta.com/datasets/segment-anything-video)] [[`Blog`](https://ai.meta.com/blog/segment-anything-2)] [[`BibTeX`](#citing-sam-2)]
8
+
9
+ ![SAM 2 architecture](assets/model_diagram.png?raw=true)
10
+
11
+ **Segment Anything Model 2 (SAM 2)** is a foundation model towards solving promptable visual segmentation in images and videos. We extend SAM to video by considering images as a video with a single frame. The model design is a simple transformer architecture with streaming memory for real-time video processing. We build a model-in-the-loop data engine, which improves model and data via user interaction, to collect [**our SA-V dataset**](https://ai.meta.com/datasets/segment-anything-video), the largest video segmentation dataset to date. SAM 2 trained on our data provides strong performance across a wide range of tasks and visual domains.
12
+
13
+ ![SA-V dataset](assets/sa_v_dataset.jpg?raw=true)
14
+
15
+ ## Installation
16
+
17
+ SAM 2 needs to be installed first before use. The code requires `python>=3.10`, as well as `torch>=2.3.1` and `torchvision>=0.18.1`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. You can install SAM 2 on a GPU machine using:
18
+
19
+ ```bash
20
+ git clone https://github.com/facebookresearch/segment-anything-2.git
21
+
22
+ cd segment-anything-2; pip install -e .
23
+ ```
24
+
25
+ To use the SAM 2 predictor and run the example notebooks, `jupyter` and `matplotlib` are required and can be installed by:
26
+
27
+ ```bash
28
+ pip install -e ".[demo]"
29
+ ```
30
+
31
+ Note:
32
+ 1. It's recommended to create a new Python environment for this installation and install PyTorch 2.3.1 (or higher) via `pip` following https://pytorch.org/. If you have a PyTorch version lower than 2.3.1 in your current environment, the installation command above will try to upgrade it to the latest PyTorch version using `pip`.
33
+ 2. The step above requires compiling a custom CUDA kernel with the `nvcc` compiler. If it isn't already available on your machine, please install the [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) with a version that matches your PyTorch CUDA version.
34
+
35
+ Please see [`INSTALL.md`](./INSTALL.md) for FAQs on potential issues and solutions.
36
+
37
+ ## Getting Started
38
+
39
+ ### Download Checkpoints
40
+
41
+ First, we need to download a model checkpoint. All the model checkpoints can be downloaded by running:
42
+
43
+ ```bash
44
+ cd checkpoints
45
+ ./download_ckpts.sh
46
+ ```
47
+
48
+ or individually from:
49
+
50
+ - [sam2_hiera_tiny.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt)
51
+ - [sam2_hiera_small.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt)
52
+ - [sam2_hiera_base_plus.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt)
53
+ - [sam2_hiera_large.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt)
54
+
55
+ Then SAM 2 can be used in a few lines as follows for image and video prediction.
56
+
57
+ ### Image prediction
58
+
59
+ SAM 2 has all the capabilities of [SAM](https://github.com/facebookresearch/segment-anything) on static images, and we provide image prediction APIs that closely resemble SAM for image use cases. The `SAM2ImagePredictor` class has an easy interface for image prompting.
60
+
61
+ ```python
62
+ import torch
63
+ from sam2.build_sam import build_sam2
64
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
65
+
66
+ checkpoint = "./checkpoints/sam2_hiera_large.pt"
67
+ model_cfg = "sam2_hiera_l.yaml"
68
+ predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))
69
+
70
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
71
+ predictor.set_image(<your_image>)
72
+ masks, _, _ = predictor.predict(<input_prompts>)
73
+ ```
74
+
75
+ Please refer to the examples in [image_predictor_example.ipynb](./notebooks/image_predictor_example.ipynb) for static image use cases.
76
+
77
+ SAM 2 also supports automatic mask generation on images just like SAM. Please see [automatic_mask_generator_example.ipynb](./notebooks/automatic_mask_generator_example.ipynb) for automatic mask generation in images.
78
+
79
+ ### Video prediction
80
+
81
+ For promptable segmentation and tracking in videos, we provide a video predictor with APIs for example to add prompts and propagate masklets throughout a video. SAM 2 supports video inference on multiple objects and uses an inference state to keep track of the interactions in each video.
82
+
83
+ ```python
84
+ import torch
85
+ from sam2.build_sam import build_sam2_video_predictor
86
+
87
+ checkpoint = "./checkpoints/sam2_hiera_large.pt"
88
+ model_cfg = "sam2_hiera_l.yaml"
89
+ predictor = build_sam2_video_predictor(model_cfg, checkpoint)
90
+
91
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
92
+ state = predictor.init_state(<your_video>)
93
+
94
+ # add new prompts and instantly get the output on the same frame
95
+ frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>):
96
+
97
+ # propagate the prompts to get masklets throughout the video
98
+ for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
99
+ ...
100
+ ```
101
+
102
+ Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add prompts, make refinements, and track multiple objects in videos.
103
+
104
+ ## Model Description
105
+
106
+ | **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |
107
+ | :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: |
108
+ | sam2_hiera_tiny | 38.9 | 47.2 | 75.0 | 70.9 | 75.3 |
109
+ | sam2_hiera_small | 46 | 43.3 (53.0 compiled\*) | 74.9 | 71.5 | 76.4 |
110
+ | sam2_hiera_base_plus | 80.8 | 34.8 (43.8 compiled\*) | 74.7 | 72.8 | 75.8 |
111
+ | sam2_hiera_large | 224.4 | 24.2 (30.2 compiled\*) | 76.0 | 74.6 | 79.8 |
112
+
113
+ \* Compile the model by setting `compile_image_encoder: True` in the config.
114
+
115
+ ## Segment Anything Video Dataset
116
+
117
+ See [sav_dataset/README.md](sav_dataset/README.md) for details.
118
+
119
+ ## License
120
+
121
+ The models are licensed under the [Apache 2.0 license](./LICENSE). Please refer to our research paper for more details on the models.
122
+
123
+ ## Contributing
124
+
125
+ See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md).
126
+
127
+ ## Contributors
128
+
129
+ The SAM 2 project was made possible with the help of many contributors (alphabetical):
130
+
131
+ Karen Bergan, Daniel Bolya, Alex Bosenberg, Kai Brown, Vispi Cassod, Christopher Chedeau, Ida Cheng, Luc Dahlin, Shoubhik Debnath, Rene Martinez Doehner, Grant Gardner, Sahir Gomez, Rishi Godugu, Baishan Guo, Caleb Ho, Andrew Huang, Somya Jain, Bob Kamma, Amanda Kallet, Jake Kinney, Alexander Kirillov, Shiva Koduvayur, Devansh Kukreja, Robert Kuo, Aohan Lin, Parth Malani, Jitendra Malik, Mallika Malhotra, Miguel Martin, Alexander Miller, Sasha Mitts, William Ngan, George Orlin, Joelle Pineau, Kate Saenko, Rodrick Shepard, Azita Shokrpour, David Soofian, Jonathan Torres, Jenny Truong, Sagar Vaze, Meng Wang, Claudette Ward, Pengchuan Zhang.
132
+
133
+ Third-party code: we use a GPU-based connected component algorithm adapted from [`cc_torch`](https://github.com/zsef123/Connected_components_PyTorch) (with its license in [`LICENSE_cctorch`](./LICENSE_cctorch)) as an optional post-processing step for the mask predictions.
134
+
135
+ ## Citing SAM 2
136
+
137
+ If you use SAM 2 or the SA-V dataset in your research, please use the following BibTeX entry.
138
+
139
+ ```bibtex
140
+ @article{ravi2024sam2,
141
+ title={SAM 2: Segment Anything in Images and Videos},
142
+ author={Ravi, Nikhila and Gabeur, Valentin and Hu, Yuan-Ting and Hu, Ronghang and Ryali, Chaitanya and Ma, Tengyu and Khedr, Haitham and R{\"a}dle, Roman and Rolland, Chloe and Gustafson, Laura and Mintun, Eric and Pan, Junting and Alwala, Kalyan Vasudev and Carion, Nicolas and Wu, Chao-Yuan and Girshick, Ross and Doll{\'a}r, Piotr and Feichtenhofer, Christoph},
143
+ journal={arXiv preprint arXiv:2408.00714},
144
+ url={https://arxiv.org/abs/2408.00714},
145
+ year={2024}
146
+ }
147
+ ```
segment-anything-2/notebooks/automatic_mask_generator_example.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
segment-anything-2/notebooks/image_predictor_example.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
segment-anything-2/pyproject.toml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = [
3
+ "setuptools>=61.0",
4
+ "torch>=2.3.1",
5
+ ]
6
+ build-backend = "setuptools.build_meta"
segment-anything-2/sam2/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from hydra import initialize_config_module
8
+
9
+ initialize_config_module("sam2_configs", version_base="1.2")
segment-anything-2/sam2/automatic_mask_generator.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
8
+ from typing import Any, Dict, List, Optional, Tuple
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torchvision.ops.boxes import batched_nms, box_area # type: ignore
13
+
14
+ from sam2.modeling.sam2_base import SAM2Base
15
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
16
+ from sam2.utils.amg import (
17
+ area_from_rle,
18
+ batch_iterator,
19
+ batched_mask_to_box,
20
+ box_xyxy_to_xywh,
21
+ build_all_layer_point_grids,
22
+ calculate_stability_score,
23
+ coco_encode_rle,
24
+ generate_crop_boxes,
25
+ is_box_near_crop_edge,
26
+ mask_to_rle_pytorch,
27
+ MaskData,
28
+ remove_small_regions,
29
+ rle_to_mask,
30
+ uncrop_boxes_xyxy,
31
+ uncrop_masks,
32
+ uncrop_points,
33
+ )
34
+
35
+
36
+ class SAM2AutomaticMaskGenerator:
37
+ def __init__(
38
+ self,
39
+ model: SAM2Base,
40
+ points_per_side: Optional[int] = 32,
41
+ points_per_batch: int = 64,
42
+ pred_iou_thresh: float = 0.8,
43
+ stability_score_thresh: float = 0.95,
44
+ stability_score_offset: float = 1.0,
45
+ mask_threshold: float = 0.0,
46
+ box_nms_thresh: float = 0.7,
47
+ crop_n_layers: int = 0,
48
+ crop_nms_thresh: float = 0.7,
49
+ crop_overlap_ratio: float = 512 / 1500,
50
+ crop_n_points_downscale_factor: int = 1,
51
+ point_grids: Optional[List[np.ndarray]] = None,
52
+ min_mask_region_area: int = 0,
53
+ output_mode: str = "binary_mask",
54
+ use_m2m: bool = False,
55
+ multimask_output: bool = True,
56
+ ) -> None:
57
+ """
58
+ Using a SAM 2 model, generates masks for the entire image.
59
+ Generates a grid of point prompts over the image, then filters
60
+ low quality and duplicate masks. The default settings are chosen
61
+ for SAM 2 with a HieraL backbone.
62
+
63
+ Arguments:
64
+ model (Sam): The SAM 2 model to use for mask prediction.
65
+ points_per_side (int or None): The number of points to be sampled
66
+ along one side of the image. The total number of points is
67
+ points_per_side**2. If None, 'point_grids' must provide explicit
68
+ point sampling.
69
+ points_per_batch (int): Sets the number of points run simultaneously
70
+ by the model. Higher numbers may be faster but use more GPU memory.
71
+ pred_iou_thresh (float): A filtering threshold in [0,1], using the
72
+ model's predicted mask quality.
73
+ stability_score_thresh (float): A filtering threshold in [0,1], using
74
+ the stability of the mask under changes to the cutoff used to binarize
75
+ the model's mask predictions.
76
+ stability_score_offset (float): The amount to shift the cutoff when
77
+ calculated the stability score.
78
+ mask_threshold (float): Threshold for binarizing the mask logits
79
+ box_nms_thresh (float): The box IoU cutoff used by non-maximal
80
+ suppression to filter duplicate masks.
81
+ crop_n_layers (int): If >0, mask prediction will be run again on
82
+ crops of the image. Sets the number of layers to run, where each
83
+ layer has 2**i_layer number of image crops.
84
+ crop_nms_thresh (float): The box IoU cutoff used by non-maximal
85
+ suppression to filter duplicate masks between different crops.
86
+ crop_overlap_ratio (float): Sets the degree to which crops overlap.
87
+ In the first crop layer, crops will overlap by this fraction of
88
+ the image length. Later layers with more crops scale down this overlap.
89
+ crop_n_points_downscale_factor (int): The number of points-per-side
90
+ sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
91
+ point_grids (list(np.ndarray) or None): A list over explicit grids
92
+ of points used for sampling, normalized to [0,1]. The nth grid in the
93
+ list is used in the nth crop layer. Exclusive with points_per_side.
94
+ min_mask_region_area (int): If >0, postprocessing will be applied
95
+ to remove disconnected regions and holes in masks with area smaller
96
+ than min_mask_region_area. Requires opencv.
97
+ output_mode (str): The form masks are returned in. Can be 'binary_mask',
98
+ 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
99
+ For large resolutions, 'binary_mask' may consume large amounts of
100
+ memory.
101
+ use_m2m (bool): Whether to add a one step refinement using previous mask predictions.
102
+ multimask_output (bool): Whether to output multimask at each point of the grid.
103
+ """
104
+
105
+ assert (points_per_side is None) != (
106
+ point_grids is None
107
+ ), "Exactly one of points_per_side or point_grid must be provided."
108
+ if points_per_side is not None:
109
+ self.point_grids = build_all_layer_point_grids(
110
+ points_per_side,
111
+ crop_n_layers,
112
+ crop_n_points_downscale_factor,
113
+ )
114
+ elif point_grids is not None:
115
+ self.point_grids = point_grids
116
+ else:
117
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
118
+
119
+ assert output_mode in [
120
+ "binary_mask",
121
+ "uncompressed_rle",
122
+ "coco_rle",
123
+ ], f"Unknown output_mode {output_mode}."
124
+ if output_mode == "coco_rle":
125
+ try:
126
+ from pycocotools import mask as mask_utils # type: ignore # noqa: F401
127
+ except ImportError as e:
128
+ print("Please install pycocotools")
129
+ raise e
130
+
131
+ self.predictor = SAM2ImagePredictor(
132
+ model,
133
+ max_hole_area=min_mask_region_area,
134
+ max_sprinkle_area=min_mask_region_area,
135
+ )
136
+ self.points_per_batch = points_per_batch
137
+ self.pred_iou_thresh = pred_iou_thresh
138
+ self.stability_score_thresh = stability_score_thresh
139
+ self.stability_score_offset = stability_score_offset
140
+ self.mask_threshold = mask_threshold
141
+ self.box_nms_thresh = box_nms_thresh
142
+ self.crop_n_layers = crop_n_layers
143
+ self.crop_nms_thresh = crop_nms_thresh
144
+ self.crop_overlap_ratio = crop_overlap_ratio
145
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
146
+ self.min_mask_region_area = min_mask_region_area
147
+ self.output_mode = output_mode
148
+ self.use_m2m = use_m2m
149
+ self.multimask_output = multimask_output
150
+
151
+ @torch.no_grad()
152
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
153
+ """
154
+ Generates masks for the given image.
155
+
156
+ Arguments:
157
+ image (np.ndarray): The image to generate masks for, in HWC uint8 format.
158
+
159
+ Returns:
160
+ list(dict(str, any)): A list over records for masks. Each record is
161
+ a dict containing the following keys:
162
+ segmentation (dict(str, any) or np.ndarray): The mask. If
163
+ output_mode='binary_mask', is an array of shape HW. Otherwise,
164
+ is a dictionary containing the RLE.
165
+ bbox (list(float)): The box around the mask, in XYWH format.
166
+ area (int): The area in pixels of the mask.
167
+ predicted_iou (float): The model's own prediction of the mask's
168
+ quality. This is filtered by the pred_iou_thresh parameter.
169
+ point_coords (list(list(float))): The point coordinates input
170
+ to the model to generate this mask.
171
+ stability_score (float): A measure of the mask's quality. This
172
+ is filtered on using the stability_score_thresh parameter.
173
+ crop_box (list(float)): The crop of the image used to generate
174
+ the mask, given in XYWH format.
175
+ """
176
+
177
+ # Generate masks
178
+ mask_data = self._generate_masks(image)
179
+
180
+ # Encode masks
181
+ if self.output_mode == "coco_rle":
182
+ mask_data["segmentations"] = [
183
+ coco_encode_rle(rle) for rle in mask_data["rles"]
184
+ ]
185
+ elif self.output_mode == "binary_mask":
186
+ mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
187
+ else:
188
+ mask_data["segmentations"] = mask_data["rles"]
189
+
190
+ # Write mask records
191
+ curr_anns = []
192
+ for idx in range(len(mask_data["segmentations"])):
193
+ ann = {
194
+ "segmentation": mask_data["segmentations"][idx],
195
+ "area": area_from_rle(mask_data["rles"][idx]),
196
+ "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
197
+ "predicted_iou": mask_data["iou_preds"][idx].item(),
198
+ "point_coords": [mask_data["points"][idx].tolist()],
199
+ "stability_score": mask_data["stability_score"][idx].item(),
200
+ "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
201
+ }
202
+ curr_anns.append(ann)
203
+
204
+ return curr_anns
205
+
206
+ def _generate_masks(self, image: np.ndarray) -> MaskData:
207
+ orig_size = image.shape[:2]
208
+ crop_boxes, layer_idxs = generate_crop_boxes(
209
+ orig_size, self.crop_n_layers, self.crop_overlap_ratio
210
+ )
211
+
212
+ # Iterate over image crops
213
+ data = MaskData()
214
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
215
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
216
+ data.cat(crop_data)
217
+
218
+ # Remove duplicate masks between crops
219
+ if len(crop_boxes) > 1:
220
+ # Prefer masks from smaller crops
221
+ scores = 1 / box_area(data["crop_boxes"])
222
+ scores = scores.to(data["boxes"].device)
223
+ keep_by_nms = batched_nms(
224
+ data["boxes"].float(),
225
+ scores,
226
+ torch.zeros_like(data["boxes"][:, 0]), # categories
227
+ iou_threshold=self.crop_nms_thresh,
228
+ )
229
+ data.filter(keep_by_nms)
230
+ data.to_numpy()
231
+ return data
232
+
233
+ def _process_crop(
234
+ self,
235
+ image: np.ndarray,
236
+ crop_box: List[int],
237
+ crop_layer_idx: int,
238
+ orig_size: Tuple[int, ...],
239
+ ) -> MaskData:
240
+ # Crop the image and calculate embeddings
241
+ x0, y0, x1, y1 = crop_box
242
+ cropped_im = image[y0:y1, x0:x1, :]
243
+ cropped_im_size = cropped_im.shape[:2]
244
+ self.predictor.set_image(cropped_im)
245
+
246
+ # Get points for this crop
247
+ points_scale = np.array(cropped_im_size)[None, ::-1]
248
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
249
+
250
+ # Generate masks for this crop in batches
251
+ data = MaskData()
252
+ for (points,) in batch_iterator(self.points_per_batch, points_for_image):
253
+ batch_data = self._process_batch(
254
+ points, cropped_im_size, crop_box, orig_size, normalize=True
255
+ )
256
+ data.cat(batch_data)
257
+ del batch_data
258
+ self.predictor.reset_predictor()
259
+
260
+ # Remove duplicates within this crop.
261
+ keep_by_nms = batched_nms(
262
+ data["boxes"].float(),
263
+ data["iou_preds"],
264
+ torch.zeros_like(data["boxes"][:, 0]), # categories
265
+ iou_threshold=self.box_nms_thresh,
266
+ )
267
+ data.filter(keep_by_nms)
268
+
269
+ # Return to the original image frame
270
+ data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
271
+ data["points"] = uncrop_points(data["points"], crop_box)
272
+ data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
273
+
274
+ return data
275
+
276
+ def _process_batch(
277
+ self,
278
+ points: np.ndarray,
279
+ im_size: Tuple[int, ...],
280
+ crop_box: List[int],
281
+ orig_size: Tuple[int, ...],
282
+ normalize=False,
283
+ ) -> MaskData:
284
+ orig_h, orig_w = orig_size
285
+
286
+ # Run model on this batch
287
+ points = torch.as_tensor(points, device=self.predictor.device)
288
+ in_points = self.predictor._transforms.transform_coords(
289
+ points, normalize=normalize, orig_hw=im_size
290
+ )
291
+ in_labels = torch.ones(
292
+ in_points.shape[0], dtype=torch.int, device=in_points.device
293
+ )
294
+ masks, iou_preds, low_res_masks = self.predictor._predict(
295
+ in_points[:, None, :],
296
+ in_labels[:, None],
297
+ multimask_output=self.multimask_output,
298
+ return_logits=True,
299
+ )
300
+
301
+ # Serialize predictions and store in MaskData
302
+ data = MaskData(
303
+ masks=masks.flatten(0, 1),
304
+ iou_preds=iou_preds.flatten(0, 1),
305
+ points=points.repeat_interleave(masks.shape[1], dim=0),
306
+ low_res_masks=low_res_masks.flatten(0, 1),
307
+ )
308
+ del masks
309
+
310
+ if not self.use_m2m:
311
+ # Filter by predicted IoU
312
+ if self.pred_iou_thresh > 0.0:
313
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
314
+ data.filter(keep_mask)
315
+
316
+ # Calculate and filter by stability score
317
+ data["stability_score"] = calculate_stability_score(
318
+ data["masks"], self.mask_threshold, self.stability_score_offset
319
+ )
320
+ if self.stability_score_thresh > 0.0:
321
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
322
+ data.filter(keep_mask)
323
+ else:
324
+ # One step refinement using previous mask predictions
325
+ in_points = self.predictor._transforms.transform_coords(
326
+ data["points"], normalize=normalize, orig_hw=im_size
327
+ )
328
+ labels = torch.ones(
329
+ in_points.shape[0], dtype=torch.int, device=in_points.device
330
+ )
331
+ masks, ious = self.refine_with_m2m(
332
+ in_points, labels, data["low_res_masks"], self.points_per_batch
333
+ )
334
+ data["masks"] = masks.squeeze(1)
335
+ data["iou_preds"] = ious.squeeze(1)
336
+
337
+ if self.pred_iou_thresh > 0.0:
338
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
339
+ data.filter(keep_mask)
340
+
341
+ data["stability_score"] = calculate_stability_score(
342
+ data["masks"], self.mask_threshold, self.stability_score_offset
343
+ )
344
+ if self.stability_score_thresh > 0.0:
345
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
346
+ data.filter(keep_mask)
347
+
348
+ # Threshold masks and calculate boxes
349
+ data["masks"] = data["masks"] > self.mask_threshold
350
+ data["boxes"] = batched_mask_to_box(data["masks"])
351
+
352
+ # Filter boxes that touch crop boundaries
353
+ keep_mask = ~is_box_near_crop_edge(
354
+ data["boxes"], crop_box, [0, 0, orig_w, orig_h]
355
+ )
356
+ if not torch.all(keep_mask):
357
+ data.filter(keep_mask)
358
+
359
+ # Compress to RLE
360
+ data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
361
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
362
+ del data["masks"]
363
+
364
+ return data
365
+
366
+ @staticmethod
367
+ def postprocess_small_regions(
368
+ mask_data: MaskData, min_area: int, nms_thresh: float
369
+ ) -> MaskData:
370
+ """
371
+ Removes small disconnected regions and holes in masks, then reruns
372
+ box NMS to remove any new duplicates.
373
+
374
+ Edits mask_data in place.
375
+
376
+ Requires open-cv as a dependency.
377
+ """
378
+ if len(mask_data["rles"]) == 0:
379
+ return mask_data
380
+
381
+ # Filter small disconnected regions and holes
382
+ new_masks = []
383
+ scores = []
384
+ for rle in mask_data["rles"]:
385
+ mask = rle_to_mask(rle)
386
+
387
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
388
+ unchanged = not changed
389
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
390
+ unchanged = unchanged and not changed
391
+
392
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
393
+ # Give score=0 to changed masks and score=1 to unchanged masks
394
+ # so NMS will prefer ones that didn't need postprocessing
395
+ scores.append(float(unchanged))
396
+
397
+ # Recalculate boxes and remove any new duplicates
398
+ masks = torch.cat(new_masks, dim=0)
399
+ boxes = batched_mask_to_box(masks)
400
+ keep_by_nms = batched_nms(
401
+ boxes.float(),
402
+ torch.as_tensor(scores),
403
+ torch.zeros_like(boxes[:, 0]), # categories
404
+ iou_threshold=nms_thresh,
405
+ )
406
+
407
+ # Only recalculate RLEs for masks that have changed
408
+ for i_mask in keep_by_nms:
409
+ if scores[i_mask] == 0.0:
410
+ mask_torch = masks[i_mask].unsqueeze(0)
411
+ mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
412
+ mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
413
+ mask_data.filter(keep_by_nms)
414
+
415
+ return mask_data
416
+
417
+ def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch):
418
+ new_masks = []
419
+ new_iou_preds = []
420
+
421
+ for cur_points, cur_point_labels, low_res_mask in batch_iterator(
422
+ points_per_batch, points, point_labels, low_res_masks
423
+ ):
424
+ best_masks, best_iou_preds, _ = self.predictor._predict(
425
+ cur_points[:, None, :],
426
+ cur_point_labels[:, None],
427
+ mask_input=low_res_mask[:, None, :],
428
+ multimask_output=False,
429
+ return_logits=True,
430
+ )
431
+ new_masks.append(best_masks)
432
+ new_iou_preds.append(best_iou_preds)
433
+ masks = torch.cat(new_masks, dim=0)
434
+ return masks, torch.cat(new_iou_preds, dim=0)
segment-anything-2/sam2/build_sam.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+
9
+ import torch
10
+ from hydra import compose
11
+ from hydra.utils import instantiate
12
+ from omegaconf import OmegaConf
13
+
14
+
15
+ def build_sam2(
16
+ config_file,
17
+ ckpt_path=None,
18
+ device="cuda",
19
+ mode="eval",
20
+ hydra_overrides_extra=[],
21
+ apply_postprocessing=True,
22
+ ):
23
+
24
+ if apply_postprocessing:
25
+ hydra_overrides_extra = hydra_overrides_extra.copy()
26
+ hydra_overrides_extra += [
27
+ # dynamically fall back to multi-mask if the single mask is not stable
28
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
29
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
30
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
31
+ ]
32
+ # Read config and init model
33
+ cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
34
+ OmegaConf.resolve(cfg)
35
+ model = instantiate(cfg.model, _recursive_=True)
36
+ _load_checkpoint(model, ckpt_path)
37
+ model = model.to(device)
38
+ if mode == "eval":
39
+ model.eval()
40
+ return model
41
+
42
+
43
+ def build_sam2_video_predictor(
44
+ config_file,
45
+ ckpt_path=None,
46
+ device="cuda",
47
+ mode="eval",
48
+ hydra_overrides_extra=[],
49
+ apply_postprocessing=True,
50
+ ):
51
+ hydra_overrides = [
52
+ "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
53
+ ]
54
+ if apply_postprocessing:
55
+ hydra_overrides_extra = hydra_overrides_extra.copy()
56
+ hydra_overrides_extra += [
57
+ # dynamically fall back to multi-mask if the single mask is not stable
58
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
59
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
60
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
61
+ # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
62
+ "++model.binarize_mask_from_pts_for_mem_enc=true",
63
+ # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
64
+ "++model.fill_hole_area=8",
65
+ ]
66
+ hydra_overrides.extend(hydra_overrides_extra)
67
+
68
+ # Read config and init model
69
+ cfg = compose(config_name=config_file, overrides=hydra_overrides)
70
+ OmegaConf.resolve(cfg)
71
+ model = instantiate(cfg.model, _recursive_=True)
72
+ _load_checkpoint(model, ckpt_path)
73
+ model = model.to(device)
74
+ if mode == "eval":
75
+ model.eval()
76
+ return model
77
+
78
+
79
+ def _load_checkpoint(model, ckpt_path):
80
+ if ckpt_path is not None:
81
+ sd = torch.load(ckpt_path, map_location="cpu")["model"]
82
+ missing_keys, unexpected_keys = model.load_state_dict(sd)
83
+ if missing_keys:
84
+ logging.error(missing_keys)
85
+ raise RuntimeError()
86
+ if unexpected_keys:
87
+ logging.error(unexpected_keys)
88
+ raise RuntimeError()
89
+ logging.info("Loaded checkpoint sucessfully")
segment-anything-2/sam2/csrc/connected_components.cu ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+
4
+ // This source code is licensed under the license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ // adapted from https://github.com/zsef123/Connected_components_PyTorch
8
+ // with license found in the LICENSE_cctorch file in the root directory.
9
+ #include <ATen/cuda/CUDAContext.h>
10
+ #include <cuda.h>
11
+ #include <cuda_runtime.h>
12
+ #include <torch/extension.h>
13
+ #include <torch/script.h>
14
+ #include <vector>
15
+
16
+ // 2d
17
+ #define BLOCK_ROWS 16
18
+ #define BLOCK_COLS 16
19
+
20
+ namespace cc2d {
21
+
22
+ template <typename T>
23
+ __device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) {
24
+ return (bitmap >> pos) & 1;
25
+ }
26
+
27
+ __device__ int32_t find(const int32_t* s_buf, int32_t n) {
28
+ while (s_buf[n] != n)
29
+ n = s_buf[n];
30
+ return n;
31
+ }
32
+
33
+ __device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) {
34
+ const int32_t id = n;
35
+ while (s_buf[n] != n) {
36
+ n = s_buf[n];
37
+ s_buf[id] = n;
38
+ }
39
+ return n;
40
+ }
41
+
42
+ __device__ void union_(int32_t* s_buf, int32_t a, int32_t b) {
43
+ bool done;
44
+ do {
45
+ a = find(s_buf, a);
46
+ b = find(s_buf, b);
47
+
48
+ if (a < b) {
49
+ int32_t old = atomicMin(s_buf + b, a);
50
+ done = (old == b);
51
+ b = old;
52
+ } else if (b < a) {
53
+ int32_t old = atomicMin(s_buf + a, b);
54
+ done = (old == a);
55
+ a = old;
56
+ } else
57
+ done = true;
58
+
59
+ } while (!done);
60
+ }
61
+
62
+ __global__ void
63
+ init_labeling(int32_t* label, const uint32_t W, const uint32_t H) {
64
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
65
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
66
+ const uint32_t idx = row * W + col;
67
+
68
+ if (row < H && col < W)
69
+ label[idx] = idx;
70
+ }
71
+
72
+ __global__ void
73
+ merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) {
74
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
75
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
76
+ const uint32_t idx = row * W + col;
77
+
78
+ if (row >= H || col >= W)
79
+ return;
80
+
81
+ uint32_t P = 0;
82
+
83
+ if (img[idx])
84
+ P |= 0x777;
85
+ if (row + 1 < H && img[idx + W])
86
+ P |= 0x777 << 4;
87
+ if (col + 1 < W && img[idx + 1])
88
+ P |= 0x777 << 1;
89
+
90
+ if (col == 0)
91
+ P &= 0xEEEE;
92
+ if (col + 1 >= W)
93
+ P &= 0x3333;
94
+ else if (col + 2 >= W)
95
+ P &= 0x7777;
96
+
97
+ if (row == 0)
98
+ P &= 0xFFF0;
99
+ if (row + 1 >= H)
100
+ P &= 0xFF;
101
+
102
+ if (P > 0) {
103
+ // If need check about top-left pixel(if flag the first bit) and hit the
104
+ // top-left pixel
105
+ if (hasBit(P, 0) && img[idx - W - 1]) {
106
+ union_(label, idx, idx - 2 * W - 2); // top left block
107
+ }
108
+
109
+ if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1]))
110
+ union_(label, idx, idx - 2 * W); // top bottom block
111
+
112
+ if (hasBit(P, 3) && img[idx + 2 - W])
113
+ union_(label, idx, idx - 2 * W + 2); // top right block
114
+
115
+ if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1]))
116
+ union_(label, idx, idx - 2); // just left block
117
+ }
118
+ }
119
+
120
+ __global__ void compression(int32_t* label, const int32_t W, const int32_t H) {
121
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
122
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
123
+ const uint32_t idx = row * W + col;
124
+
125
+ if (row < H && col < W)
126
+ find_n_compress(label, idx);
127
+ }
128
+
129
+ __global__ void final_labeling(
130
+ const uint8_t* img,
131
+ int32_t* label,
132
+ const int32_t W,
133
+ const int32_t H) {
134
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
135
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
136
+ const uint32_t idx = row * W + col;
137
+
138
+ if (row >= H || col >= W)
139
+ return;
140
+
141
+ int32_t y = label[idx] + 1;
142
+
143
+ if (img[idx])
144
+ label[idx] = y;
145
+ else
146
+ label[idx] = 0;
147
+
148
+ if (col + 1 < W) {
149
+ if (img[idx + 1])
150
+ label[idx + 1] = y;
151
+ else
152
+ label[idx + 1] = 0;
153
+
154
+ if (row + 1 < H) {
155
+ if (img[idx + W + 1])
156
+ label[idx + W + 1] = y;
157
+ else
158
+ label[idx + W + 1] = 0;
159
+ }
160
+ }
161
+
162
+ if (row + 1 < H) {
163
+ if (img[idx + W])
164
+ label[idx + W] = y;
165
+ else
166
+ label[idx + W] = 0;
167
+ }
168
+ }
169
+
170
+ __global__ void init_counting(
171
+ const int32_t* label,
172
+ int32_t* count_init,
173
+ const int32_t W,
174
+ const int32_t H) {
175
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
176
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
177
+ const uint32_t idx = row * W + col;
178
+
179
+ if (row >= H || col >= W)
180
+ return;
181
+
182
+ int32_t y = label[idx];
183
+ if (y > 0) {
184
+ int32_t count_idx = y - 1;
185
+ atomicAdd(count_init + count_idx, 1);
186
+ }
187
+ }
188
+
189
+ __global__ void final_counting(
190
+ const int32_t* label,
191
+ const int32_t* count_init,
192
+ int32_t* count_final,
193
+ const int32_t W,
194
+ const int32_t H) {
195
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
196
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
197
+ const uint32_t idx = row * W + col;
198
+
199
+ if (row >= H || col >= W)
200
+ return;
201
+
202
+ int32_t y = label[idx];
203
+ if (y > 0) {
204
+ int32_t count_idx = y - 1;
205
+ count_final[idx] = count_init[count_idx];
206
+ } else {
207
+ count_final[idx] = 0;
208
+ }
209
+ }
210
+
211
+ } // namespace cc2d
212
+
213
+ std::vector<torch::Tensor> get_connected_componnets(
214
+ const torch::Tensor& inputs) {
215
+ AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor");
216
+ AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape");
217
+ AT_ASSERTM(
218
+ inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type");
219
+
220
+ const uint32_t N = inputs.size(0);
221
+ const uint32_t C = inputs.size(1);
222
+ const uint32_t H = inputs.size(2);
223
+ const uint32_t W = inputs.size(3);
224
+
225
+ AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape");
226
+ AT_ASSERTM((H % 2) == 0, "height must be an even number");
227
+ AT_ASSERTM((W % 2) == 0, "width must be an even number");
228
+
229
+ // label must be uint32_t
230
+ auto label_options =
231
+ torch::TensorOptions().dtype(torch::kInt32).device(inputs.device());
232
+ torch::Tensor labels = torch::zeros({N, C, H, W}, label_options);
233
+ torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options);
234
+ torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options);
235
+
236
+ dim3 grid = dim3(
237
+ ((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS,
238
+ ((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS);
239
+ dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS);
240
+ dim3 grid_count =
241
+ dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS);
242
+ dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS);
243
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
244
+
245
+ for (int n = 0; n < N; n++) {
246
+ uint32_t offset = n * H * W;
247
+
248
+ cc2d::init_labeling<<<grid, block, 0, stream>>>(
249
+ labels.data_ptr<int32_t>() + offset, W, H);
250
+ cc2d::merge<<<grid, block, 0, stream>>>(
251
+ inputs.data_ptr<uint8_t>() + offset,
252
+ labels.data_ptr<int32_t>() + offset,
253
+ W,
254
+ H);
255
+ cc2d::compression<<<grid, block, 0, stream>>>(
256
+ labels.data_ptr<int32_t>() + offset, W, H);
257
+ cc2d::final_labeling<<<grid, block, 0, stream>>>(
258
+ inputs.data_ptr<uint8_t>() + offset,
259
+ labels.data_ptr<int32_t>() + offset,
260
+ W,
261
+ H);
262
+
263
+ // get the counting of each pixel
264
+ cc2d::init_counting<<<grid_count, block_count, 0, stream>>>(
265
+ labels.data_ptr<int32_t>() + offset,
266
+ counts_init.data_ptr<int32_t>() + offset,
267
+ W,
268
+ H);
269
+ cc2d::final_counting<<<grid_count, block_count, 0, stream>>>(
270
+ labels.data_ptr<int32_t>() + offset,
271
+ counts_init.data_ptr<int32_t>() + offset,
272
+ counts_final.data_ptr<int32_t>() + offset,
273
+ W,
274
+ H);
275
+ }
276
+
277
+ // returned values are [labels, counts]
278
+ std::vector<torch::Tensor> outputs;
279
+ outputs.push_back(labels);
280
+ outputs.push_back(counts_final);
281
+ return outputs;
282
+ }
283
+
284
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
285
+ m.def(
286
+ "get_connected_componnets",
287
+ &get_connected_componnets,
288
+ "get_connected_componnets");
289
+ }
segment-anything-2/sam2/modeling/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
segment-anything-2/sam2/modeling/backbones/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
segment-anything-2/sam2/modeling/backbones/hieradet.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from functools import partial
8
+ from typing import List, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from sam2.modeling.backbones.utils import (
15
+ PatchEmbed,
16
+ window_partition,
17
+ window_unpartition,
18
+ )
19
+
20
+ from sam2.modeling.sam2_utils import DropPath, MLP
21
+
22
+
23
+ def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
24
+ if pool is None:
25
+ return x
26
+ # (B, H, W, C) -> (B, C, H, W)
27
+ x = x.permute(0, 3, 1, 2)
28
+ x = pool(x)
29
+ # (B, C, H', W') -> (B, H', W', C)
30
+ x = x.permute(0, 2, 3, 1)
31
+ if norm:
32
+ x = norm(x)
33
+
34
+ return x
35
+
36
+
37
+ class MultiScaleAttention(nn.Module):
38
+ def __init__(
39
+ self,
40
+ dim: int,
41
+ dim_out: int,
42
+ num_heads: int,
43
+ q_pool: nn.Module = None,
44
+ ):
45
+ super().__init__()
46
+
47
+ self.dim = dim
48
+ self.dim_out = dim_out
49
+
50
+ self.num_heads = num_heads
51
+ head_dim = dim_out // num_heads
52
+ self.scale = head_dim**-0.5
53
+
54
+ self.q_pool = q_pool
55
+ self.qkv = nn.Linear(dim, dim_out * 3)
56
+ self.proj = nn.Linear(dim_out, dim_out)
57
+
58
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
59
+ B, H, W, _ = x.shape
60
+ # qkv with shape (B, H * W, 3, nHead, C)
61
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
62
+ # q, k, v with shape (B, H * W, nheads, C)
63
+ q, k, v = torch.unbind(qkv, 2)
64
+
65
+ # Q pooling (for downsample at stage changes)
66
+ if self.q_pool:
67
+ q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
68
+ H, W = q.shape[1:3] # downsampled shape
69
+ q = q.reshape(B, H * W, self.num_heads, -1)
70
+
71
+ # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
72
+ x = F.scaled_dot_product_attention(
73
+ q.transpose(1, 2),
74
+ k.transpose(1, 2),
75
+ v.transpose(1, 2),
76
+ )
77
+ # Transpose back
78
+ x = x.transpose(1, 2)
79
+ x = x.reshape(B, H, W, -1)
80
+
81
+ x = self.proj(x)
82
+
83
+ return x
84
+
85
+
86
+ class MultiScaleBlock(nn.Module):
87
+ def __init__(
88
+ self,
89
+ dim: int,
90
+ dim_out: int,
91
+ num_heads: int,
92
+ mlp_ratio: float = 4.0,
93
+ drop_path: float = 0.0,
94
+ norm_layer: Union[nn.Module, str] = "LayerNorm",
95
+ q_stride: Tuple[int, int] = None,
96
+ act_layer: nn.Module = nn.GELU,
97
+ window_size: int = 0,
98
+ ):
99
+ super().__init__()
100
+
101
+ if isinstance(norm_layer, str):
102
+ norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
103
+
104
+ self.dim = dim
105
+ self.dim_out = dim_out
106
+ self.norm1 = norm_layer(dim)
107
+
108
+ self.window_size = window_size
109
+
110
+ self.pool, self.q_stride = None, q_stride
111
+ if self.q_stride:
112
+ self.pool = nn.MaxPool2d(
113
+ kernel_size=q_stride, stride=q_stride, ceil_mode=False
114
+ )
115
+
116
+ self.attn = MultiScaleAttention(
117
+ dim,
118
+ dim_out,
119
+ num_heads=num_heads,
120
+ q_pool=self.pool,
121
+ )
122
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
123
+
124
+ self.norm2 = norm_layer(dim_out)
125
+ self.mlp = MLP(
126
+ dim_out,
127
+ int(dim_out * mlp_ratio),
128
+ dim_out,
129
+ num_layers=2,
130
+ activation=act_layer,
131
+ )
132
+
133
+ if dim != dim_out:
134
+ self.proj = nn.Linear(dim, dim_out)
135
+
136
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
137
+ shortcut = x # B, H, W, C
138
+ x = self.norm1(x)
139
+
140
+ # Skip connection
141
+ if self.dim != self.dim_out:
142
+ shortcut = do_pool(self.proj(x), self.pool)
143
+
144
+ # Window partition
145
+ window_size = self.window_size
146
+ if window_size > 0:
147
+ H, W = x.shape[1], x.shape[2]
148
+ x, pad_hw = window_partition(x, window_size)
149
+
150
+ # Window Attention + Q Pooling (if stage change)
151
+ x = self.attn(x)
152
+ if self.q_stride:
153
+ # Shapes have changed due to Q pooling
154
+ window_size = self.window_size // self.q_stride[0]
155
+ H, W = shortcut.shape[1:3]
156
+
157
+ pad_h = (window_size - H % window_size) % window_size
158
+ pad_w = (window_size - W % window_size) % window_size
159
+ pad_hw = (H + pad_h, W + pad_w)
160
+
161
+ # Reverse window partition
162
+ if self.window_size > 0:
163
+ x = window_unpartition(x, window_size, pad_hw, (H, W))
164
+
165
+ x = shortcut + self.drop_path(x)
166
+ # MLP
167
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
168
+ return x
169
+
170
+
171
+ class Hiera(nn.Module):
172
+ """
173
+ Reference: https://arxiv.org/abs/2306.00989
174
+ """
175
+
176
+ def __init__(
177
+ self,
178
+ embed_dim: int = 96, # initial embed dim
179
+ num_heads: int = 1, # initial number of heads
180
+ drop_path_rate: float = 0.0, # stochastic depth
181
+ q_pool: int = 3, # number of q_pool stages
182
+ q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
183
+ stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
184
+ dim_mul: float = 2.0, # dim_mul factor at stage shift
185
+ head_mul: float = 2.0, # head_mul factor at stage shift
186
+ window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
187
+ # window size per stage, when not using global att.
188
+ window_spec: Tuple[int, ...] = (
189
+ 8,
190
+ 4,
191
+ 14,
192
+ 7,
193
+ ),
194
+ # global attn in these blocks
195
+ global_att_blocks: Tuple[int, ...] = (
196
+ 12,
197
+ 16,
198
+ 20,
199
+ ),
200
+ return_interm_layers=True, # return feats from every stage
201
+ ):
202
+ super().__init__()
203
+
204
+ assert len(stages) == len(window_spec)
205
+ self.window_spec = window_spec
206
+
207
+ depth = sum(stages)
208
+ self.q_stride = q_stride
209
+ self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
210
+ assert 0 <= q_pool <= len(self.stage_ends[:-1])
211
+ self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
212
+ self.return_interm_layers = return_interm_layers
213
+
214
+ self.patch_embed = PatchEmbed(
215
+ embed_dim=embed_dim,
216
+ )
217
+ # Which blocks have global att?
218
+ self.global_att_blocks = global_att_blocks
219
+
220
+ # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
221
+ self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
222
+ self.pos_embed = nn.Parameter(
223
+ torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
224
+ )
225
+ self.pos_embed_window = nn.Parameter(
226
+ torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
227
+ )
228
+
229
+ dpr = [
230
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
231
+ ] # stochastic depth decay rule
232
+
233
+ cur_stage = 1
234
+ self.blocks = nn.ModuleList()
235
+
236
+ for i in range(depth):
237
+ dim_out = embed_dim
238
+ # lags by a block, so first block of
239
+ # next stage uses an initial window size
240
+ # of previous stage and final window size of current stage
241
+ window_size = self.window_spec[cur_stage - 1]
242
+
243
+ if self.global_att_blocks is not None:
244
+ window_size = 0 if i in self.global_att_blocks else window_size
245
+
246
+ if i - 1 in self.stage_ends:
247
+ dim_out = int(embed_dim * dim_mul)
248
+ num_heads = int(num_heads * head_mul)
249
+ cur_stage += 1
250
+
251
+ block = MultiScaleBlock(
252
+ dim=embed_dim,
253
+ dim_out=dim_out,
254
+ num_heads=num_heads,
255
+ drop_path=dpr[i],
256
+ q_stride=self.q_stride if i in self.q_pool_blocks else None,
257
+ window_size=window_size,
258
+ )
259
+
260
+ embed_dim = dim_out
261
+ self.blocks.append(block)
262
+
263
+ self.channel_list = (
264
+ [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
265
+ if return_interm_layers
266
+ else [self.blocks[-1].dim_out]
267
+ )
268
+
269
+ def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
270
+ h, w = hw
271
+ window_embed = self.pos_embed_window
272
+ pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
273
+ pos_embed = pos_embed + window_embed.tile(
274
+ [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
275
+ )
276
+ pos_embed = pos_embed.permute(0, 2, 3, 1)
277
+ return pos_embed
278
+
279
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
280
+ x = self.patch_embed(x)
281
+ # x: (B, H, W, C)
282
+
283
+ # Add pos embed
284
+ x = x + self._get_pos_embed(x.shape[1:3])
285
+
286
+ outputs = []
287
+ for i, blk in enumerate(self.blocks):
288
+ x = blk(x)
289
+ if (i == self.stage_ends[-1]) or (
290
+ i in self.stage_ends and self.return_interm_layers
291
+ ):
292
+ feats = x.permute(0, 3, 1, 2)
293
+ outputs.append(feats)
294
+
295
+ return outputs
segment-anything-2/sam2/modeling/backbones/image_encoder.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class ImageEncoder(nn.Module):
15
+ def __init__(
16
+ self,
17
+ trunk: nn.Module,
18
+ neck: nn.Module,
19
+ scalp: int = 0,
20
+ ):
21
+ super().__init__()
22
+ self.trunk = trunk
23
+ self.neck = neck
24
+ self.scalp = scalp
25
+ assert (
26
+ self.trunk.channel_list == self.neck.backbone_channel_list
27
+ ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
28
+
29
+ def forward(self, sample: torch.Tensor):
30
+ # Forward through backbone
31
+ features, pos = self.neck(self.trunk(sample))
32
+ if self.scalp > 0:
33
+ # Discard the lowest resolution features
34
+ features, pos = features[: -self.scalp], pos[: -self.scalp]
35
+
36
+ src = features[-1]
37
+ output = {
38
+ "vision_features": src,
39
+ "vision_pos_enc": pos,
40
+ "backbone_fpn": features,
41
+ }
42
+ return output
43
+
44
+
45
+ class FpnNeck(nn.Module):
46
+ """
47
+ A modified variant of Feature Pyramid Network (FPN) neck
48
+ (we remove output conv and also do bicubic interpolation similar to ViT
49
+ pos embed interpolation)
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ position_encoding: nn.Module,
55
+ d_model: int,
56
+ backbone_channel_list: List[int],
57
+ kernel_size: int = 1,
58
+ stride: int = 1,
59
+ padding: int = 0,
60
+ fpn_interp_model: str = "bilinear",
61
+ fuse_type: str = "sum",
62
+ fpn_top_down_levels: Optional[List[int]] = None,
63
+ ):
64
+ """Initialize the neck
65
+ :param trunk: the backbone
66
+ :param position_encoding: the positional encoding to use
67
+ :param d_model: the dimension of the model
68
+ :param neck_norm: the normalization to use
69
+ """
70
+ super().__init__()
71
+ self.position_encoding = position_encoding
72
+ self.convs = nn.ModuleList()
73
+ self.backbone_channel_list = backbone_channel_list
74
+ for dim in backbone_channel_list:
75
+ current = nn.Sequential()
76
+ current.add_module(
77
+ "conv",
78
+ nn.Conv2d(
79
+ in_channels=dim,
80
+ out_channels=d_model,
81
+ kernel_size=kernel_size,
82
+ stride=stride,
83
+ padding=padding,
84
+ ),
85
+ )
86
+
87
+ self.convs.append(current)
88
+ self.fpn_interp_model = fpn_interp_model
89
+ assert fuse_type in ["sum", "avg"]
90
+ self.fuse_type = fuse_type
91
+
92
+ # levels to have top-down features in its outputs
93
+ # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
94
+ # have top-down propagation, while outputs of level 0 and level 1 have only
95
+ # lateral features from the same backbone level.
96
+ if fpn_top_down_levels is None:
97
+ # default is to have top-down features on all levels
98
+ fpn_top_down_levels = range(len(self.convs))
99
+ self.fpn_top_down_levels = list(fpn_top_down_levels)
100
+
101
+ def forward(self, xs: List[torch.Tensor]):
102
+
103
+ out = [None] * len(self.convs)
104
+ pos = [None] * len(self.convs)
105
+ assert len(xs) == len(self.convs)
106
+ # fpn forward pass
107
+ # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
108
+ prev_features = None
109
+ # forward in top-down order (from low to high resolution)
110
+ n = len(self.convs) - 1
111
+ for i in range(n, -1, -1):
112
+ x = xs[i]
113
+ lateral_features = self.convs[n - i](x)
114
+ if i in self.fpn_top_down_levels and prev_features is not None:
115
+ top_down_features = F.interpolate(
116
+ prev_features.to(dtype=torch.float32),
117
+ scale_factor=2.0,
118
+ mode=self.fpn_interp_model,
119
+ align_corners=(
120
+ None if self.fpn_interp_model == "nearest" else False
121
+ ),
122
+ antialias=False,
123
+ )
124
+ prev_features = lateral_features + top_down_features
125
+ if self.fuse_type == "avg":
126
+ prev_features /= 2
127
+ else:
128
+ prev_features = lateral_features
129
+ x_out = prev_features
130
+ out[i] = x_out
131
+ pos[i] = self.position_encoding(x_out).to(x_out.dtype)
132
+
133
+ return out, pos
segment-anything-2/sam2/modeling/backbones/utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Some utilities for backbones, in particular for windowing"""
8
+
9
+ from typing import Tuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+
16
+ def window_partition(x, window_size):
17
+ """
18
+ Partition into non-overlapping windows with padding if needed.
19
+ Args:
20
+ x (tensor): input tokens with [B, H, W, C].
21
+ window_size (int): window size.
22
+ Returns:
23
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
24
+ (Hp, Wp): padded height and width before partition
25
+ """
26
+ B, H, W, C = x.shape
27
+
28
+ pad_h = (window_size - H % window_size) % window_size
29
+ pad_w = (window_size - W % window_size) % window_size
30
+ if pad_h > 0 or pad_w > 0:
31
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
32
+ Hp, Wp = H + pad_h, W + pad_w
33
+
34
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
35
+ windows = (
36
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
37
+ )
38
+ return windows, (Hp, Wp)
39
+
40
+
41
+ def window_unpartition(windows, window_size, pad_hw, hw):
42
+ """
43
+ Window unpartition into original sequences and removing padding.
44
+ Args:
45
+ x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
46
+ window_size (int): window size.
47
+ pad_hw (Tuple): padded height and width (Hp, Wp).
48
+ hw (Tuple): original height and width (H, W) before padding.
49
+ Returns:
50
+ x: unpartitioned sequences with [B, H, W, C].
51
+ """
52
+ Hp, Wp = pad_hw
53
+ H, W = hw
54
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
55
+ x = windows.view(
56
+ B, Hp // window_size, Wp // window_size, window_size, window_size, -1
57
+ )
58
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
59
+
60
+ if Hp > H or Wp > W:
61
+ x = x[:, :H, :W, :].contiguous()
62
+ return x
63
+
64
+
65
+ class PatchEmbed(nn.Module):
66
+ """
67
+ Image to Patch Embedding.
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ kernel_size: Tuple[int, ...] = (7, 7),
73
+ stride: Tuple[int, ...] = (4, 4),
74
+ padding: Tuple[int, ...] = (3, 3),
75
+ in_chans: int = 3,
76
+ embed_dim: int = 768,
77
+ ):
78
+ """
79
+ Args:
80
+ kernel_size (Tuple): kernel size of the projection layer.
81
+ stride (Tuple): stride of the projection layer.
82
+ padding (Tuple): padding size of the projection layer.
83
+ in_chans (int): Number of input image channels.
84
+ embed_dim (int): embed_dim (int): Patch embedding dimension.
85
+ """
86
+ super().__init__()
87
+ self.proj = nn.Conv2d(
88
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
89
+ )
90
+
91
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
92
+ x = self.proj(x)
93
+ # B C H W -> B H W C
94
+ x = x.permute(0, 2, 3, 1)
95
+ return x
segment-anything-2/sam2/modeling/memory_attention.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from torch import nn, Tensor
11
+
12
+ from sam2.modeling.sam.transformer import RoPEAttention
13
+
14
+ from sam2.modeling.sam2_utils import get_activation_fn, get_clones
15
+
16
+
17
+ class MemoryAttentionLayer(nn.Module):
18
+
19
+ def __init__(
20
+ self,
21
+ activation: str,
22
+ cross_attention: nn.Module,
23
+ d_model: int,
24
+ dim_feedforward: int,
25
+ dropout: float,
26
+ pos_enc_at_attn: bool,
27
+ pos_enc_at_cross_attn_keys: bool,
28
+ pos_enc_at_cross_attn_queries: bool,
29
+ self_attention: nn.Module,
30
+ ):
31
+ super().__init__()
32
+ self.d_model = d_model
33
+ self.dim_feedforward = dim_feedforward
34
+ self.dropout_value = dropout
35
+ self.self_attn = self_attention
36
+ self.cross_attn_image = cross_attention
37
+
38
+ # Implementation of Feedforward model
39
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
40
+ self.dropout = nn.Dropout(dropout)
41
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
42
+
43
+ self.norm1 = nn.LayerNorm(d_model)
44
+ self.norm2 = nn.LayerNorm(d_model)
45
+ self.norm3 = nn.LayerNorm(d_model)
46
+ self.dropout1 = nn.Dropout(dropout)
47
+ self.dropout2 = nn.Dropout(dropout)
48
+ self.dropout3 = nn.Dropout(dropout)
49
+
50
+ self.activation_str = activation
51
+ self.activation = get_activation_fn(activation)
52
+
53
+ # Where to add pos enc
54
+ self.pos_enc_at_attn = pos_enc_at_attn
55
+ self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
56
+ self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
57
+
58
+ def _forward_sa(self, tgt, query_pos):
59
+ # Self-Attention
60
+ tgt2 = self.norm1(tgt)
61
+ q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
62
+ tgt2 = self.self_attn(q, k, v=tgt2)
63
+ tgt = tgt + self.dropout1(tgt2)
64
+ return tgt
65
+
66
+ def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
67
+ kwds = {}
68
+ if num_k_exclude_rope > 0:
69
+ assert isinstance(self.cross_attn_image, RoPEAttention)
70
+ kwds = {"num_k_exclude_rope": num_k_exclude_rope}
71
+
72
+ # Cross-Attention
73
+ tgt2 = self.norm2(tgt)
74
+ tgt2 = self.cross_attn_image(
75
+ q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
76
+ k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
77
+ v=memory,
78
+ **kwds,
79
+ )
80
+ tgt = tgt + self.dropout2(tgt2)
81
+ return tgt
82
+
83
+ def forward(
84
+ self,
85
+ tgt,
86
+ memory,
87
+ pos: Optional[Tensor] = None,
88
+ query_pos: Optional[Tensor] = None,
89
+ num_k_exclude_rope: int = 0,
90
+ ) -> torch.Tensor:
91
+
92
+ # Self-Attn, Cross-Attn
93
+ tgt = self._forward_sa(tgt, query_pos)
94
+ tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
95
+ # MLP
96
+ tgt2 = self.norm3(tgt)
97
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
98
+ tgt = tgt + self.dropout3(tgt2)
99
+ return tgt
100
+
101
+
102
+ class MemoryAttention(nn.Module):
103
+ def __init__(
104
+ self,
105
+ d_model: int,
106
+ pos_enc_at_input: bool,
107
+ layer: nn.Module,
108
+ num_layers: int,
109
+ batch_first: bool = True, # Do layers expect batch first input?
110
+ ):
111
+ super().__init__()
112
+ self.d_model = d_model
113
+ self.layers = get_clones(layer, num_layers)
114
+ self.num_layers = num_layers
115
+ self.norm = nn.LayerNorm(d_model)
116
+ self.pos_enc_at_input = pos_enc_at_input
117
+ self.batch_first = batch_first
118
+
119
+ def forward(
120
+ self,
121
+ curr: torch.Tensor, # self-attention inputs
122
+ memory: torch.Tensor, # cross-attention inputs
123
+ curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
124
+ memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
125
+ num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
126
+ ):
127
+ if isinstance(curr, list):
128
+ assert isinstance(curr_pos, list)
129
+ assert len(curr) == len(curr_pos) == 1
130
+ curr, curr_pos = (
131
+ curr[0],
132
+ curr_pos[0],
133
+ )
134
+
135
+ assert (
136
+ curr.shape[1] == memory.shape[1]
137
+ ), "Batch size must be the same for curr and memory"
138
+
139
+ output = curr
140
+ if self.pos_enc_at_input and curr_pos is not None:
141
+ output = output + 0.1 * curr_pos
142
+
143
+ if self.batch_first:
144
+ # Convert to batch first
145
+ output = output.transpose(0, 1)
146
+ curr_pos = curr_pos.transpose(0, 1)
147
+ memory = memory.transpose(0, 1)
148
+ memory_pos = memory_pos.transpose(0, 1)
149
+
150
+ for layer in self.layers:
151
+ kwds = {}
152
+ if isinstance(layer.cross_attn_image, RoPEAttention):
153
+ kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
154
+
155
+ output = layer(
156
+ tgt=output,
157
+ memory=memory,
158
+ pos=memory_pos,
159
+ query_pos=curr_pos,
160
+ **kwds,
161
+ )
162
+ normed_output = self.norm(output)
163
+
164
+ if self.batch_first:
165
+ # Convert back to seq first
166
+ normed_output = normed_output.transpose(0, 1)
167
+ curr_pos = curr_pos.transpose(0, 1)
168
+
169
+ return normed_output
segment-anything-2/sam2/modeling/memory_encoder.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d
15
+
16
+
17
+ class MaskDownSampler(nn.Module):
18
+ """
19
+ Progressively downsample a mask by total_stride, each time by stride.
20
+ Note that LayerNorm is applied per *token*, like in ViT.
21
+
22
+ With each downsample (by a factor stride**2), channel capacity increases by the same factor.
23
+ In the end, we linearly project to embed_dim channels.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ embed_dim=256,
29
+ kernel_size=4,
30
+ stride=4,
31
+ padding=0,
32
+ total_stride=16,
33
+ activation=nn.GELU,
34
+ ):
35
+ super().__init__()
36
+ num_layers = int(math.log2(total_stride) // math.log2(stride))
37
+ assert stride**num_layers == total_stride
38
+ self.encoder = nn.Sequential()
39
+ mask_in_chans, mask_out_chans = 1, 1
40
+ for _ in range(num_layers):
41
+ mask_out_chans = mask_in_chans * (stride**2)
42
+ self.encoder.append(
43
+ nn.Conv2d(
44
+ mask_in_chans,
45
+ mask_out_chans,
46
+ kernel_size=kernel_size,
47
+ stride=stride,
48
+ padding=padding,
49
+ )
50
+ )
51
+ self.encoder.append(LayerNorm2d(mask_out_chans))
52
+ self.encoder.append(activation())
53
+ mask_in_chans = mask_out_chans
54
+
55
+ self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
56
+
57
+ def forward(self, x):
58
+ return self.encoder(x)
59
+
60
+
61
+ # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
62
+ class CXBlock(nn.Module):
63
+ r"""ConvNeXt Block. There are two equivalent implementations:
64
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
65
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
66
+ We use (2) as we find it slightly faster in PyTorch
67
+
68
+ Args:
69
+ dim (int): Number of input channels.
70
+ drop_path (float): Stochastic depth rate. Default: 0.0
71
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ dim,
77
+ kernel_size=7,
78
+ padding=3,
79
+ drop_path=0.0,
80
+ layer_scale_init_value=1e-6,
81
+ use_dwconv=True,
82
+ ):
83
+ super().__init__()
84
+ self.dwconv = nn.Conv2d(
85
+ dim,
86
+ dim,
87
+ kernel_size=kernel_size,
88
+ padding=padding,
89
+ groups=dim if use_dwconv else 1,
90
+ ) # depthwise conv
91
+ self.norm = LayerNorm2d(dim, eps=1e-6)
92
+ self.pwconv1 = nn.Linear(
93
+ dim, 4 * dim
94
+ ) # pointwise/1x1 convs, implemented with linear layers
95
+ self.act = nn.GELU()
96
+ self.pwconv2 = nn.Linear(4 * dim, dim)
97
+ self.gamma = (
98
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
99
+ if layer_scale_init_value > 0
100
+ else None
101
+ )
102
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
103
+
104
+ def forward(self, x):
105
+ input = x
106
+ x = self.dwconv(x)
107
+ x = self.norm(x)
108
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
109
+ x = self.pwconv1(x)
110
+ x = self.act(x)
111
+ x = self.pwconv2(x)
112
+ if self.gamma is not None:
113
+ x = self.gamma * x
114
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
115
+
116
+ x = input + self.drop_path(x)
117
+ return x
118
+
119
+
120
+ class Fuser(nn.Module):
121
+ def __init__(self, layer, num_layers, dim=None, input_projection=False):
122
+ super().__init__()
123
+ self.proj = nn.Identity()
124
+ self.layers = get_clones(layer, num_layers)
125
+
126
+ if input_projection:
127
+ assert dim is not None
128
+ self.proj = nn.Conv2d(dim, dim, kernel_size=1)
129
+
130
+ def forward(self, x):
131
+ # normally x: (N, C, H, W)
132
+ x = self.proj(x)
133
+ for layer in self.layers:
134
+ x = layer(x)
135
+ return x
136
+
137
+
138
+ class MemoryEncoder(nn.Module):
139
+ def __init__(
140
+ self,
141
+ out_dim,
142
+ mask_downsampler,
143
+ fuser,
144
+ position_encoding,
145
+ in_dim=256, # in_dim of pix_feats
146
+ ):
147
+ super().__init__()
148
+
149
+ self.mask_downsampler = mask_downsampler
150
+
151
+ self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
152
+ self.fuser = fuser
153
+ self.position_encoding = position_encoding
154
+ self.out_proj = nn.Identity()
155
+ if out_dim != in_dim:
156
+ self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
157
+
158
+ def forward(
159
+ self,
160
+ pix_feat: torch.Tensor,
161
+ masks: torch.Tensor,
162
+ skip_mask_sigmoid: bool = False,
163
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
164
+ ## Process masks
165
+ # sigmoid, so that less domain shift from gt masks which are bool
166
+ if not skip_mask_sigmoid:
167
+ masks = F.sigmoid(masks)
168
+ masks = self.mask_downsampler(masks)
169
+
170
+ ## Fuse pix_feats and downsampled masks
171
+ # in case the visual features are on CPU, cast them to CUDA
172
+ pix_feat = pix_feat.to(masks.device)
173
+
174
+ x = self.pix_feat_proj(pix_feat)
175
+ x = x + masks
176
+ x = self.fuser(x)
177
+ x = self.out_proj(x)
178
+
179
+ pos = self.position_encoding(x).to(x.dtype)
180
+
181
+ return {"vision_features": x, "vision_pos_enc": [pos]}
segment-anything-2/sam2/modeling/position_encoding.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Any, Optional, Tuple
9
+
10
+ import numpy as np
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+
16
+ class PositionEmbeddingSine(nn.Module):
17
+ """
18
+ This is a more standard version of the position embedding, very similar to the one
19
+ used by the Attention is all you need paper, generalized to work on images.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ num_pos_feats,
25
+ temperature: int = 10000,
26
+ normalize: bool = True,
27
+ scale: Optional[float] = None,
28
+ ):
29
+ super().__init__()
30
+ assert num_pos_feats % 2 == 0, "Expecting even model width"
31
+ self.num_pos_feats = num_pos_feats // 2
32
+ self.temperature = temperature
33
+ self.normalize = normalize
34
+ if scale is not None and normalize is False:
35
+ raise ValueError("normalize should be True if scale is passed")
36
+ if scale is None:
37
+ scale = 2 * math.pi
38
+ self.scale = scale
39
+
40
+ self.cache = {}
41
+
42
+ def _encode_xy(self, x, y):
43
+ # The positions are expected to be normalized
44
+ assert len(x) == len(y) and x.ndim == y.ndim == 1
45
+ x_embed = x * self.scale
46
+ y_embed = y * self.scale
47
+
48
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
49
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
50
+
51
+ pos_x = x_embed[:, None] / dim_t
52
+ pos_y = y_embed[:, None] / dim_t
53
+ pos_x = torch.stack(
54
+ (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
55
+ ).flatten(1)
56
+ pos_y = torch.stack(
57
+ (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
58
+ ).flatten(1)
59
+ return pos_x, pos_y
60
+
61
+ @torch.no_grad()
62
+ def encode_boxes(self, x, y, w, h):
63
+ pos_x, pos_y = self._encode_xy(x, y)
64
+ pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
65
+ return pos
66
+
67
+ encode = encode_boxes # Backwards compatibility
68
+
69
+ @torch.no_grad()
70
+ def encode_points(self, x, y, labels):
71
+ (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
72
+ assert bx == by and nx == ny and bx == bl and nx == nl
73
+ pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
74
+ pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
75
+ pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
76
+ return pos
77
+
78
+ @torch.no_grad()
79
+ def forward(self, x: torch.Tensor):
80
+ cache_key = (x.shape[-2], x.shape[-1])
81
+ if cache_key in self.cache:
82
+ return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
83
+ y_embed = (
84
+ torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
85
+ .view(1, -1, 1)
86
+ .repeat(x.shape[0], 1, x.shape[-1])
87
+ )
88
+ x_embed = (
89
+ torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
90
+ .view(1, 1, -1)
91
+ .repeat(x.shape[0], x.shape[-2], 1)
92
+ )
93
+
94
+ if self.normalize:
95
+ eps = 1e-6
96
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
97
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
98
+
99
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
100
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
101
+
102
+ pos_x = x_embed[:, :, :, None] / dim_t
103
+ pos_y = y_embed[:, :, :, None] / dim_t
104
+ pos_x = torch.stack(
105
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
106
+ ).flatten(3)
107
+ pos_y = torch.stack(
108
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
109
+ ).flatten(3)
110
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
111
+ self.cache[cache_key] = pos[0]
112
+ return pos
113
+
114
+
115
+ class PositionEmbeddingRandom(nn.Module):
116
+ """
117
+ Positional encoding using random spatial frequencies.
118
+ """
119
+
120
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
121
+ super().__init__()
122
+ if scale is None or scale <= 0.0:
123
+ scale = 1.0
124
+ self.register_buffer(
125
+ "positional_encoding_gaussian_matrix",
126
+ scale * torch.randn((2, num_pos_feats)),
127
+ )
128
+
129
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
130
+ """Positionally encode points that are normalized to [0,1]."""
131
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
132
+ coords = 2 * coords - 1
133
+ coords = coords @ self.positional_encoding_gaussian_matrix
134
+ coords = 2 * np.pi * coords
135
+ # outputs d_1 x ... x d_n x C shape
136
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
137
+
138
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
139
+ """Generate positional encoding for a grid of the specified size."""
140
+ h, w = size
141
+ device: Any = self.positional_encoding_gaussian_matrix.device
142
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
143
+ y_embed = grid.cumsum(dim=0) - 0.5
144
+ x_embed = grid.cumsum(dim=1) - 0.5
145
+ y_embed = y_embed / h
146
+ x_embed = x_embed / w
147
+
148
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
149
+ return pe.permute(2, 0, 1) # C x H x W
150
+
151
+ def forward_with_coords(
152
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
153
+ ) -> torch.Tensor:
154
+ """Positionally encode points that are not normalized to [0,1]."""
155
+ coords = coords_input.clone()
156
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
157
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
158
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
159
+
160
+
161
+ # Rotary Positional Encoding, adapted from:
162
+ # 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
163
+ # 2. https://github.com/naver-ai/rope-vit
164
+ # 3. https://github.com/lucidrains/rotary-embedding-torch
165
+
166
+
167
+ def init_t_xy(end_x: int, end_y: int):
168
+ t = torch.arange(end_x * end_y, dtype=torch.float32)
169
+ t_x = (t % end_x).float()
170
+ t_y = torch.div(t, end_x, rounding_mode="floor").float()
171
+ return t_x, t_y
172
+
173
+
174
+ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
175
+ freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
176
+ freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
177
+
178
+ t_x, t_y = init_t_xy(end_x, end_y)
179
+ freqs_x = torch.outer(t_x, freqs_x)
180
+ freqs_y = torch.outer(t_y, freqs_y)
181
+ freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
182
+ freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
183
+ return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
184
+
185
+
186
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
187
+ ndim = x.ndim
188
+ assert 0 <= 1 < ndim
189
+ assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
190
+ shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
191
+ return freqs_cis.view(*shape)
192
+
193
+
194
+ def apply_rotary_enc(
195
+ xq: torch.Tensor,
196
+ xk: torch.Tensor,
197
+ freqs_cis: torch.Tensor,
198
+ repeat_freqs_k: bool = False,
199
+ ):
200
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
201
+ xk_ = (
202
+ torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
203
+ if xk.shape[-2] != 0
204
+ else None
205
+ )
206
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
207
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
208
+ if xk_ is None:
209
+ # no keys to rotate, due to dropout
210
+ return xq_out.type_as(xq).to(xq.device), xk
211
+ # repeat freqs along seq_len dim to match k seq_len
212
+ if repeat_freqs_k:
213
+ r = xk_.shape[-2] // xq_.shape[-2]
214
+ freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
215
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
216
+ return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
segment-anything-2/sam2/modeling/sam/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
segment-anything-2/sam2/modeling/sam/mask_decoder.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Optional, Tuple, Type
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from sam2.modeling.sam2_utils import LayerNorm2d, MLP
13
+
14
+
15
+ class MaskDecoder(nn.Module):
16
+ def __init__(
17
+ self,
18
+ *,
19
+ transformer_dim: int,
20
+ transformer: nn.Module,
21
+ num_multimask_outputs: int = 3,
22
+ activation: Type[nn.Module] = nn.GELU,
23
+ iou_head_depth: int = 3,
24
+ iou_head_hidden_dim: int = 256,
25
+ use_high_res_features: bool = False,
26
+ iou_prediction_use_sigmoid=False,
27
+ dynamic_multimask_via_stability=False,
28
+ dynamic_multimask_stability_delta=0.05,
29
+ dynamic_multimask_stability_thresh=0.98,
30
+ pred_obj_scores: bool = False,
31
+ pred_obj_scores_mlp: bool = False,
32
+ use_multimask_token_for_obj_ptr: bool = False,
33
+ ) -> None:
34
+ """
35
+ Predicts masks given an image and prompt embeddings, using a
36
+ transformer architecture.
37
+
38
+ Arguments:
39
+ transformer_dim (int): the channel dimension of the transformer
40
+ transformer (nn.Module): the transformer used to predict masks
41
+ num_multimask_outputs (int): the number of masks to predict
42
+ when disambiguating masks
43
+ activation (nn.Module): the type of activation to use when
44
+ upscaling masks
45
+ iou_head_depth (int): the depth of the MLP used to predict
46
+ mask quality
47
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
48
+ used to predict mask quality
49
+ """
50
+ super().__init__()
51
+ self.transformer_dim = transformer_dim
52
+ self.transformer = transformer
53
+
54
+ self.num_multimask_outputs = num_multimask_outputs
55
+
56
+ self.iou_token = nn.Embedding(1, transformer_dim)
57
+ self.num_mask_tokens = num_multimask_outputs + 1
58
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
59
+
60
+ self.pred_obj_scores = pred_obj_scores
61
+ if self.pred_obj_scores:
62
+ self.obj_score_token = nn.Embedding(1, transformer_dim)
63
+ self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
64
+
65
+ self.output_upscaling = nn.Sequential(
66
+ nn.ConvTranspose2d(
67
+ transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
68
+ ),
69
+ LayerNorm2d(transformer_dim // 4),
70
+ activation(),
71
+ nn.ConvTranspose2d(
72
+ transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
73
+ ),
74
+ activation(),
75
+ )
76
+ self.use_high_res_features = use_high_res_features
77
+ if use_high_res_features:
78
+ self.conv_s0 = nn.Conv2d(
79
+ transformer_dim, transformer_dim // 8, kernel_size=1, stride=1
80
+ )
81
+ self.conv_s1 = nn.Conv2d(
82
+ transformer_dim, transformer_dim // 4, kernel_size=1, stride=1
83
+ )
84
+
85
+ self.output_hypernetworks_mlps = nn.ModuleList(
86
+ [
87
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
88
+ for i in range(self.num_mask_tokens)
89
+ ]
90
+ )
91
+
92
+ self.iou_prediction_head = MLP(
93
+ transformer_dim,
94
+ iou_head_hidden_dim,
95
+ self.num_mask_tokens,
96
+ iou_head_depth,
97
+ sigmoid_output=iou_prediction_use_sigmoid,
98
+ )
99
+ if self.pred_obj_scores:
100
+ self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
101
+ if pred_obj_scores_mlp:
102
+ self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
103
+
104
+ # When outputting a single mask, optionally we can dynamically fall back to the best
105
+ # multimask output token if the single mask output token gives low stability scores.
106
+ self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
107
+ self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
108
+ self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
109
+
110
+ def forward(
111
+ self,
112
+ image_embeddings: torch.Tensor,
113
+ image_pe: torch.Tensor,
114
+ sparse_prompt_embeddings: torch.Tensor,
115
+ dense_prompt_embeddings: torch.Tensor,
116
+ multimask_output: bool,
117
+ repeat_image: bool,
118
+ high_res_features: Optional[List[torch.Tensor]] = None,
119
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
120
+ """
121
+ Predict masks given image and prompt embeddings.
122
+
123
+ Arguments:
124
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
125
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
126
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
127
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
128
+ multimask_output (bool): Whether to return multiple masks or a single
129
+ mask.
130
+
131
+ Returns:
132
+ torch.Tensor: batched predicted masks
133
+ torch.Tensor: batched predictions of mask quality
134
+ torch.Tensor: batched SAM token for mask output
135
+ """
136
+ masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
137
+ image_embeddings=image_embeddings,
138
+ image_pe=image_pe,
139
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
140
+ dense_prompt_embeddings=dense_prompt_embeddings,
141
+ repeat_image=repeat_image,
142
+ high_res_features=high_res_features,
143
+ )
144
+
145
+ # Select the correct mask or masks for output
146
+ if multimask_output:
147
+ masks = masks[:, 1:, :, :]
148
+ iou_pred = iou_pred[:, 1:]
149
+ elif self.dynamic_multimask_via_stability and not self.training:
150
+ masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
151
+ else:
152
+ masks = masks[:, 0:1, :, :]
153
+ iou_pred = iou_pred[:, 0:1]
154
+
155
+ if multimask_output and self.use_multimask_token_for_obj_ptr:
156
+ sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
157
+ else:
158
+ # Take the mask output token. Here we *always* use the token for single mask output.
159
+ # At test time, even if we track after 1-click (and using multimask_output=True),
160
+ # we still take the single mask token here. The rationale is that we always track
161
+ # after multiple clicks during training, so the past tokens seen during training
162
+ # are always the single mask token (and we'll let it be the object-memory token).
163
+ sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
164
+
165
+ # Prepare output
166
+ return masks, iou_pred, sam_tokens_out, object_score_logits
167
+
168
+ def predict_masks(
169
+ self,
170
+ image_embeddings: torch.Tensor,
171
+ image_pe: torch.Tensor,
172
+ sparse_prompt_embeddings: torch.Tensor,
173
+ dense_prompt_embeddings: torch.Tensor,
174
+ repeat_image: bool,
175
+ high_res_features: Optional[List[torch.Tensor]] = None,
176
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
177
+ """Predicts masks. See 'forward' for more details."""
178
+ # Concatenate output tokens
179
+ s = 0
180
+ if self.pred_obj_scores:
181
+ output_tokens = torch.cat(
182
+ [
183
+ self.obj_score_token.weight,
184
+ self.iou_token.weight,
185
+ self.mask_tokens.weight,
186
+ ],
187
+ dim=0,
188
+ )
189
+ s = 1
190
+ else:
191
+ output_tokens = torch.cat(
192
+ [self.iou_token.weight, self.mask_tokens.weight], dim=0
193
+ )
194
+ output_tokens = output_tokens.unsqueeze(0).expand(
195
+ sparse_prompt_embeddings.size(0), -1, -1
196
+ )
197
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
198
+
199
+ # Expand per-image data in batch direction to be per-mask
200
+ if repeat_image:
201
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
202
+ else:
203
+ assert image_embeddings.shape[0] == tokens.shape[0]
204
+ src = image_embeddings
205
+ src = src + dense_prompt_embeddings
206
+ assert (
207
+ image_pe.size(0) == 1
208
+ ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
209
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
210
+ b, c, h, w = src.shape
211
+
212
+ # Run the transformer
213
+ hs, src = self.transformer(src, pos_src, tokens)
214
+ iou_token_out = hs[:, s, :]
215
+ mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
216
+
217
+ # Upscale mask embeddings and predict masks using the mask tokens
218
+ src = src.transpose(1, 2).view(b, c, h, w)
219
+ if not self.use_high_res_features:
220
+ upscaled_embedding = self.output_upscaling(src)
221
+ else:
222
+ dc1, ln1, act1, dc2, act2 = self.output_upscaling
223
+ feat_s0, feat_s1 = high_res_features
224
+ upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
225
+ upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
226
+
227
+ hyper_in_list: List[torch.Tensor] = []
228
+ for i in range(self.num_mask_tokens):
229
+ hyper_in_list.append(
230
+ self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
231
+ )
232
+ hyper_in = torch.stack(hyper_in_list, dim=1)
233
+ b, c, h, w = upscaled_embedding.shape
234
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
235
+
236
+ # Generate mask quality predictions
237
+ iou_pred = self.iou_prediction_head(iou_token_out)
238
+ if self.pred_obj_scores:
239
+ assert s == 1
240
+ object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
241
+ else:
242
+ # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
243
+ object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
244
+
245
+ return masks, iou_pred, mask_tokens_out, object_score_logits
246
+
247
+ def _get_stability_scores(self, mask_logits):
248
+ """
249
+ Compute stability scores of the mask logits based on the IoU between upper and
250
+ lower thresholds, similar to https://github.com/fairinternal/onevision/pull/568.
251
+ """
252
+ mask_logits = mask_logits.flatten(-2)
253
+ stability_delta = self.dynamic_multimask_stability_delta
254
+ area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
255
+ area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
256
+ stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
257
+ return stability_scores
258
+
259
+ def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
260
+ """
261
+ When outputting a single mask, if the stability score from the current single-mask
262
+ output (based on output token 0) falls below a threshold, we instead select from
263
+ multi-mask outputs (based on output token 1~3) the mask with the highest predicted
264
+ IoU score. This is intended to ensure a valid mask for both clicking and tracking.
265
+ """
266
+ # The best mask from multimask output tokens (1~3)
267
+ multimask_logits = all_mask_logits[:, 1:, :, :]
268
+ multimask_iou_scores = all_iou_scores[:, 1:]
269
+ best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
270
+ batch_inds = torch.arange(
271
+ multimask_iou_scores.size(0), device=all_iou_scores.device
272
+ )
273
+ best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
274
+ best_multimask_logits = best_multimask_logits.unsqueeze(1)
275
+ best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
276
+ best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
277
+
278
+ # The mask from singlemask output token 0 and its stability score
279
+ singlemask_logits = all_mask_logits[:, 0:1, :, :]
280
+ singlemask_iou_scores = all_iou_scores[:, 0:1]
281
+ stability_scores = self._get_stability_scores(singlemask_logits)
282
+ is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
283
+
284
+ # Dynamically fall back to best multimask output upon low stability scores.
285
+ mask_logits_out = torch.where(
286
+ is_stable[..., None, None].expand_as(singlemask_logits),
287
+ singlemask_logits,
288
+ best_multimask_logits,
289
+ )
290
+ iou_scores_out = torch.where(
291
+ is_stable.expand_as(singlemask_iou_scores),
292
+ singlemask_iou_scores,
293
+ best_multimask_iou_scores,
294
+ )
295
+ return mask_logits_out, iou_scores_out
segment-anything-2/sam2/modeling/sam/prompt_encoder.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional, Tuple, Type
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from sam2.modeling.position_encoding import PositionEmbeddingRandom
13
+
14
+ from sam2.modeling.sam2_utils import LayerNorm2d
15
+
16
+
17
+ class PromptEncoder(nn.Module):
18
+ def __init__(
19
+ self,
20
+ embed_dim: int,
21
+ image_embedding_size: Tuple[int, int],
22
+ input_image_size: Tuple[int, int],
23
+ mask_in_chans: int,
24
+ activation: Type[nn.Module] = nn.GELU,
25
+ ) -> None:
26
+ """
27
+ Encodes prompts for input to SAM's mask decoder.
28
+
29
+ Arguments:
30
+ embed_dim (int): The prompts' embedding dimension
31
+ image_embedding_size (tuple(int, int)): The spatial size of the
32
+ image embedding, as (H, W).
33
+ input_image_size (int): The padded size of the image as input
34
+ to the image encoder, as (H, W).
35
+ mask_in_chans (int): The number of hidden channels used for
36
+ encoding input masks.
37
+ activation (nn.Module): The activation to use when encoding
38
+ input masks.
39
+ """
40
+ super().__init__()
41
+ self.embed_dim = embed_dim
42
+ self.input_image_size = input_image_size
43
+ self.image_embedding_size = image_embedding_size
44
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
45
+
46
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
47
+ point_embeddings = [
48
+ nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
49
+ ]
50
+ self.point_embeddings = nn.ModuleList(point_embeddings)
51
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
52
+
53
+ self.mask_input_size = (
54
+ 4 * image_embedding_size[0],
55
+ 4 * image_embedding_size[1],
56
+ )
57
+ self.mask_downscaling = nn.Sequential(
58
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
59
+ LayerNorm2d(mask_in_chans // 4),
60
+ activation(),
61
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
62
+ LayerNorm2d(mask_in_chans),
63
+ activation(),
64
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
65
+ )
66
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
67
+
68
+ def get_dense_pe(self) -> torch.Tensor:
69
+ """
70
+ Returns the positional encoding used to encode point prompts,
71
+ applied to a dense set of points the shape of the image encoding.
72
+
73
+ Returns:
74
+ torch.Tensor: Positional encoding with shape
75
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
76
+ """
77
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
78
+
79
+ def _embed_points(
80
+ self,
81
+ points: torch.Tensor,
82
+ labels: torch.Tensor,
83
+ pad: bool,
84
+ ) -> torch.Tensor:
85
+ """Embeds point prompts."""
86
+ points = points + 0.5 # Shift to center of pixel
87
+ if pad:
88
+ padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
89
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
90
+ points = torch.cat([points, padding_point], dim=1)
91
+ labels = torch.cat([labels, padding_label], dim=1)
92
+ point_embedding = self.pe_layer.forward_with_coords(
93
+ points, self.input_image_size
94
+ )
95
+ point_embedding[labels == -1] = 0.0
96
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
97
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
98
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
99
+ point_embedding[labels == 2] += self.point_embeddings[2].weight
100
+ point_embedding[labels == 3] += self.point_embeddings[3].weight
101
+ return point_embedding
102
+
103
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
104
+ """Embeds box prompts."""
105
+ boxes = boxes + 0.5 # Shift to center of pixel
106
+ coords = boxes.reshape(-1, 2, 2)
107
+ corner_embedding = self.pe_layer.forward_with_coords(
108
+ coords, self.input_image_size
109
+ )
110
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
111
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
112
+ return corner_embedding
113
+
114
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
115
+ """Embeds mask inputs."""
116
+ mask_embedding = self.mask_downscaling(masks)
117
+ return mask_embedding
118
+
119
+ def _get_batch_size(
120
+ self,
121
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
122
+ boxes: Optional[torch.Tensor],
123
+ masks: Optional[torch.Tensor],
124
+ ) -> int:
125
+ """
126
+ Gets the batch size of the output given the batch size of the input prompts.
127
+ """
128
+ if points is not None:
129
+ return points[0].shape[0]
130
+ elif boxes is not None:
131
+ return boxes.shape[0]
132
+ elif masks is not None:
133
+ return masks.shape[0]
134
+ else:
135
+ return 1
136
+
137
+ def _get_device(self) -> torch.device:
138
+ return self.point_embeddings[0].weight.device
139
+
140
+ def forward(
141
+ self,
142
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
143
+ boxes: Optional[torch.Tensor],
144
+ masks: Optional[torch.Tensor],
145
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
146
+ """
147
+ Embeds different types of prompts, returning both sparse and dense
148
+ embeddings.
149
+
150
+ Arguments:
151
+ points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
152
+ and labels to embed.
153
+ boxes (torch.Tensor or none): boxes to embed
154
+ masks (torch.Tensor or none): masks to embed
155
+
156
+ Returns:
157
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
158
+ BxNx(embed_dim), where N is determined by the number of input points
159
+ and boxes.
160
+ torch.Tensor: dense embeddings for the masks, in the shape
161
+ Bx(embed_dim)x(embed_H)x(embed_W)
162
+ """
163
+ bs = self._get_batch_size(points, boxes, masks)
164
+ sparse_embeddings = torch.empty(
165
+ (bs, 0, self.embed_dim), device=self._get_device()
166
+ )
167
+ if points is not None:
168
+ coords, labels = points
169
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
170
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
171
+ if boxes is not None:
172
+ box_embeddings = self._embed_boxes(boxes)
173
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
174
+
175
+ if masks is not None:
176
+ dense_embeddings = self._embed_masks(masks)
177
+ else:
178
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
179
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
180
+ )
181
+
182
+ return sparse_embeddings, dense_embeddings
segment-anything-2/sam2/modeling/sam/transformer.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import warnings
9
+ from functools import partial
10
+ from typing import Tuple, Type
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch import nn, Tensor
15
+
16
+ from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
17
+
18
+ from sam2.modeling.sam2_utils import MLP
19
+ from sam2.utils.misc import get_sdpa_settings
20
+
21
+ warnings.simplefilter(action="ignore", category=FutureWarning)
22
+ OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = True, True, True
23
+
24
+
25
+ class TwoWayTransformer(nn.Module):
26
+ def __init__(
27
+ self,
28
+ depth: int,
29
+ embedding_dim: int,
30
+ num_heads: int,
31
+ mlp_dim: int,
32
+ activation: Type[nn.Module] = nn.ReLU,
33
+ attention_downsample_rate: int = 2,
34
+ ) -> None:
35
+ """
36
+ A transformer decoder that attends to an input image using
37
+ queries whose positional embedding is supplied.
38
+
39
+ Args:
40
+ depth (int): number of layers in the transformer
41
+ embedding_dim (int): the channel dimension for the input embeddings
42
+ num_heads (int): the number of heads for multihead attention. Must
43
+ divide embedding_dim
44
+ mlp_dim (int): the channel dimension internal to the MLP block
45
+ activation (nn.Module): the activation to use in the MLP block
46
+ """
47
+ super().__init__()
48
+ self.depth = depth
49
+ self.embedding_dim = embedding_dim
50
+ self.num_heads = num_heads
51
+ self.mlp_dim = mlp_dim
52
+ self.layers = nn.ModuleList()
53
+
54
+ for i in range(depth):
55
+ self.layers.append(
56
+ TwoWayAttentionBlock(
57
+ embedding_dim=embedding_dim,
58
+ num_heads=num_heads,
59
+ mlp_dim=mlp_dim,
60
+ activation=activation,
61
+ attention_downsample_rate=attention_downsample_rate,
62
+ skip_first_layer_pe=(i == 0),
63
+ )
64
+ )
65
+
66
+ self.final_attn_token_to_image = Attention(
67
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
68
+ )
69
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
70
+
71
+ def forward(
72
+ self,
73
+ image_embedding: Tensor,
74
+ image_pe: Tensor,
75
+ point_embedding: Tensor,
76
+ ) -> Tuple[Tensor, Tensor]:
77
+ """
78
+ Args:
79
+ image_embedding (torch.Tensor): image to attend to. Should be shape
80
+ B x embedding_dim x h x w for any h and w.
81
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
82
+ have the same shape as image_embedding.
83
+ point_embedding (torch.Tensor): the embedding to add to the query points.
84
+ Must have shape B x N_points x embedding_dim for any N_points.
85
+
86
+ Returns:
87
+ torch.Tensor: the processed point_embedding
88
+ torch.Tensor: the processed image_embedding
89
+ """
90
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
91
+ bs, c, h, w = image_embedding.shape
92
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
93
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
94
+
95
+ # Prepare queries
96
+ queries = point_embedding
97
+ keys = image_embedding
98
+
99
+ # Apply transformer blocks and final layernorm
100
+ for layer in self.layers:
101
+ queries, keys = layer(
102
+ queries=queries,
103
+ keys=keys,
104
+ query_pe=point_embedding,
105
+ key_pe=image_pe,
106
+ )
107
+
108
+ # Apply the final attention layer from the points to the image
109
+ q = queries + point_embedding
110
+ k = keys + image_pe
111
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
112
+ queries = queries + attn_out
113
+ queries = self.norm_final_attn(queries)
114
+
115
+ return queries, keys
116
+
117
+
118
+ class TwoWayAttentionBlock(nn.Module):
119
+ def __init__(
120
+ self,
121
+ embedding_dim: int,
122
+ num_heads: int,
123
+ mlp_dim: int = 2048,
124
+ activation: Type[nn.Module] = nn.ReLU,
125
+ attention_downsample_rate: int = 2,
126
+ skip_first_layer_pe: bool = False,
127
+ ) -> None:
128
+ """
129
+ A transformer block with four layers: (1) self-attention of sparse
130
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
131
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
132
+ inputs.
133
+
134
+ Arguments:
135
+ embedding_dim (int): the channel dimension of the embeddings
136
+ num_heads (int): the number of heads in the attention layers
137
+ mlp_dim (int): the hidden dimension of the mlp block
138
+ activation (nn.Module): the activation of the mlp block
139
+ skip_first_layer_pe (bool): skip the PE on the first layer
140
+ """
141
+ super().__init__()
142
+ self.self_attn = Attention(embedding_dim, num_heads)
143
+ self.norm1 = nn.LayerNorm(embedding_dim)
144
+
145
+ self.cross_attn_token_to_image = Attention(
146
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
147
+ )
148
+ self.norm2 = nn.LayerNorm(embedding_dim)
149
+
150
+ self.mlp = MLP(
151
+ embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
152
+ )
153
+ self.norm3 = nn.LayerNorm(embedding_dim)
154
+
155
+ self.norm4 = nn.LayerNorm(embedding_dim)
156
+ self.cross_attn_image_to_token = Attention(
157
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
158
+ )
159
+
160
+ self.skip_first_layer_pe = skip_first_layer_pe
161
+
162
+ def forward(
163
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
164
+ ) -> Tuple[Tensor, Tensor]:
165
+ # Self attention block
166
+ if self.skip_first_layer_pe:
167
+ queries = self.self_attn(q=queries, k=queries, v=queries)
168
+ else:
169
+ q = queries + query_pe
170
+ attn_out = self.self_attn(q=q, k=q, v=queries)
171
+ queries = queries + attn_out
172
+ queries = self.norm1(queries)
173
+
174
+ # Cross attention block, tokens attending to image embedding
175
+ q = queries + query_pe
176
+ k = keys + key_pe
177
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
178
+ queries = queries + attn_out
179
+ queries = self.norm2(queries)
180
+
181
+ # MLP block
182
+ mlp_out = self.mlp(queries)
183
+ queries = queries + mlp_out
184
+ queries = self.norm3(queries)
185
+
186
+ # Cross attention block, image embedding attending to tokens
187
+ q = queries + query_pe
188
+ k = keys + key_pe
189
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
190
+ keys = keys + attn_out
191
+ keys = self.norm4(keys)
192
+
193
+ return queries, keys
194
+
195
+
196
+ class Attention(nn.Module):
197
+ """
198
+ An attention layer that allows for downscaling the size of the embedding
199
+ after projection to queries, keys, and values.
200
+ """
201
+
202
+ def __init__(
203
+ self,
204
+ embedding_dim: int,
205
+ num_heads: int,
206
+ downsample_rate: int = 1,
207
+ dropout: float = 0.0,
208
+ kv_in_dim: int = None,
209
+ ) -> None:
210
+ super().__init__()
211
+ self.embedding_dim = embedding_dim
212
+ self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
213
+ self.internal_dim = embedding_dim // downsample_rate
214
+ self.num_heads = num_heads
215
+ assert (
216
+ self.internal_dim % num_heads == 0
217
+ ), "num_heads must divide embedding_dim."
218
+
219
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
220
+ self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
221
+ self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
222
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
223
+
224
+ self.dropout_p = dropout
225
+
226
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
227
+ b, n, c = x.shape
228
+ x = x.reshape(b, n, num_heads, c // num_heads)
229
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
230
+
231
+ def _recombine_heads(self, x: Tensor) -> Tensor:
232
+ b, n_heads, n_tokens, c_per_head = x.shape
233
+ x = x.transpose(1, 2)
234
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
235
+
236
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
237
+ # Input projections
238
+ q = self.q_proj(q)
239
+ k = self.k_proj(k)
240
+ v = self.v_proj(v)
241
+
242
+ # Separate into heads
243
+ q = self._separate_heads(q, self.num_heads)
244
+ k = self._separate_heads(k, self.num_heads)
245
+ v = self._separate_heads(v, self.num_heads)
246
+
247
+ dropout_p = self.dropout_p if self.training else 0.0
248
+ # Attention
249
+ with torch.backends.cuda.sdp_kernel(
250
+ enable_flash=USE_FLASH_ATTN,
251
+ # if Flash attention kernel is off, then math kernel needs to be enabled
252
+ enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
253
+ enable_mem_efficient=OLD_GPU,
254
+ ):
255
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
256
+
257
+ out = self._recombine_heads(out)
258
+ out = self.out_proj(out)
259
+
260
+ return out
261
+
262
+
263
+ class RoPEAttention(Attention):
264
+ """Attention with rotary position encoding."""
265
+
266
+ def __init__(
267
+ self,
268
+ *args,
269
+ rope_theta=10000.0,
270
+ # whether to repeat q rope to match k length
271
+ # this is needed for cross-attention to memories
272
+ rope_k_repeat=False,
273
+ feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
274
+ **kwargs,
275
+ ):
276
+ super().__init__(*args, **kwargs)
277
+
278
+ self.compute_cis = partial(
279
+ compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
280
+ )
281
+ freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
282
+ self.freqs_cis = freqs_cis
283
+ self.rope_k_repeat = rope_k_repeat
284
+
285
+ def forward(
286
+ self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
287
+ ) -> Tensor:
288
+ # Input projections
289
+ q = self.q_proj(q)
290
+ k = self.k_proj(k)
291
+ v = self.v_proj(v)
292
+
293
+ # Separate into heads
294
+ q = self._separate_heads(q, self.num_heads)
295
+ k = self._separate_heads(k, self.num_heads)
296
+ v = self._separate_heads(v, self.num_heads)
297
+
298
+ # Apply rotary position encoding
299
+ w = h = math.sqrt(q.shape[-2])
300
+ self.freqs_cis = self.freqs_cis.to(q.device)
301
+ if self.freqs_cis.shape[0] != q.shape[-2]:
302
+ self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
303
+ if q.shape[-2] != k.shape[-2]:
304
+ assert self.rope_k_repeat
305
+
306
+ num_k_rope = k.size(-2) - num_k_exclude_rope
307
+ q, k[:, :, :num_k_rope] = apply_rotary_enc(
308
+ q,
309
+ k[:, :, :num_k_rope],
310
+ freqs_cis=self.freqs_cis,
311
+ repeat_freqs_k=self.rope_k_repeat,
312
+ )
313
+
314
+ dropout_p = self.dropout_p if self.training else 0.0
315
+ # Attention
316
+ with torch.backends.cuda.sdp_kernel(
317
+ enable_flash=USE_FLASH_ATTN,
318
+ # if Flash attention kernel is off, then math kernel needs to be enabled
319
+ enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
320
+ enable_mem_efficient=OLD_GPU,
321
+ ):
322
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
323
+
324
+ out = self._recombine_heads(out)
325
+ out = self.out_proj(out)
326
+
327
+ return out
segment-anything-2/sam2/modeling/sam2_base.py ADDED
@@ -0,0 +1,829 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.distributed
9
+ import torch.nn.functional as F
10
+
11
+ from torch.nn.init import trunc_normal_
12
+
13
+ from sam2.modeling.sam.mask_decoder import MaskDecoder
14
+ from sam2.modeling.sam.prompt_encoder import PromptEncoder
15
+ from sam2.modeling.sam.transformer import TwoWayTransformer
16
+ from sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames
17
+
18
+ # a large negative value as a placeholder score for missing objects
19
+ NO_OBJ_SCORE = -1024.0
20
+
21
+
22
+ class SAM2Base(torch.nn.Module):
23
+ def __init__(
24
+ self,
25
+ image_encoder,
26
+ memory_attention,
27
+ memory_encoder,
28
+ num_maskmem=7, # default 1 input frame + 6 previous frames
29
+ image_size=512,
30
+ backbone_stride=16, # stride of the image backbone output
31
+ sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob
32
+ sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob
33
+ # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks
34
+ binarize_mask_from_pts_for_mem_enc=False,
35
+ use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder
36
+ # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit,
37
+ # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model
38
+ # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM.
39
+ max_cond_frames_in_attn=-1,
40
+ # on the first frame, whether to directly add the no-memory embedding to the image feature
41
+ # (instead of using the transformer encoder)
42
+ directly_add_no_mem_embed=False,
43
+ # whether to use high-resolution feature maps in the SAM mask decoder
44
+ use_high_res_features_in_sam=False,
45
+ # whether to output multiple (3) masks for the first click on initial conditioning frames
46
+ multimask_output_in_sam=False,
47
+ # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`;
48
+ # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points)
49
+ multimask_min_pt_num=1,
50
+ multimask_max_pt_num=1,
51
+ # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`)
52
+ multimask_output_for_tracking=False,
53
+ # Whether to use multimask tokens for obj ptr; Only relevant when both
54
+ # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True
55
+ use_multimask_token_for_obj_ptr: bool = False,
56
+ # whether to use sigmoid to restrict ious prediction to [0-1]
57
+ iou_prediction_use_sigmoid=False,
58
+ # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5).
59
+ # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of
60
+ # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame.
61
+ memory_temporal_stride_for_eval=1,
62
+ # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
63
+ # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
64
+ add_all_frames_to_correct_as_cond=False,
65
+ # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks)
66
+ non_overlap_masks_for_mem_enc=False,
67
+ # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
68
+ use_obj_ptrs_in_encoder=False,
69
+ # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`)
70
+ max_obj_ptrs_in_encoder=16,
71
+ # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`)
72
+ add_tpos_enc_to_obj_ptrs=True,
73
+ # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference
74
+ # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
75
+ proj_tpos_enc_in_obj_ptrs=False,
76
+ # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation
77
+ # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking)
78
+ only_obj_ptrs_in_the_past_for_eval=False,
79
+ # Whether to predict if there is an object in the frame
80
+ pred_obj_scores: bool = False,
81
+ # Whether to use an MLP to predict object scores
82
+ pred_obj_scores_mlp: bool = False,
83
+ # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True;
84
+ # Whether to have a fixed no obj pointer when there is no object present
85
+ # or to use it as an additive embedding with obj_ptr produced by decoder
86
+ fixed_no_obj_ptr: bool = False,
87
+ # Soft no object, i.e. mix in no_obj_ptr softly,
88
+ # hope to make recovery easier if there is a mistake and mitigate accumulation of errors
89
+ soft_no_obj_ptr: bool = False,
90
+ use_mlp_for_obj_ptr_proj: bool = False,
91
+ # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class.
92
+ sam_mask_decoder_extra_args=None,
93
+ compile_image_encoder: bool = False,
94
+ ):
95
+ super().__init__()
96
+
97
+ # Part 1: the image backbone
98
+ self.image_encoder = image_encoder
99
+ # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
100
+ self.use_high_res_features_in_sam = use_high_res_features_in_sam
101
+ self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
102
+ self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
103
+ self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
104
+ if use_obj_ptrs_in_encoder:
105
+ # A conv layer to downsample the mask prompt to stride 4 (the same stride as
106
+ # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
107
+ # so that it can be fed into the SAM mask decoder to generate a pointer.
108
+ self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
109
+ self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
110
+ if proj_tpos_enc_in_obj_ptrs:
111
+ assert add_tpos_enc_to_obj_ptrs # these options need to be used together
112
+ self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
113
+ self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
114
+
115
+ # Part 2: memory attention to condition current frame's visual features
116
+ # with memories (and obj ptrs) from past frames
117
+ self.memory_attention = memory_attention
118
+ self.hidden_dim = memory_attention.d_model
119
+
120
+ # Part 3: memory encoder for the previous frame's outputs
121
+ self.memory_encoder = memory_encoder
122
+ self.mem_dim = self.hidden_dim
123
+ if hasattr(self.memory_encoder, "out_proj") and hasattr(
124
+ self.memory_encoder.out_proj, "weight"
125
+ ):
126
+ # if there is compression of memories along channel dim
127
+ self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
128
+ self.num_maskmem = num_maskmem # Number of memories accessible
129
+ # Temporal encoding of the memories
130
+ self.maskmem_tpos_enc = torch.nn.Parameter(
131
+ torch.zeros(num_maskmem, 1, 1, self.mem_dim)
132
+ )
133
+ trunc_normal_(self.maskmem_tpos_enc, std=0.02)
134
+ # a single token to indicate no memory embedding from previous frames
135
+ self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
136
+ self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
137
+ trunc_normal_(self.no_mem_embed, std=0.02)
138
+ trunc_normal_(self.no_mem_pos_enc, std=0.02)
139
+ self.directly_add_no_mem_embed = directly_add_no_mem_embed
140
+ # Apply sigmoid to the output raw mask logits (to turn them from
141
+ # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
142
+ self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
143
+ self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
144
+ self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
145
+ self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
146
+ self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
147
+ # On frames with mask input, whether to directly output the input mask without
148
+ # using a SAM prompt encoder + mask decoder
149
+ self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
150
+ self.multimask_output_in_sam = multimask_output_in_sam
151
+ self.multimask_min_pt_num = multimask_min_pt_num
152
+ self.multimask_max_pt_num = multimask_max_pt_num
153
+ self.multimask_output_for_tracking = multimask_output_for_tracking
154
+ self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
155
+ self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid
156
+
157
+ # Part 4: SAM-style prompt encoder (for both mask and point inputs)
158
+ # and SAM-style mask decoder for the final mask output
159
+ self.image_size = image_size
160
+ self.backbone_stride = backbone_stride
161
+ self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
162
+ self.pred_obj_scores = pred_obj_scores
163
+ self.pred_obj_scores_mlp = pred_obj_scores_mlp
164
+ self.fixed_no_obj_ptr = fixed_no_obj_ptr
165
+ self.soft_no_obj_ptr = soft_no_obj_ptr
166
+ if self.fixed_no_obj_ptr:
167
+ assert self.pred_obj_scores
168
+ assert self.use_obj_ptrs_in_encoder
169
+ if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
170
+ self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
171
+ trunc_normal_(self.no_obj_ptr, std=0.02)
172
+ self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
173
+
174
+ self._build_sam_heads()
175
+ self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
176
+ self.max_cond_frames_in_attn = max_cond_frames_in_attn
177
+
178
+ # Model compilation
179
+ if compile_image_encoder:
180
+ # Compile the forward function (not the full module) to allow loading checkpoints.
181
+ print(
182
+ "Image encoder compilation is enabled. First forward pass will be slow."
183
+ )
184
+ self.image_encoder.forward = torch.compile(
185
+ self.image_encoder.forward,
186
+ mode="max-autotune",
187
+ fullgraph=True,
188
+ dynamic=False,
189
+ )
190
+
191
+ @property
192
+ def device(self):
193
+ return next(self.parameters()).device
194
+
195
+ def forward(self, *args, **kwargs):
196
+ raise NotImplementedError(
197
+ "Please use the corresponding methods in SAM2VideoPredictor for inference."
198
+ "See notebooks/video_predictor_example.ipynb for an example."
199
+ )
200
+
201
+ def _build_sam_heads(self):
202
+ """Build SAM-style prompt encoder and mask decoder."""
203
+ self.sam_prompt_embed_dim = self.hidden_dim
204
+ self.sam_image_embedding_size = self.image_size // self.backbone_stride
205
+
206
+ # build PromptEncoder and MaskDecoder from SAM
207
+ # (their hyperparameters like `mask_in_chans=16` are from SAM code)
208
+ self.sam_prompt_encoder = PromptEncoder(
209
+ embed_dim=self.sam_prompt_embed_dim,
210
+ image_embedding_size=(
211
+ self.sam_image_embedding_size,
212
+ self.sam_image_embedding_size,
213
+ ),
214
+ input_image_size=(self.image_size, self.image_size),
215
+ mask_in_chans=16,
216
+ )
217
+ self.sam_mask_decoder = MaskDecoder(
218
+ num_multimask_outputs=3,
219
+ transformer=TwoWayTransformer(
220
+ depth=2,
221
+ embedding_dim=self.sam_prompt_embed_dim,
222
+ mlp_dim=2048,
223
+ num_heads=8,
224
+ ),
225
+ transformer_dim=self.sam_prompt_embed_dim,
226
+ iou_head_depth=3,
227
+ iou_head_hidden_dim=256,
228
+ use_high_res_features=self.use_high_res_features_in_sam,
229
+ iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
230
+ pred_obj_scores=self.pred_obj_scores,
231
+ pred_obj_scores_mlp=self.pred_obj_scores_mlp,
232
+ use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
233
+ **(self.sam_mask_decoder_extra_args or {}),
234
+ )
235
+ if self.use_obj_ptrs_in_encoder:
236
+ # a linear projection on SAM output tokens to turn them into object pointers
237
+ self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
238
+ if self.use_mlp_for_obj_ptr_proj:
239
+ self.obj_ptr_proj = MLP(
240
+ self.hidden_dim, self.hidden_dim, self.hidden_dim, 3
241
+ )
242
+ else:
243
+ self.obj_ptr_proj = torch.nn.Identity()
244
+ if self.proj_tpos_enc_in_obj_ptrs:
245
+ # a linear projection on temporal positional encoding in object pointers to
246
+ # avoid potential interference with spatial positional encoding
247
+ self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
248
+ else:
249
+ self.obj_ptr_tpos_proj = torch.nn.Identity()
250
+
251
+ def _forward_sam_heads(
252
+ self,
253
+ backbone_features,
254
+ point_inputs=None,
255
+ mask_inputs=None,
256
+ high_res_features=None,
257
+ multimask_output=False,
258
+ ):
259
+ """
260
+ Forward SAM prompt encoders and mask heads.
261
+
262
+ Inputs:
263
+ - backbone_features: image features of [B, C, H, W] shape
264
+ - point_inputs: a dictionary with "point_coords" and "point_labels", where
265
+ 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the
266
+ absolute pixel-unit coordinate in (x, y) format of the P input points
267
+ 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means
268
+ positive clicks, 0 means negative clicks, and -1 means padding
269
+ - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the
270
+ same spatial size as the image.
271
+ - high_res_features: either 1) None or 2) or a list of length 2 containing
272
+ two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively,
273
+ which will be used as high-resolution feature maps for SAM decoder.
274
+ - multimask_output: if it's True, we output 3 candidate masks and their 3
275
+ corresponding IoU estimates, and if it's False, we output only 1 mask and
276
+ its corresponding IoU estimate.
277
+
278
+ Outputs:
279
+ - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if
280
+ `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM
281
+ output mask logits (before sigmoid) for the low-resolution masks, with 4x
282
+ the resolution (1/4 stride) of the input backbone_features.
283
+ - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3
284
+ if `multimask_output=True` and M = 1 if `multimask_output=False`),
285
+ upsampled from the low-resolution masks, with shape size as the image
286
+ (stride is 1 pixel).
287
+ - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1
288
+ if `multimask_output=False`), the estimated IoU of each output mask.
289
+ - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`.
290
+ If `multimask_output=True`, it's the mask with the highest IoU estimate.
291
+ If `multimask_output=False`, it's the same as `low_res_multimasks`.
292
+ - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`.
293
+ If `multimask_output=True`, it's the mask with the highest IoU estimate.
294
+ If `multimask_output=False`, it's the same as `high_res_multimasks`.
295
+ - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted
296
+ based on the output token from the SAM mask decoder.
297
+ """
298
+ B = backbone_features.size(0)
299
+ device = backbone_features.device
300
+ assert backbone_features.size(1) == self.sam_prompt_embed_dim
301
+ assert backbone_features.size(2) == self.sam_image_embedding_size
302
+ assert backbone_features.size(3) == self.sam_image_embedding_size
303
+
304
+ # a) Handle point prompts
305
+ if point_inputs is not None:
306
+ sam_point_coords = point_inputs["point_coords"]
307
+ sam_point_labels = point_inputs["point_labels"]
308
+ assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
309
+ else:
310
+ # If no points are provide, pad with an empty point (with label -1)
311
+ sam_point_coords = torch.zeros(B, 1, 2, device=device)
312
+ sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
313
+
314
+ # b) Handle mask prompts
315
+ if mask_inputs is not None:
316
+ # If mask_inputs is provided, downsize it into low-res mask input if needed
317
+ # and feed it as a dense mask prompt into the SAM mask encoder
318
+ assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
319
+ if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
320
+ sam_mask_prompt = F.interpolate(
321
+ mask_inputs.float(),
322
+ size=self.sam_prompt_encoder.mask_input_size,
323
+ align_corners=False,
324
+ mode="bilinear",
325
+ antialias=True, # use antialias for downsampling
326
+ )
327
+ else:
328
+ sam_mask_prompt = mask_inputs
329
+ else:
330
+ # Otherwise, simply feed None (and SAM's prompt encoder will add
331
+ # a learned `no_mask_embed` to indicate no mask input in this case).
332
+ sam_mask_prompt = None
333
+
334
+ sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
335
+ points=(sam_point_coords, sam_point_labels),
336
+ boxes=None,
337
+ masks=sam_mask_prompt,
338
+ )
339
+ (
340
+ low_res_multimasks,
341
+ ious,
342
+ sam_output_tokens,
343
+ object_score_logits,
344
+ ) = self.sam_mask_decoder(
345
+ image_embeddings=backbone_features,
346
+ image_pe=self.sam_prompt_encoder.get_dense_pe(),
347
+ sparse_prompt_embeddings=sparse_embeddings,
348
+ dense_prompt_embeddings=dense_embeddings,
349
+ multimask_output=multimask_output,
350
+ repeat_image=False, # the image is already batched
351
+ high_res_features=high_res_features,
352
+ )
353
+ if self.pred_obj_scores:
354
+ is_obj_appearing = object_score_logits > 0
355
+
356
+ # Mask used for spatial memories is always a *hard* choice between obj and no obj,
357
+ # consistent with the actual mask prediction
358
+ low_res_multimasks = torch.where(
359
+ is_obj_appearing[:, None, None],
360
+ low_res_multimasks,
361
+ NO_OBJ_SCORE,
362
+ )
363
+
364
+ # convert masks from possibly bfloat16 (or float16) to float32
365
+ # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
366
+ low_res_multimasks = low_res_multimasks.float()
367
+ high_res_multimasks = F.interpolate(
368
+ low_res_multimasks,
369
+ size=(self.image_size, self.image_size),
370
+ mode="bilinear",
371
+ align_corners=False,
372
+ )
373
+
374
+ sam_output_token = sam_output_tokens[:, 0]
375
+ if multimask_output:
376
+ # take the best mask prediction (with the highest IoU estimation)
377
+ best_iou_inds = torch.argmax(ious, dim=-1)
378
+ batch_inds = torch.arange(B, device=device)
379
+ low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
380
+ high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
381
+ if sam_output_tokens.size(1) > 1:
382
+ sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
383
+ else:
384
+ low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
385
+
386
+ # Extract object pointer from the SAM output token (with occlusion handling)
387
+ obj_ptr = self.obj_ptr_proj(sam_output_token)
388
+ if self.pred_obj_scores:
389
+ # Allow *soft* no obj ptr, unlike for masks
390
+ if self.soft_no_obj_ptr:
391
+ # Only hard possible with gt
392
+ assert not self.teacher_force_obj_scores_for_mem
393
+ lambda_is_obj_appearing = object_score_logits.sigmoid()
394
+ else:
395
+ lambda_is_obj_appearing = is_obj_appearing.float()
396
+
397
+ if self.fixed_no_obj_ptr:
398
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
399
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
400
+
401
+ return (
402
+ low_res_multimasks,
403
+ high_res_multimasks,
404
+ ious,
405
+ low_res_masks,
406
+ high_res_masks,
407
+ obj_ptr,
408
+ object_score_logits,
409
+ )
410
+
411
+ def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
412
+ """
413
+ Directly turn binary `mask_inputs` into a output mask logits without using SAM.
414
+ (same input and output shapes as in _forward_sam_heads above).
415
+ """
416
+ # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
417
+ out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
418
+ mask_inputs_float = mask_inputs.float()
419
+ high_res_masks = mask_inputs_float * out_scale + out_bias
420
+ low_res_masks = F.interpolate(
421
+ high_res_masks,
422
+ size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
423
+ align_corners=False,
424
+ mode="bilinear",
425
+ antialias=True, # use antialias for downsampling
426
+ )
427
+ # a dummy IoU prediction of all 1's under mask input
428
+ ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
429
+ if not self.use_obj_ptrs_in_encoder:
430
+ # all zeros as a dummy object pointer (of shape [B, C])
431
+ obj_ptr = torch.zeros(
432
+ mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device
433
+ )
434
+ else:
435
+ # produce an object pointer using the SAM decoder from the mask input
436
+ _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
437
+ backbone_features=backbone_features,
438
+ mask_inputs=self.mask_downsample(mask_inputs_float),
439
+ high_res_features=high_res_features,
440
+ )
441
+ # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
442
+ # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
443
+ # on the object_scores from the SAM decoder.
444
+ is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
445
+ is_obj_appearing = is_obj_appearing[..., None]
446
+ lambda_is_obj_appearing = is_obj_appearing.float()
447
+ object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
448
+ if self.pred_obj_scores:
449
+ if self.fixed_no_obj_ptr:
450
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
451
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
452
+
453
+ return (
454
+ low_res_masks,
455
+ high_res_masks,
456
+ ious,
457
+ low_res_masks,
458
+ high_res_masks,
459
+ obj_ptr,
460
+ object_score_logits,
461
+ )
462
+
463
+ def forward_image(self, img_batch: torch.Tensor):
464
+ """Get the image feature on the input batch."""
465
+ backbone_out = self.image_encoder(img_batch)
466
+ if self.use_high_res_features_in_sam:
467
+ # precompute projected level 0 and level 1 features in SAM decoder
468
+ # to avoid running it again on every SAM click
469
+ backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
470
+ backbone_out["backbone_fpn"][0]
471
+ )
472
+ backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
473
+ backbone_out["backbone_fpn"][1]
474
+ )
475
+ return backbone_out
476
+
477
+ def _prepare_backbone_features(self, backbone_out):
478
+ """Prepare and flatten visual features."""
479
+ backbone_out = backbone_out.copy()
480
+ assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
481
+ assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
482
+
483
+ feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
484
+ vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
485
+
486
+ feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
487
+ # flatten NxCxHxW to HWxNxC
488
+ vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
489
+ vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
490
+
491
+ return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
492
+
493
+ def _prepare_memory_conditioned_features(
494
+ self,
495
+ frame_idx,
496
+ is_init_cond_frame,
497
+ current_vision_feats,
498
+ current_vision_pos_embeds,
499
+ feat_sizes,
500
+ output_dict,
501
+ num_frames,
502
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
503
+ ):
504
+ """Fuse the current frame's visual feature map with previous memory."""
505
+ B = current_vision_feats[-1].size(1) # batch size on this frame
506
+ C = self.hidden_dim
507
+ H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
508
+ device = current_vision_feats[-1].device
509
+ # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
510
+ # In this case, we skip the fusion with any memory.
511
+ if self.num_maskmem == 0: # Disable memory and skip fusion
512
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
513
+ return pix_feat
514
+
515
+ num_obj_ptr_tokens = 0
516
+ # Step 1: condition the visual features of the current frame on previous memories
517
+ if not is_init_cond_frame:
518
+ # Retrieve the memories encoded with the maskmem backbone
519
+ to_cat_memory, to_cat_memory_pos_embed = [], []
520
+ # Add conditioning frames's output first (all cond frames have t_pos=0 for
521
+ # when getting temporal positional embedding below)
522
+ assert len(output_dict["cond_frame_outputs"]) > 0
523
+ # Select a maximum number of temporally closest cond frames for cross attention
524
+ cond_outputs = output_dict["cond_frame_outputs"]
525
+ selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
526
+ frame_idx, cond_outputs, self.max_cond_frames_in_attn
527
+ )
528
+ t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
529
+ # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
530
+ # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
531
+ # We also allow taking the memory frame non-consecutively (with r>1), in which case
532
+ # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame.
533
+ r = self.memory_temporal_stride_for_eval
534
+ for t_pos in range(1, self.num_maskmem):
535
+ t_rel = self.num_maskmem - t_pos # how many frames before current frame
536
+ if t_rel == 1:
537
+ # for t_rel == 1, we take the last frame (regardless of r)
538
+ if not track_in_reverse:
539
+ # the frame immediately before this frame (i.e. frame_idx - 1)
540
+ prev_frame_idx = frame_idx - t_rel
541
+ else:
542
+ # the frame immediately after this frame (i.e. frame_idx + 1)
543
+ prev_frame_idx = frame_idx + t_rel
544
+ else:
545
+ # for t_rel >= 2, we take the memory frame from every r-th frames
546
+ if not track_in_reverse:
547
+ # first find the nearest frame among every r-th frames before this frame
548
+ # for r=1, this would be (frame_idx - 2)
549
+ prev_frame_idx = ((frame_idx - 2) // r) * r
550
+ # then seek further among every r-th frames
551
+ prev_frame_idx = prev_frame_idx - (t_rel - 2) * r
552
+ else:
553
+ # first find the nearest frame among every r-th frames after this frame
554
+ # for r=1, this would be (frame_idx + 2)
555
+ prev_frame_idx = -(-(frame_idx + 2) // r) * r
556
+ # then seek further among every r-th frames
557
+ prev_frame_idx = prev_frame_idx + (t_rel - 2) * r
558
+ out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
559
+ if out is None:
560
+ # If an unselected conditioning frame is among the last (self.num_maskmem - 1)
561
+ # frames, we still attend to it as if it's a non-conditioning frame.
562
+ out = unselected_cond_outputs.get(prev_frame_idx, None)
563
+ t_pos_and_prevs.append((t_pos, out))
564
+
565
+ for t_pos, prev in t_pos_and_prevs:
566
+ if prev is None:
567
+ continue # skip padding frames
568
+ # "maskmem_features" might have been offloaded to CPU in demo use cases,
569
+ # so we load it back to GPU (it's a no-op if it's already on GPU).
570
+ feats = prev["maskmem_features"].cuda(non_blocking=True)
571
+ to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
572
+ # Spatial positional encoding (it might have been offloaded to CPU in eval)
573
+ maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
574
+ maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
575
+ # Temporal positional encoding
576
+ maskmem_enc = (
577
+ maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
578
+ )
579
+ to_cat_memory_pos_embed.append(maskmem_enc)
580
+
581
+ # Construct the list of past object pointers
582
+ if self.use_obj_ptrs_in_encoder:
583
+ max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
584
+ # First add those object pointers from selected conditioning frames
585
+ # (optionally, only include object pointers in the past during evaluation)
586
+ if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
587
+ ptr_cond_outputs = {
588
+ t: out
589
+ for t, out in selected_cond_outputs.items()
590
+ if (t >= frame_idx if track_in_reverse else t <= frame_idx)
591
+ }
592
+ else:
593
+ ptr_cond_outputs = selected_cond_outputs
594
+ pos_and_ptrs = [
595
+ # Temporal pos encoding contains how far away each pointer is from current frame
596
+ (abs(frame_idx - t), out["obj_ptr"])
597
+ for t, out in ptr_cond_outputs.items()
598
+ ]
599
+ # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
600
+ for t_diff in range(1, max_obj_ptrs_in_encoder):
601
+ t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
602
+ if t < 0 or (num_frames is not None and t >= num_frames):
603
+ break
604
+ out = output_dict["non_cond_frame_outputs"].get(
605
+ t, unselected_cond_outputs.get(t, None)
606
+ )
607
+ if out is not None:
608
+ pos_and_ptrs.append((t_diff, out["obj_ptr"]))
609
+ # If we have at least one object pointer, add them to the across attention
610
+ if len(pos_and_ptrs) > 0:
611
+ pos_list, ptrs_list = zip(*pos_and_ptrs)
612
+ # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
613
+ obj_ptrs = torch.stack(ptrs_list, dim=0)
614
+ # a temporal positional embedding based on how far each object pointer is from
615
+ # the current frame (sine embedding normalized by the max pointer num).
616
+ if self.add_tpos_enc_to_obj_ptrs:
617
+ t_diff_max = max_obj_ptrs_in_encoder - 1
618
+ tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
619
+ obj_pos = torch.tensor(pos_list, device=device)
620
+ obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
621
+ obj_pos = self.obj_ptr_tpos_proj(obj_pos)
622
+ obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
623
+ else:
624
+ obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
625
+ if self.mem_dim < C:
626
+ # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
627
+ obj_ptrs = obj_ptrs.reshape(
628
+ -1, B, C // self.mem_dim, self.mem_dim
629
+ )
630
+ obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
631
+ obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
632
+ to_cat_memory.append(obj_ptrs)
633
+ to_cat_memory_pos_embed.append(obj_pos)
634
+ num_obj_ptr_tokens = obj_ptrs.shape[0]
635
+ else:
636
+ num_obj_ptr_tokens = 0
637
+ else:
638
+ # for initial conditioning frames, encode them without using any previous memory
639
+ if self.directly_add_no_mem_embed:
640
+ # directly add no-mem embedding (instead of using the transformer encoder)
641
+ pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
642
+ pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
643
+ return pix_feat_with_mem
644
+
645
+ # Use a dummy token on the first frame (to avoid emtpy memory input to tranformer encoder)
646
+ to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
647
+ to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
648
+
649
+ # Step 2: Concatenate the memories and forward through the transformer encoder
650
+ memory = torch.cat(to_cat_memory, dim=0)
651
+ memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
652
+
653
+ pix_feat_with_mem = self.memory_attention(
654
+ curr=current_vision_feats,
655
+ curr_pos=current_vision_pos_embeds,
656
+ memory=memory,
657
+ memory_pos=memory_pos_embed,
658
+ num_obj_ptr_tokens=num_obj_ptr_tokens,
659
+ )
660
+ # reshape the output (HW)BC => BCHW
661
+ pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
662
+ return pix_feat_with_mem
663
+
664
+ def _encode_new_memory(
665
+ self,
666
+ current_vision_feats,
667
+ feat_sizes,
668
+ pred_masks_high_res,
669
+ is_mask_from_pts,
670
+ ):
671
+ """Encode the current image and its prediction into a memory feature."""
672
+ B = current_vision_feats[-1].size(1) # batch size on this frame
673
+ C = self.hidden_dim
674
+ H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
675
+ # top-level feature, (HW)BC => BCHW
676
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
677
+ if self.non_overlap_masks_for_mem_enc and not self.training:
678
+ # optionally, apply non-overlapping constraints to the masks (it's applied
679
+ # in the batch dimension and should only be used during eval, where all
680
+ # the objects come from the same video under batch size 1).
681
+ pred_masks_high_res = self._apply_non_overlapping_constraints(
682
+ pred_masks_high_res
683
+ )
684
+ # scale the raw mask logits with a temperature before applying sigmoid
685
+ binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
686
+ if binarize and not self.training:
687
+ mask_for_mem = (pred_masks_high_res > 0).float()
688
+ else:
689
+ # apply sigmoid on the raw mask logits to turn them into range (0, 1)
690
+ mask_for_mem = torch.sigmoid(pred_masks_high_res)
691
+ # apply scale and bias terms to the sigmoid probabilities
692
+ if self.sigmoid_scale_for_mem_enc != 1.0:
693
+ mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
694
+ if self.sigmoid_bias_for_mem_enc != 0.0:
695
+ mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
696
+ maskmem_out = self.memory_encoder(
697
+ pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied
698
+ )
699
+ maskmem_features = maskmem_out["vision_features"]
700
+ maskmem_pos_enc = maskmem_out["vision_pos_enc"]
701
+
702
+ return maskmem_features, maskmem_pos_enc
703
+
704
+ def track_step(
705
+ self,
706
+ frame_idx,
707
+ is_init_cond_frame,
708
+ current_vision_feats,
709
+ current_vision_pos_embeds,
710
+ feat_sizes,
711
+ point_inputs,
712
+ mask_inputs,
713
+ output_dict,
714
+ num_frames,
715
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
716
+ # Whether to run the memory encoder on the predicted masks. Sometimes we might want
717
+ # to skip the memory encoder with `run_mem_encoder=False`. For example,
718
+ # in demo we might call `track_step` multiple times for each user click,
719
+ # and only encode the memory when the user finalizes their clicks. And in ablation
720
+ # settings like SAM training on static images, we don't need the memory encoder.
721
+ run_mem_encoder=True,
722
+ # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
723
+ prev_sam_mask_logits=None,
724
+ ):
725
+ current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
726
+ # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
727
+ if len(current_vision_feats) > 1:
728
+ high_res_features = [
729
+ x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
730
+ for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
731
+ ]
732
+ else:
733
+ high_res_features = None
734
+ if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
735
+ # When use_mask_input_as_output_without_sam=True, we directly output the mask input
736
+ # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
737
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0)
738
+ pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
739
+ sam_outputs = self._use_mask_as_output(
740
+ pix_feat, high_res_features, mask_inputs
741
+ )
742
+ else:
743
+ # fused the visual feature with previous memory features in the memory bank
744
+ pix_feat_with_mem = self._prepare_memory_conditioned_features(
745
+ frame_idx=frame_idx,
746
+ is_init_cond_frame=is_init_cond_frame,
747
+ current_vision_feats=current_vision_feats[-1:],
748
+ current_vision_pos_embeds=current_vision_pos_embeds[-1:],
749
+ feat_sizes=feat_sizes[-1:],
750
+ output_dict=output_dict,
751
+ num_frames=num_frames,
752
+ track_in_reverse=track_in_reverse,
753
+ )
754
+ # apply SAM-style segmentation head
755
+ # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
756
+ # e.g. in demo where such logits come from earlier interaction instead of correction sampling
757
+ # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
758
+ if prev_sam_mask_logits is not None:
759
+ assert point_inputs is not None and mask_inputs is None
760
+ mask_inputs = prev_sam_mask_logits
761
+ multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
762
+ sam_outputs = self._forward_sam_heads(
763
+ backbone_features=pix_feat_with_mem,
764
+ point_inputs=point_inputs,
765
+ mask_inputs=mask_inputs,
766
+ high_res_features=high_res_features,
767
+ multimask_output=multimask_output,
768
+ )
769
+ (
770
+ _,
771
+ _,
772
+ _,
773
+ low_res_masks,
774
+ high_res_masks,
775
+ obj_ptr,
776
+ _,
777
+ ) = sam_outputs
778
+
779
+ current_out["pred_masks"] = low_res_masks
780
+ current_out["pred_masks_high_res"] = high_res_masks
781
+ current_out["obj_ptr"] = obj_ptr
782
+
783
+ # Finally run the memory encoder on the predicted mask to encode
784
+ # it into a new memory feature (that can be used in future frames)
785
+ if run_mem_encoder and self.num_maskmem > 0:
786
+ high_res_masks_for_mem_enc = high_res_masks
787
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
788
+ current_vision_feats=current_vision_feats,
789
+ feat_sizes=feat_sizes,
790
+ pred_masks_high_res=high_res_masks_for_mem_enc,
791
+ is_mask_from_pts=(point_inputs is not None),
792
+ )
793
+ current_out["maskmem_features"] = maskmem_features
794
+ current_out["maskmem_pos_enc"] = maskmem_pos_enc
795
+ else:
796
+ current_out["maskmem_features"] = None
797
+ current_out["maskmem_pos_enc"] = None
798
+
799
+ return current_out
800
+
801
+ def _use_multimask(self, is_init_cond_frame, point_inputs):
802
+ """Whether to use multimask output in the SAM head."""
803
+ num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
804
+ multimask_output = (
805
+ self.multimask_output_in_sam
806
+ and (is_init_cond_frame or self.multimask_output_for_tracking)
807
+ and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
808
+ )
809
+ return multimask_output
810
+
811
+ def _apply_non_overlapping_constraints(self, pred_masks):
812
+ """
813
+ Apply non-overlapping constraints to the object scores in pred_masks. Here we
814
+ keep only the highest scoring object at each spatial location in pred_masks.
815
+ """
816
+ batch_size = pred_masks.size(0)
817
+ if batch_size == 1:
818
+ return pred_masks
819
+
820
+ device = pred_masks.device
821
+ # "max_obj_inds": object index of the object with the highest score at each location
822
+ max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
823
+ # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
824
+ batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
825
+ keep = max_obj_inds == batch_obj_inds
826
+ # suppress overlapping regions' scores below -10.0 so that the foreground regions
827
+ # don't overlap (here sigmoid(-10.0)=4.5398e-05)
828
+ pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
829
+ return pred_masks
segment-anything-2/sam2/modeling/sam2_utils.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import copy
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
16
+ """
17
+ Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs`
18
+ that are temporally closest to the current frame at `frame_idx`. Here, we take
19
+ - a) the closest conditioning frame before `frame_idx` (if any);
20
+ - b) the closest conditioning frame after `frame_idx` (if any);
21
+ - c) any other temporally closest conditioning frames until reaching a total
22
+ of `max_cond_frame_num` conditioning frames.
23
+
24
+ Outputs:
25
+ - selected_outputs: selected items (keys & values) from `cond_frame_outputs`.
26
+ - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`.
27
+ """
28
+ if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
29
+ selected_outputs = cond_frame_outputs
30
+ unselected_outputs = {}
31
+ else:
32
+ assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
33
+ selected_outputs = {}
34
+
35
+ # the closest conditioning frame before `frame_idx` (if any)
36
+ idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
37
+ if idx_before is not None:
38
+ selected_outputs[idx_before] = cond_frame_outputs[idx_before]
39
+
40
+ # the closest conditioning frame after `frame_idx` (if any)
41
+ idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
42
+ if idx_after is not None:
43
+ selected_outputs[idx_after] = cond_frame_outputs[idx_after]
44
+
45
+ # add other temporally closest conditioning frames until reaching a total
46
+ # of `max_cond_frame_num` conditioning frames.
47
+ num_remain = max_cond_frame_num - len(selected_outputs)
48
+ inds_remain = sorted(
49
+ (t for t in cond_frame_outputs if t not in selected_outputs),
50
+ key=lambda x: abs(x - frame_idx),
51
+ )[:num_remain]
52
+ selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
53
+ unselected_outputs = {
54
+ t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs
55
+ }
56
+
57
+ return selected_outputs, unselected_outputs
58
+
59
+
60
+ def get_1d_sine_pe(pos_inds, dim, temperature=10000):
61
+ """
62
+ Get 1D sine positional embedding as in the original Transformer paper.
63
+ """
64
+ pe_dim = dim // 2
65
+ dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
66
+ dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
67
+
68
+ pos_embed = pos_inds.unsqueeze(-1) / dim_t
69
+ pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
70
+ return pos_embed
71
+
72
+
73
+ def get_activation_fn(activation):
74
+ """Return an activation function given a string"""
75
+ if activation == "relu":
76
+ return F.relu
77
+ if activation == "gelu":
78
+ return F.gelu
79
+ if activation == "glu":
80
+ return F.glu
81
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
82
+
83
+
84
+ def get_clones(module, N):
85
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
86
+
87
+
88
+ class DropPath(nn.Module):
89
+ # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
90
+ def __init__(self, drop_prob=0.0, scale_by_keep=True):
91
+ super(DropPath, self).__init__()
92
+ self.drop_prob = drop_prob
93
+ self.scale_by_keep = scale_by_keep
94
+
95
+ def forward(self, x):
96
+ if self.drop_prob == 0.0 or not self.training:
97
+ return x
98
+ keep_prob = 1 - self.drop_prob
99
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
100
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
101
+ if keep_prob > 0.0 and self.scale_by_keep:
102
+ random_tensor.div_(keep_prob)
103
+ return x * random_tensor
104
+
105
+
106
+ # Lightly adapted from
107
+ # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
108
+ class MLP(nn.Module):
109
+ def __init__(
110
+ self,
111
+ input_dim: int,
112
+ hidden_dim: int,
113
+ output_dim: int,
114
+ num_layers: int,
115
+ activation: nn.Module = nn.ReLU,
116
+ sigmoid_output: bool = False,
117
+ ) -> None:
118
+ super().__init__()
119
+ self.num_layers = num_layers
120
+ h = [hidden_dim] * (num_layers - 1)
121
+ self.layers = nn.ModuleList(
122
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
123
+ )
124
+ self.sigmoid_output = sigmoid_output
125
+ self.act = activation()
126
+
127
+ def forward(self, x):
128
+ for i, layer in enumerate(self.layers):
129
+ x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
130
+ if self.sigmoid_output:
131
+ x = F.sigmoid(x)
132
+ return x
133
+
134
+
135
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
136
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
137
+ class LayerNorm2d(nn.Module):
138
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
139
+ super().__init__()
140
+ self.weight = nn.Parameter(torch.ones(num_channels))
141
+ self.bias = nn.Parameter(torch.zeros(num_channels))
142
+ self.eps = eps
143
+
144
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
145
+ u = x.mean(1, keepdim=True)
146
+ s = (x - u).pow(2).mean(1, keepdim=True)
147
+ x = (x - u) / torch.sqrt(s + self.eps)
148
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
149
+ return x
segment-anything-2/sam2/sam2_image_predictor.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+
9
+ from typing import List, Optional, Tuple, Union
10
+
11
+ import numpy as np
12
+ import torch
13
+ from PIL.Image import Image
14
+
15
+ from sam2.modeling.sam2_base import SAM2Base
16
+
17
+ from sam2.utils.transforms import SAM2Transforms
18
+
19
+
20
+ class SAM2ImagePredictor:
21
+ def __init__(
22
+ self,
23
+ sam_model: SAM2Base,
24
+ mask_threshold=0.0,
25
+ max_hole_area=0.0,
26
+ max_sprinkle_area=0.0,
27
+ ) -> None:
28
+ """
29
+ Uses SAM-2 to calculate the image embedding for an image, and then
30
+ allow repeated, efficient mask prediction given prompts.
31
+
32
+ Arguments:
33
+ sam_model (Sam-2): The model to use for mask prediction.
34
+ mask_threshold (float): The threshold to use when converting mask logits
35
+ to binary masks. Masks are thresholded at 0 by default.
36
+ fill_hole_area (int): If fill_hole_area > 0, we fill small holes in up to
37
+ the maximum area of fill_hole_area in low_res_masks.
38
+ """
39
+ super().__init__()
40
+ self.model = sam_model
41
+ self._transforms = SAM2Transforms(
42
+ resolution=self.model.image_size,
43
+ mask_threshold=mask_threshold,
44
+ max_hole_area=max_hole_area,
45
+ max_sprinkle_area=max_sprinkle_area,
46
+ )
47
+
48
+ # Predictor state
49
+ self._is_image_set = False
50
+ self._features = None
51
+ self._orig_hw = None
52
+ # Whether the predictor is set for single image or a batch of images
53
+ self._is_batch = False
54
+
55
+ # Predictor config
56
+ self.mask_threshold = mask_threshold
57
+
58
+ # Spatial dim for backbone feature maps
59
+ self._bb_feat_sizes = [
60
+ (256, 256),
61
+ (128, 128),
62
+ (64, 64),
63
+ ]
64
+
65
+ @torch.no_grad()
66
+ def set_image(
67
+ self,
68
+ image: Union[np.ndarray, Image],
69
+ ) -> None:
70
+ """
71
+ Calculates the image embeddings for the provided image, allowing
72
+ masks to be predicted with the 'predict' method.
73
+
74
+ Arguments:
75
+ image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image
76
+ with pixel values in [0, 255].
77
+ image_format (str): The color format of the image, in ['RGB', 'BGR'].
78
+ """
79
+ self.reset_predictor()
80
+ # Transform the image to the form expected by the model
81
+ if isinstance(image, np.ndarray):
82
+ logging.info("For numpy array image, we assume (HxWxC) format")
83
+ self._orig_hw = [image.shape[:2]]
84
+ elif isinstance(image, Image):
85
+ w, h = image.size
86
+ self._orig_hw = [(h, w)]
87
+ else:
88
+ raise NotImplementedError("Image format not supported")
89
+
90
+ input_image = self._transforms(image)
91
+ input_image = input_image[None, ...].to(self.device)
92
+
93
+ assert (
94
+ len(input_image.shape) == 4 and input_image.shape[1] == 3
95
+ ), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
96
+ logging.info("Computing image embeddings for the provided image...")
97
+ backbone_out = self.model.forward_image(input_image)
98
+ _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
99
+ # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
100
+ if self.model.directly_add_no_mem_embed:
101
+ vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
102
+
103
+ feats = [
104
+ feat.permute(1, 2, 0).view(1, -1, *feat_size)
105
+ for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
106
+ ][::-1]
107
+ self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
108
+ self._is_image_set = True
109
+ logging.info("Image embeddings computed.")
110
+
111
+ @torch.no_grad()
112
+ def set_image_batch(
113
+ self,
114
+ image_list: List[Union[np.ndarray]],
115
+ ) -> None:
116
+ """
117
+ Calculates the image embeddings for the provided image batch, allowing
118
+ masks to be predicted with the 'predict_batch' method.
119
+
120
+ Arguments:
121
+ image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray
122
+ with pixel values in [0, 255].
123
+ """
124
+ self.reset_predictor()
125
+ assert isinstance(image_list, list)
126
+ self._orig_hw = []
127
+ for image in image_list:
128
+ assert isinstance(
129
+ image, np.ndarray
130
+ ), "Images are expected to be an np.ndarray in RGB format, and of shape HWC"
131
+ self._orig_hw.append(image.shape[:2])
132
+ # Transform the image to the form expected by the model
133
+ img_batch = self._transforms.forward_batch(image_list)
134
+ img_batch = img_batch.to(self.device)
135
+ batch_size = img_batch.shape[0]
136
+ assert (
137
+ len(img_batch.shape) == 4 and img_batch.shape[1] == 3
138
+ ), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
139
+ logging.info("Computing image embeddings for the provided images...")
140
+ backbone_out = self.model.forward_image(img_batch)
141
+ _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
142
+ # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
143
+ if self.model.directly_add_no_mem_embed:
144
+ vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
145
+
146
+ feats = [
147
+ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
148
+ for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
149
+ ][::-1]
150
+ self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
151
+ self._is_image_set = True
152
+ self._is_batch = True
153
+ logging.info("Image embeddings computed.")
154
+
155
+ def predict_batch(
156
+ self,
157
+ point_coords_batch: List[np.ndarray] = None,
158
+ point_labels_batch: List[np.ndarray] = None,
159
+ box_batch: List[np.ndarray] = None,
160
+ mask_input_batch: List[np.ndarray] = None,
161
+ multimask_output: bool = True,
162
+ return_logits: bool = False,
163
+ normalize_coords=True,
164
+ ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
165
+ """This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images.
166
+ It returns a tupele of lists of masks, ious, and low_res_masks_logits.
167
+ """
168
+ assert self._is_batch, "This function should only be used when in batched mode"
169
+ if not self._is_image_set:
170
+ raise RuntimeError(
171
+ "An image must be set with .set_image_batch(...) before mask prediction."
172
+ )
173
+ num_images = len(self._features["image_embed"])
174
+ all_masks = []
175
+ all_ious = []
176
+ all_low_res_masks = []
177
+ for img_idx in range(num_images):
178
+ # Transform input prompts
179
+ point_coords = (
180
+ point_coords_batch[img_idx] if point_coords_batch is not None else None
181
+ )
182
+ point_labels = (
183
+ point_labels_batch[img_idx] if point_labels_batch is not None else None
184
+ )
185
+ box = box_batch[img_idx] if box_batch is not None else None
186
+ mask_input = (
187
+ mask_input_batch[img_idx] if mask_input_batch is not None else None
188
+ )
189
+ mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
190
+ point_coords,
191
+ point_labels,
192
+ box,
193
+ mask_input,
194
+ normalize_coords,
195
+ img_idx=img_idx,
196
+ )
197
+ masks, iou_predictions, low_res_masks = self._predict(
198
+ unnorm_coords,
199
+ labels,
200
+ unnorm_box,
201
+ mask_input,
202
+ multimask_output,
203
+ return_logits=return_logits,
204
+ img_idx=img_idx,
205
+ )
206
+ masks_np = masks.squeeze(0).float().detach().cpu().numpy()
207
+ iou_predictions_np = (
208
+ iou_predictions.squeeze(0).float().detach().cpu().numpy()
209
+ )
210
+ low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
211
+ all_masks.append(masks_np)
212
+ all_ious.append(iou_predictions_np)
213
+ all_low_res_masks.append(low_res_masks_np)
214
+
215
+ return all_masks, all_ious, all_low_res_masks
216
+
217
+ def predict(
218
+ self,
219
+ point_coords: Optional[np.ndarray] = None,
220
+ point_labels: Optional[np.ndarray] = None,
221
+ box: Optional[np.ndarray] = None,
222
+ mask_input: Optional[np.ndarray] = None,
223
+ multimask_output: bool = True,
224
+ return_logits: bool = False,
225
+ normalize_coords=True,
226
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
227
+ """
228
+ Predict masks for the given input prompts, using the currently set image.
229
+
230
+ Arguments:
231
+ point_coords (np.ndarray or None): A Nx2 array of point prompts to the
232
+ model. Each point is in (X,Y) in pixels.
233
+ point_labels (np.ndarray or None): A length N array of labels for the
234
+ point prompts. 1 indicates a foreground point and 0 indicates a
235
+ background point.
236
+ box (np.ndarray or None): A length 4 array given a box prompt to the
237
+ model, in XYXY format.
238
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
239
+ coming from a previous prediction iteration. Has form 1xHxW, where
240
+ for SAM, H=W=256.
241
+ multimask_output (bool): If true, the model will return three masks.
242
+ For ambiguous input prompts (such as a single click), this will often
243
+ produce better masks than a single prediction. If only a single
244
+ mask is needed, the model's predicted quality score can be used
245
+ to select the best mask. For non-ambiguous prompts, such as multiple
246
+ input prompts, multimask_output=False can give better results.
247
+ return_logits (bool): If true, returns un-thresholded masks logits
248
+ instead of a binary mask.
249
+ normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions.
250
+
251
+ Returns:
252
+ (np.ndarray): The output masks in CxHxW format, where C is the
253
+ number of masks, and (H, W) is the original image size.
254
+ (np.ndarray): An array of length C containing the model's
255
+ predictions for the quality of each mask.
256
+ (np.ndarray): An array of shape CxHxW, where C is the number
257
+ of masks and H=W=256. These low resolution logits can be passed to
258
+ a subsequent iteration as mask input.
259
+ """
260
+ if not self._is_image_set:
261
+ raise RuntimeError(
262
+ "An image must be set with .set_image(...) before mask prediction."
263
+ )
264
+
265
+ # Transform input prompts
266
+
267
+ mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
268
+ point_coords, point_labels, box, mask_input, normalize_coords
269
+ )
270
+
271
+ masks, iou_predictions, low_res_masks = self._predict(
272
+ unnorm_coords,
273
+ labels,
274
+ unnorm_box,
275
+ mask_input,
276
+ multimask_output,
277
+ return_logits=return_logits,
278
+ )
279
+
280
+ masks_np = masks.squeeze(0).float().detach().cpu().numpy()
281
+ iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy()
282
+ low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
283
+ return masks_np, iou_predictions_np, low_res_masks_np
284
+
285
+ def _prep_prompts(
286
+ self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1
287
+ ):
288
+
289
+ unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
290
+ if point_coords is not None:
291
+ assert (
292
+ point_labels is not None
293
+ ), "point_labels must be supplied if point_coords is supplied."
294
+ point_coords = torch.as_tensor(
295
+ point_coords, dtype=torch.float, device=self.device
296
+ )
297
+ unnorm_coords = self._transforms.transform_coords(
298
+ point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
299
+ )
300
+ labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
301
+ if len(unnorm_coords.shape) == 2:
302
+ unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]
303
+ if box is not None:
304
+ box = torch.as_tensor(box, dtype=torch.float, device=self.device)
305
+ unnorm_box = self._transforms.transform_boxes(
306
+ box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
307
+ ) # Bx2x2
308
+ if mask_logits is not None:
309
+ mask_input = torch.as_tensor(
310
+ mask_logits, dtype=torch.float, device=self.device
311
+ )
312
+ if len(mask_input.shape) == 3:
313
+ mask_input = mask_input[None, :, :, :]
314
+ return mask_input, unnorm_coords, labels, unnorm_box
315
+
316
+ @torch.no_grad()
317
+ def _predict(
318
+ self,
319
+ point_coords: Optional[torch.Tensor],
320
+ point_labels: Optional[torch.Tensor],
321
+ boxes: Optional[torch.Tensor] = None,
322
+ mask_input: Optional[torch.Tensor] = None,
323
+ multimask_output: bool = True,
324
+ return_logits: bool = False,
325
+ img_idx: int = -1,
326
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
327
+ """
328
+ Predict masks for the given input prompts, using the currently set image.
329
+ Input prompts are batched torch tensors and are expected to already be
330
+ transformed to the input frame using SAM2Transforms.
331
+
332
+ Arguments:
333
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
334
+ model. Each point is in (X,Y) in pixels.
335
+ point_labels (torch.Tensor or None): A BxN array of labels for the
336
+ point prompts. 1 indicates a foreground point and 0 indicates a
337
+ background point.
338
+ boxes (np.ndarray or None): A Bx4 array given a box prompt to the
339
+ model, in XYXY format.
340
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
341
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
342
+ for SAM, H=W=256. Masks returned by a previous iteration of the
343
+ predict method do not need further transformation.
344
+ multimask_output (bool): If true, the model will return three masks.
345
+ For ambiguous input prompts (such as a single click), this will often
346
+ produce better masks than a single prediction. If only a single
347
+ mask is needed, the model's predicted quality score can be used
348
+ to select the best mask. For non-ambiguous prompts, such as multiple
349
+ input prompts, multimask_output=False can give better results.
350
+ return_logits (bool): If true, returns un-thresholded masks logits
351
+ instead of a binary mask.
352
+
353
+ Returns:
354
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
355
+ number of masks, and (H, W) is the original image size.
356
+ (torch.Tensor): An array of shape BxC containing the model's
357
+ predictions for the quality of each mask.
358
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
359
+ of masks and H=W=256. These low res logits can be passed to
360
+ a subsequent iteration as mask input.
361
+ """
362
+ if not self._is_image_set:
363
+ raise RuntimeError(
364
+ "An image must be set with .set_image(...) before mask prediction."
365
+ )
366
+
367
+ if point_coords is not None:
368
+ concat_points = (point_coords, point_labels)
369
+ else:
370
+ concat_points = None
371
+
372
+ # Embed prompts
373
+ if boxes is not None:
374
+ box_coords = boxes.reshape(-1, 2, 2)
375
+ box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
376
+ box_labels = box_labels.repeat(boxes.size(0), 1)
377
+ # we merge "boxes" and "points" into a single "concat_points" input (where
378
+ # boxes are added at the beginning) to sam_prompt_encoder
379
+ if concat_points is not None:
380
+ concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
381
+ concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
382
+ concat_points = (concat_coords, concat_labels)
383
+ else:
384
+ concat_points = (box_coords, box_labels)
385
+
386
+ sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
387
+ points=concat_points,
388
+ boxes=None,
389
+ masks=mask_input,
390
+ )
391
+
392
+ # Predict masks
393
+ batched_mode = (
394
+ concat_points is not None and concat_points[0].shape[0] > 1
395
+ ) # multi object prediction
396
+ high_res_features = [
397
+ feat_level[img_idx].unsqueeze(0)
398
+ for feat_level in self._features["high_res_feats"]
399
+ ]
400
+ low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
401
+ image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0),
402
+ image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
403
+ sparse_prompt_embeddings=sparse_embeddings,
404
+ dense_prompt_embeddings=dense_embeddings,
405
+ multimask_output=multimask_output,
406
+ repeat_image=batched_mode,
407
+ high_res_features=high_res_features,
408
+ )
409
+
410
+ # Upscale the masks to the original image resolution
411
+ masks = self._transforms.postprocess_masks(
412
+ low_res_masks, self._orig_hw[img_idx]
413
+ )
414
+ low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
415
+ if not return_logits:
416
+ masks = masks > self.mask_threshold
417
+
418
+ return masks, iou_predictions, low_res_masks
419
+
420
+ def get_image_embedding(self) -> torch.Tensor:
421
+ """
422
+ Returns the image embeddings for the currently set image, with
423
+ shape 1xCxHxW, where C is the embedding dimension and (H,W) are
424
+ the embedding spatial dimension of SAM (typically C=256, H=W=64).
425
+ """
426
+ if not self._is_image_set:
427
+ raise RuntimeError(
428
+ "An image must be set with .set_image(...) to generate an embedding."
429
+ )
430
+ assert (
431
+ self._features is not None
432
+ ), "Features must exist if an image has been set."
433
+ return self._features["image_embed"]
434
+
435
+ @property
436
+ def device(self) -> torch.device:
437
+ return self.model.device
438
+
439
+ def reset_predictor(self) -> None:
440
+ """
441
+ Resets the image embeddings and other state variables.
442
+ """
443
+ self._is_image_set = False
444
+ self._features = None
445
+ self._orig_hw = None
446
+ self._is_batch = False
segment-anything-2/sam2/sam2_video_predictor.py ADDED
@@ -0,0 +1,898 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from collections import OrderedDict
8
+
9
+ import torch
10
+
11
+ from tqdm import tqdm
12
+
13
+ from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
14
+ from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
15
+
16
+
17
+ class SAM2VideoPredictor(SAM2Base):
18
+ """The predictor class to handle user interactions and manage inference states."""
19
+
20
+ def __init__(
21
+ self,
22
+ fill_hole_area=0,
23
+ # whether to apply non-overlapping constraints on the output object masks
24
+ non_overlap_masks=False,
25
+ # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
26
+ # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
27
+ clear_non_cond_mem_around_input=False,
28
+ # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
29
+ clear_non_cond_mem_for_multi_obj=False,
30
+ **kwargs,
31
+ ):
32
+ super().__init__(**kwargs)
33
+ self.fill_hole_area = fill_hole_area
34
+ self.non_overlap_masks = non_overlap_masks
35
+ self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
36
+ self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
37
+
38
+ @torch.inference_mode()
39
+ def init_state(
40
+ self,
41
+ video_path,
42
+ offload_video_to_cpu=False,
43
+ offload_state_to_cpu=False,
44
+ async_loading_frames=False,
45
+ ):
46
+ """Initialize a inference state."""
47
+ images, video_height, video_width = load_video_frames(
48
+ video_path=video_path,
49
+ image_size=self.image_size,
50
+ offload_video_to_cpu=offload_video_to_cpu,
51
+ async_loading_frames=async_loading_frames,
52
+ )
53
+ inference_state = {}
54
+ inference_state["images"] = images
55
+ inference_state["num_frames"] = len(images)
56
+ # whether to offload the video frames to CPU memory
57
+ # turning on this option saves the GPU memory with only a very small overhead
58
+ inference_state["offload_video_to_cpu"] = offload_video_to_cpu
59
+ # whether to offload the inference state to CPU memory
60
+ # turning on this option saves the GPU memory at the cost of a lower tracking fps
61
+ # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
62
+ # and from 24 to 21 when tracking two objects)
63
+ inference_state["offload_state_to_cpu"] = offload_state_to_cpu
64
+ # the original video height and width, used for resizing final output scores
65
+ inference_state["video_height"] = video_height
66
+ inference_state["video_width"] = video_width
67
+ inference_state["device"] = torch.device("cuda")
68
+ if offload_state_to_cpu:
69
+ inference_state["storage_device"] = torch.device("cpu")
70
+ else:
71
+ inference_state["storage_device"] = torch.device("cuda")
72
+ # inputs on each frame
73
+ inference_state["point_inputs_per_obj"] = {}
74
+ inference_state["mask_inputs_per_obj"] = {}
75
+ # visual features on a small number of recently visited frames for quick interactions
76
+ inference_state["cached_features"] = {}
77
+ # values that don't change across frames (so we only need to hold one copy of them)
78
+ inference_state["constants"] = {}
79
+ # mapping between client-side object id and model-side object index
80
+ inference_state["obj_id_to_idx"] = OrderedDict()
81
+ inference_state["obj_idx_to_id"] = OrderedDict()
82
+ inference_state["obj_ids"] = []
83
+ # A storage to hold the model's tracking results and states on each frame
84
+ inference_state["output_dict"] = {
85
+ "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
86
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
87
+ }
88
+ # Slice (view) of each object tracking results, sharing the same memory with "output_dict"
89
+ inference_state["output_dict_per_obj"] = {}
90
+ # A temporary storage to hold new outputs when user interact with a frame
91
+ # to add clicks or mask (it's merged into "output_dict" before propagation starts)
92
+ inference_state["temp_output_dict_per_obj"] = {}
93
+ # Frames that already holds consolidated outputs from click or mask inputs
94
+ # (we directly use their consolidated outputs during tracking)
95
+ inference_state["consolidated_frame_inds"] = {
96
+ "cond_frame_outputs": set(), # set containing frame indices
97
+ "non_cond_frame_outputs": set(), # set containing frame indices
98
+ }
99
+ # metadata for each tracking frame (e.g. which direction it's tracked)
100
+ inference_state["tracking_has_started"] = False
101
+ inference_state["frames_already_tracked"] = {}
102
+ # Warm up the visual backbone and cache the image feature on frame 0
103
+ self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
104
+ return inference_state
105
+
106
+ def _obj_id_to_idx(self, inference_state, obj_id):
107
+ """Map client-side object id to model-side object index."""
108
+ obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
109
+ if obj_idx is not None:
110
+ return obj_idx
111
+
112
+ # This is a new object id not sent to the server before. We only allow adding
113
+ # new objects *before* the tracking starts.
114
+ allow_new_object = not inference_state["tracking_has_started"]
115
+ if allow_new_object:
116
+ # get the next object slot
117
+ obj_idx = len(inference_state["obj_id_to_idx"])
118
+ inference_state["obj_id_to_idx"][obj_id] = obj_idx
119
+ inference_state["obj_idx_to_id"][obj_idx] = obj_id
120
+ inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
121
+ # set up input and output structures for this object
122
+ inference_state["point_inputs_per_obj"][obj_idx] = {}
123
+ inference_state["mask_inputs_per_obj"][obj_idx] = {}
124
+ inference_state["output_dict_per_obj"][obj_idx] = {
125
+ "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
126
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
127
+ }
128
+ inference_state["temp_output_dict_per_obj"][obj_idx] = {
129
+ "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
130
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
131
+ }
132
+ return obj_idx
133
+ else:
134
+ raise RuntimeError(
135
+ f"Cannot add new object id {obj_id} after tracking starts. "
136
+ f"All existing object ids: {inference_state['obj_ids']}. "
137
+ f"Please call 'reset_state' to restart from scratch."
138
+ )
139
+
140
+ def _obj_idx_to_id(self, inference_state, obj_idx):
141
+ """Map model-side object index to client-side object id."""
142
+ return inference_state["obj_idx_to_id"][obj_idx]
143
+
144
+ def _get_obj_num(self, inference_state):
145
+ """Get the total number of unique object ids received so far in this session."""
146
+ return len(inference_state["obj_idx_to_id"])
147
+
148
+ @torch.inference_mode()
149
+ def add_new_points(
150
+ self,
151
+ inference_state,
152
+ frame_idx,
153
+ obj_id,
154
+ points,
155
+ labels,
156
+ clear_old_points=True,
157
+ normalize_coords=True,
158
+ ):
159
+ """Add new points to a frame."""
160
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
161
+ point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
162
+ mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
163
+
164
+ if not isinstance(points, torch.Tensor):
165
+ points = torch.tensor(points, dtype=torch.float32)
166
+ if not isinstance(labels, torch.Tensor):
167
+ labels = torch.tensor(labels, dtype=torch.int32)
168
+ if points.dim() == 2:
169
+ points = points.unsqueeze(0) # add batch dimension
170
+ if labels.dim() == 1:
171
+ labels = labels.unsqueeze(0) # add batch dimension
172
+ if normalize_coords:
173
+ video_H = inference_state["video_height"]
174
+ video_W = inference_state["video_width"]
175
+ points = points / torch.tensor([video_W, video_H]).to(points.device)
176
+ # scale the (normalized) coordinates by the model's internal image size
177
+ points = points * self.image_size
178
+ points = points.to(inference_state["device"])
179
+ labels = labels.to(inference_state["device"])
180
+
181
+ if not clear_old_points:
182
+ point_inputs = point_inputs_per_frame.get(frame_idx, None)
183
+ else:
184
+ point_inputs = None
185
+ point_inputs = concat_points(point_inputs, points, labels)
186
+
187
+ point_inputs_per_frame[frame_idx] = point_inputs
188
+ mask_inputs_per_frame.pop(frame_idx, None)
189
+ # If this frame hasn't been tracked before, we treat it as an initial conditioning
190
+ # frame, meaning that the inputs points are to generate segments on this frame without
191
+ # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
192
+ # the input points will be used to correct the already tracked masks.
193
+ is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
194
+ # whether to track in reverse time order
195
+ if is_init_cond_frame:
196
+ reverse = False
197
+ else:
198
+ reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
199
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
200
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
201
+ # Add a frame to conditioning output if it's an initial conditioning frame or
202
+ # if the model sees all frames receiving clicks/mask as conditioning frames.
203
+ is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
204
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
205
+
206
+ # Get any previously predicted mask logits on this object and feed it along with
207
+ # the new clicks into the SAM mask decoder.
208
+ prev_sam_mask_logits = None
209
+ # lookup temporary output dict first, which contains the most recent output
210
+ # (if not found, then lookup conditioning and non-conditioning frame output)
211
+ prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
212
+ if prev_out is None:
213
+ prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
214
+ if prev_out is None:
215
+ prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
216
+
217
+ if prev_out is not None and prev_out["pred_masks"] is not None:
218
+ prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
219
+ # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
220
+ prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
221
+ current_out, _ = self._run_single_frame_inference(
222
+ inference_state=inference_state,
223
+ output_dict=obj_output_dict, # run on the slice of a single object
224
+ frame_idx=frame_idx,
225
+ batch_size=1, # run on the slice of a single object
226
+ is_init_cond_frame=is_init_cond_frame,
227
+ point_inputs=point_inputs,
228
+ mask_inputs=None,
229
+ reverse=reverse,
230
+ # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
231
+ # at the beginning of `propagate_in_video` (after user finalize their clicks). This
232
+ # allows us to enforce non-overlapping constraints on all objects before encoding
233
+ # them into memory.
234
+ run_mem_encoder=False,
235
+ prev_sam_mask_logits=prev_sam_mask_logits,
236
+ )
237
+ # Add the output to the output dict (to be used as future memory)
238
+ obj_temp_output_dict[storage_key][frame_idx] = current_out
239
+
240
+ # Resize the output mask to the original video resolution
241
+ obj_ids = inference_state["obj_ids"]
242
+ consolidated_out = self._consolidate_temp_output_across_obj(
243
+ inference_state,
244
+ frame_idx,
245
+ is_cond=is_cond,
246
+ run_mem_encoder=False,
247
+ consolidate_at_video_res=True,
248
+ )
249
+ _, video_res_masks = self._get_orig_video_res_output(
250
+ inference_state, consolidated_out["pred_masks_video_res"]
251
+ )
252
+ return frame_idx, obj_ids, video_res_masks
253
+
254
+ @torch.inference_mode()
255
+ def add_new_mask(
256
+ self,
257
+ inference_state,
258
+ frame_idx,
259
+ obj_id,
260
+ mask,
261
+ ):
262
+ """Add new mask to a frame."""
263
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
264
+ point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
265
+ mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
266
+
267
+ if not isinstance(mask, torch.Tensor):
268
+ mask = torch.tensor(mask, dtype=torch.bool)
269
+ assert mask.dim() == 2
270
+ mask_H, mask_W = mask.shape
271
+ mask_inputs_orig = mask[None, None] # add batch and channel dimension
272
+ mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"])
273
+
274
+ # resize the mask if it doesn't match the model's image size
275
+ if mask_H != self.image_size or mask_W != self.image_size:
276
+ mask_inputs = torch.nn.functional.interpolate(
277
+ mask_inputs_orig,
278
+ size=(self.image_size, self.image_size),
279
+ align_corners=False,
280
+ mode="bilinear",
281
+ antialias=True, # use antialias for downsampling
282
+ )
283
+ mask_inputs = (mask_inputs >= 0.5).float()
284
+ else:
285
+ mask_inputs = mask_inputs_orig
286
+
287
+ mask_inputs_per_frame[frame_idx] = mask_inputs
288
+ point_inputs_per_frame.pop(frame_idx, None)
289
+ # If this frame hasn't been tracked before, we treat it as an initial conditioning
290
+ # frame, meaning that the inputs points are to generate segments on this frame without
291
+ # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
292
+ # the input points will be used to correct the already tracked masks.
293
+ is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
294
+ # whether to track in reverse time order
295
+ if is_init_cond_frame:
296
+ reverse = False
297
+ else:
298
+ reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
299
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
300
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
301
+ # Add a frame to conditioning output if it's an initial conditioning frame or
302
+ # if the model sees all frames receiving clicks/mask as conditioning frames.
303
+ is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
304
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
305
+
306
+ current_out, _ = self._run_single_frame_inference(
307
+ inference_state=inference_state,
308
+ output_dict=obj_output_dict, # run on the slice of a single object
309
+ frame_idx=frame_idx,
310
+ batch_size=1, # run on the slice of a single object
311
+ is_init_cond_frame=is_init_cond_frame,
312
+ point_inputs=None,
313
+ mask_inputs=mask_inputs,
314
+ reverse=reverse,
315
+ # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
316
+ # at the beginning of `propagate_in_video` (after user finalize their clicks). This
317
+ # allows us to enforce non-overlapping constraints on all objects before encoding
318
+ # them into memory.
319
+ run_mem_encoder=False,
320
+ )
321
+ # Add the output to the output dict (to be used as future memory)
322
+ obj_temp_output_dict[storage_key][frame_idx] = current_out
323
+
324
+ # Resize the output mask to the original video resolution
325
+ obj_ids = inference_state["obj_ids"]
326
+ consolidated_out = self._consolidate_temp_output_across_obj(
327
+ inference_state,
328
+ frame_idx,
329
+ is_cond=is_cond,
330
+ run_mem_encoder=False,
331
+ consolidate_at_video_res=True,
332
+ )
333
+ _, video_res_masks = self._get_orig_video_res_output(
334
+ inference_state, consolidated_out["pred_masks_video_res"]
335
+ )
336
+ return frame_idx, obj_ids, video_res_masks
337
+
338
+ def _get_orig_video_res_output(self, inference_state, any_res_masks):
339
+ """
340
+ Resize the object scores to the original video resolution (video_res_masks)
341
+ and apply non-overlapping constraints for final output.
342
+ """
343
+ device = inference_state["device"]
344
+ video_H = inference_state["video_height"]
345
+ video_W = inference_state["video_width"]
346
+ any_res_masks = any_res_masks.to(device, non_blocking=True)
347
+ if any_res_masks.shape[-2:] == (video_H, video_W):
348
+ video_res_masks = any_res_masks
349
+ else:
350
+ video_res_masks = torch.nn.functional.interpolate(
351
+ any_res_masks,
352
+ size=(video_H, video_W),
353
+ mode="bilinear",
354
+ align_corners=False,
355
+ )
356
+ if self.non_overlap_masks:
357
+ video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
358
+ return any_res_masks, video_res_masks
359
+
360
+ def _consolidate_temp_output_across_obj(
361
+ self,
362
+ inference_state,
363
+ frame_idx,
364
+ is_cond,
365
+ run_mem_encoder,
366
+ consolidate_at_video_res=False,
367
+ ):
368
+ """
369
+ Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
370
+ a frame into a single output for all objects, including
371
+ 1) fill any missing objects either from `output_dict_per_obj` (if they exist in
372
+ `output_dict_per_obj` for this frame) or leave them as placeholder values
373
+ (if they don't exist in `output_dict_per_obj` for this frame);
374
+ 2) if specified, rerun memory encoder after apply non-overlapping constraints
375
+ on the object scores.
376
+ """
377
+ batch_size = self._get_obj_num(inference_state)
378
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
379
+ # Optionally, we allow consolidating the temporary outputs at the original
380
+ # video resolution (to provide a better editing experience for mask prompts).
381
+ if consolidate_at_video_res:
382
+ assert not run_mem_encoder, "memory encoder cannot run at video resolution"
383
+ consolidated_H = inference_state["video_height"]
384
+ consolidated_W = inference_state["video_width"]
385
+ consolidated_mask_key = "pred_masks_video_res"
386
+ else:
387
+ consolidated_H = consolidated_W = self.image_size // 4
388
+ consolidated_mask_key = "pred_masks"
389
+
390
+ # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
391
+ # will be added when rerunning the memory encoder after applying non-overlapping
392
+ # constraints to object scores. Its "pred_masks" are prefilled with a large
393
+ # negative value (NO_OBJ_SCORE) to represent missing objects.
394
+ consolidated_out = {
395
+ "maskmem_features": None,
396
+ "maskmem_pos_enc": None,
397
+ consolidated_mask_key: torch.full(
398
+ size=(batch_size, 1, consolidated_H, consolidated_W),
399
+ fill_value=NO_OBJ_SCORE,
400
+ dtype=torch.float32,
401
+ device=inference_state["storage_device"],
402
+ ),
403
+ "obj_ptr": torch.full(
404
+ size=(batch_size, self.hidden_dim),
405
+ fill_value=NO_OBJ_SCORE,
406
+ dtype=torch.float32,
407
+ device=inference_state["device"],
408
+ ),
409
+ }
410
+ empty_mask_ptr = None
411
+ for obj_idx in range(batch_size):
412
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
413
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
414
+ out = obj_temp_output_dict[storage_key].get(frame_idx, None)
415
+ # If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
416
+ # we fall back and look up its previous output in "output_dict_per_obj".
417
+ # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
418
+ # "output_dict_per_obj" to find a previous output for this object.
419
+ if out is None:
420
+ out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
421
+ if out is None:
422
+ out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
423
+ # If the object doesn't appear in "output_dict_per_obj" either, we skip it
424
+ # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
425
+ # placeholder above) and set its object pointer to be a dummy pointer.
426
+ if out is None:
427
+ # Fill in dummy object pointers for those objects without any inputs or
428
+ # tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
429
+ # i.e. when we need to build the memory for tracking).
430
+ if run_mem_encoder:
431
+ if empty_mask_ptr is None:
432
+ empty_mask_ptr = self._get_empty_mask_ptr(
433
+ inference_state, frame_idx
434
+ )
435
+ # fill object pointer with a dummy pointer (based on an empty mask)
436
+ consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
437
+ continue
438
+ # Add the temporary object output mask to consolidated output mask
439
+ obj_mask = out["pred_masks"]
440
+ consolidated_pred_masks = consolidated_out[consolidated_mask_key]
441
+ if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
442
+ consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
443
+ else:
444
+ # Resize first if temporary object mask has a different resolution
445
+ resized_obj_mask = torch.nn.functional.interpolate(
446
+ obj_mask,
447
+ size=consolidated_pred_masks.shape[-2:],
448
+ mode="bilinear",
449
+ align_corners=False,
450
+ )
451
+ consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
452
+ consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
453
+
454
+ # Optionally, apply non-overlapping constraints on the consolidated scores
455
+ # and rerun the memory encoder
456
+ if run_mem_encoder:
457
+ device = inference_state["device"]
458
+ high_res_masks = torch.nn.functional.interpolate(
459
+ consolidated_out["pred_masks"].to(device, non_blocking=True),
460
+ size=(self.image_size, self.image_size),
461
+ mode="bilinear",
462
+ align_corners=False,
463
+ )
464
+ if self.non_overlap_masks_for_mem_enc:
465
+ high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
466
+ maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
467
+ inference_state=inference_state,
468
+ frame_idx=frame_idx,
469
+ batch_size=batch_size,
470
+ high_res_masks=high_res_masks,
471
+ is_mask_from_pts=True, # these frames are what the user interacted with
472
+ )
473
+ consolidated_out["maskmem_features"] = maskmem_features
474
+ consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
475
+
476
+ return consolidated_out
477
+
478
+ def _get_empty_mask_ptr(self, inference_state, frame_idx):
479
+ """Get a dummy object pointer based on an empty mask on the current frame."""
480
+ # A dummy (empty) mask with a single object
481
+ batch_size = 1
482
+ mask_inputs = torch.zeros(
483
+ (batch_size, 1, self.image_size, self.image_size),
484
+ dtype=torch.float32,
485
+ device=inference_state["device"],
486
+ )
487
+
488
+ # Retrieve correct image features
489
+ (
490
+ _,
491
+ _,
492
+ current_vision_feats,
493
+ current_vision_pos_embeds,
494
+ feat_sizes,
495
+ ) = self._get_image_feature(inference_state, frame_idx, batch_size)
496
+
497
+ # Feed the empty mask and image feature above to get a dummy object pointer
498
+ current_out = self.track_step(
499
+ frame_idx=frame_idx,
500
+ is_init_cond_frame=True,
501
+ current_vision_feats=current_vision_feats,
502
+ current_vision_pos_embeds=current_vision_pos_embeds,
503
+ feat_sizes=feat_sizes,
504
+ point_inputs=None,
505
+ mask_inputs=mask_inputs,
506
+ output_dict={},
507
+ num_frames=inference_state["num_frames"],
508
+ track_in_reverse=False,
509
+ run_mem_encoder=False,
510
+ prev_sam_mask_logits=None,
511
+ )
512
+ return current_out["obj_ptr"]
513
+
514
+ @torch.inference_mode()
515
+ def propagate_in_video_preflight(self, inference_state):
516
+ """Prepare inference_state and consolidate temporary outputs before tracking."""
517
+ # Tracking has started and we don't allow adding new objects until session is reset.
518
+ inference_state["tracking_has_started"] = True
519
+ batch_size = self._get_obj_num(inference_state)
520
+
521
+ # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
522
+ # add them into "output_dict".
523
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
524
+ output_dict = inference_state["output_dict"]
525
+ # "consolidated_frame_inds" contains indices of those frames where consolidated
526
+ # temporary outputs have been added (either in this call or any previous calls
527
+ # to `propagate_in_video_preflight`).
528
+ consolidated_frame_inds = inference_state["consolidated_frame_inds"]
529
+ for is_cond in [False, True]:
530
+ # Separately consolidate conditioning and non-conditioning temp outptus
531
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
532
+ # Find all the frames that contain temporary outputs for any objects
533
+ # (these should be the frames that have just received clicks for mask inputs
534
+ # via `add_new_points` or `add_new_mask`)
535
+ temp_frame_inds = set()
536
+ for obj_temp_output_dict in temp_output_dict_per_obj.values():
537
+ temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
538
+ consolidated_frame_inds[storage_key].update(temp_frame_inds)
539
+ # consolidate the temprary output across all objects on this frame
540
+ for frame_idx in temp_frame_inds:
541
+ consolidated_out = self._consolidate_temp_output_across_obj(
542
+ inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
543
+ )
544
+ # merge them into "output_dict" and also create per-object slices
545
+ output_dict[storage_key][frame_idx] = consolidated_out
546
+ self._add_output_per_object(
547
+ inference_state, frame_idx, consolidated_out, storage_key
548
+ )
549
+ clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
550
+ self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
551
+ )
552
+ if clear_non_cond_mem:
553
+ # clear non-conditioning memory of the surrounding frames
554
+ self._clear_non_cond_mem_around_input(inference_state, frame_idx)
555
+
556
+ # clear temporary outputs in `temp_output_dict_per_obj`
557
+ for obj_temp_output_dict in temp_output_dict_per_obj.values():
558
+ obj_temp_output_dict[storage_key].clear()
559
+
560
+ # edge case: if an output is added to "cond_frame_outputs", we remove any prior
561
+ # output on the same frame in "non_cond_frame_outputs"
562
+ for frame_idx in output_dict["cond_frame_outputs"]:
563
+ output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
564
+ for obj_output_dict in inference_state["output_dict_per_obj"].values():
565
+ for frame_idx in obj_output_dict["cond_frame_outputs"]:
566
+ obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
567
+ for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
568
+ assert frame_idx in output_dict["cond_frame_outputs"]
569
+ consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
570
+
571
+ # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
572
+ # with either points or mask inputs (which should be true under a correct workflow).
573
+ all_consolidated_frame_inds = (
574
+ consolidated_frame_inds["cond_frame_outputs"]
575
+ | consolidated_frame_inds["non_cond_frame_outputs"]
576
+ )
577
+ input_frames_inds = set()
578
+ for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
579
+ input_frames_inds.update(point_inputs_per_frame.keys())
580
+ for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
581
+ input_frames_inds.update(mask_inputs_per_frame.keys())
582
+ assert all_consolidated_frame_inds == input_frames_inds
583
+
584
+ @torch.inference_mode()
585
+ def propagate_in_video(
586
+ self,
587
+ inference_state,
588
+ start_frame_idx=None,
589
+ max_frame_num_to_track=None,
590
+ reverse=False,
591
+ ):
592
+ """Propagate the input points across frames to track in the entire video."""
593
+ self.propagate_in_video_preflight(inference_state)
594
+
595
+ output_dict = inference_state["output_dict"]
596
+ consolidated_frame_inds = inference_state["consolidated_frame_inds"]
597
+ obj_ids = inference_state["obj_ids"]
598
+ num_frames = inference_state["num_frames"]
599
+ batch_size = self._get_obj_num(inference_state)
600
+ if len(output_dict["cond_frame_outputs"]) == 0:
601
+ raise RuntimeError("No points are provided; please add points first")
602
+ clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
603
+ self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
604
+ )
605
+
606
+ # set start index, end index, and processing order
607
+ if start_frame_idx is None:
608
+ # default: start from the earliest frame with input points
609
+ start_frame_idx = min(output_dict["cond_frame_outputs"])
610
+ if max_frame_num_to_track is None:
611
+ # default: track all the frames in the video
612
+ max_frame_num_to_track = num_frames
613
+ if reverse:
614
+ end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
615
+ if start_frame_idx > 0:
616
+ processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
617
+ else:
618
+ processing_order = [] # skip reverse tracking if starting from frame 0
619
+ else:
620
+ end_frame_idx = min(
621
+ start_frame_idx + max_frame_num_to_track, num_frames - 1
622
+ )
623
+ processing_order = range(start_frame_idx, end_frame_idx + 1)
624
+
625
+ for frame_idx in tqdm(processing_order, desc="propagate in video"):
626
+ # We skip those frames already in consolidated outputs (these are frames
627
+ # that received input clicks or mask). Note that we cannot directly run
628
+ # batched forward on them via `_run_single_frame_inference` because the
629
+ # number of clicks on each object might be different.
630
+ if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
631
+ storage_key = "cond_frame_outputs"
632
+ current_out = output_dict[storage_key][frame_idx]
633
+ pred_masks = current_out["pred_masks"]
634
+ if clear_non_cond_mem:
635
+ # clear non-conditioning memory of the surrounding frames
636
+ self._clear_non_cond_mem_around_input(inference_state, frame_idx)
637
+ elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
638
+ storage_key = "non_cond_frame_outputs"
639
+ current_out = output_dict[storage_key][frame_idx]
640
+ pred_masks = current_out["pred_masks"]
641
+ else:
642
+ storage_key = "non_cond_frame_outputs"
643
+ current_out, pred_masks = self._run_single_frame_inference(
644
+ inference_state=inference_state,
645
+ output_dict=output_dict,
646
+ frame_idx=frame_idx,
647
+ batch_size=batch_size,
648
+ is_init_cond_frame=False,
649
+ point_inputs=None,
650
+ mask_inputs=None,
651
+ reverse=reverse,
652
+ run_mem_encoder=True,
653
+ )
654
+ output_dict[storage_key][frame_idx] = current_out
655
+ # Create slices of per-object outputs for subsequent interaction with each
656
+ # individual object after tracking.
657
+ self._add_output_per_object(
658
+ inference_state, frame_idx, current_out, storage_key
659
+ )
660
+ inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
661
+
662
+ # Resize the output mask to the original video resolution (we directly use
663
+ # the mask scores on GPU for output to avoid any CPU conversion in between)
664
+ _, video_res_masks = self._get_orig_video_res_output(
665
+ inference_state, pred_masks
666
+ )
667
+ yield frame_idx, obj_ids, video_res_masks
668
+
669
+ def _add_output_per_object(
670
+ self, inference_state, frame_idx, current_out, storage_key
671
+ ):
672
+ """
673
+ Split a multi-object output into per-object output slices and add them into
674
+ `output_dict_per_obj`. The resulting slices share the same tensor storage.
675
+ """
676
+ maskmem_features = current_out["maskmem_features"]
677
+ assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
678
+
679
+ maskmem_pos_enc = current_out["maskmem_pos_enc"]
680
+ assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
681
+
682
+ output_dict_per_obj = inference_state["output_dict_per_obj"]
683
+ for obj_idx, obj_output_dict in output_dict_per_obj.items():
684
+ obj_slice = slice(obj_idx, obj_idx + 1)
685
+ obj_out = {
686
+ "maskmem_features": None,
687
+ "maskmem_pos_enc": None,
688
+ "pred_masks": current_out["pred_masks"][obj_slice],
689
+ "obj_ptr": current_out["obj_ptr"][obj_slice],
690
+ }
691
+ if maskmem_features is not None:
692
+ obj_out["maskmem_features"] = maskmem_features[obj_slice]
693
+ if maskmem_pos_enc is not None:
694
+ obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
695
+ obj_output_dict[storage_key][frame_idx] = obj_out
696
+
697
+ @torch.inference_mode()
698
+ def reset_state(self, inference_state):
699
+ """Remove all input points or mask in all frames throughout the video."""
700
+ self._reset_tracking_results(inference_state)
701
+ # Remove all object ids
702
+ inference_state["obj_id_to_idx"].clear()
703
+ inference_state["obj_idx_to_id"].clear()
704
+ inference_state["obj_ids"].clear()
705
+ inference_state["point_inputs_per_obj"].clear()
706
+ inference_state["mask_inputs_per_obj"].clear()
707
+ inference_state["output_dict_per_obj"].clear()
708
+ inference_state["temp_output_dict_per_obj"].clear()
709
+
710
+ def _reset_tracking_results(self, inference_state):
711
+ """Reset all tracking inputs and results across the videos."""
712
+ for v in inference_state["point_inputs_per_obj"].values():
713
+ v.clear()
714
+ for v in inference_state["mask_inputs_per_obj"].values():
715
+ v.clear()
716
+ for v in inference_state["output_dict_per_obj"].values():
717
+ v["cond_frame_outputs"].clear()
718
+ v["non_cond_frame_outputs"].clear()
719
+ for v in inference_state["temp_output_dict_per_obj"].values():
720
+ v["cond_frame_outputs"].clear()
721
+ v["non_cond_frame_outputs"].clear()
722
+ inference_state["output_dict"]["cond_frame_outputs"].clear()
723
+ inference_state["output_dict"]["non_cond_frame_outputs"].clear()
724
+ inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
725
+ inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
726
+ inference_state["tracking_has_started"] = False
727
+ inference_state["frames_already_tracked"].clear()
728
+
729
+ def _get_image_feature(self, inference_state, frame_idx, batch_size):
730
+ """Compute the image features on a given frame."""
731
+ # Look up in the cache first
732
+ image, backbone_out = inference_state["cached_features"].get(
733
+ frame_idx, (None, None)
734
+ )
735
+ if backbone_out is None:
736
+ # Cache miss -- we will run inference on a single image
737
+ image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
738
+ backbone_out = self.forward_image(image)
739
+ # Cache the most recent frame's feature (for repeated interactions with
740
+ # a frame; we can use an LRU cache for more frames in the future).
741
+ inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
742
+
743
+ # expand the features to have the same dimension as the number of objects
744
+ expanded_image = image.expand(batch_size, -1, -1, -1)
745
+ expanded_backbone_out = {
746
+ "backbone_fpn": backbone_out["backbone_fpn"].copy(),
747
+ "vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
748
+ }
749
+ for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
750
+ expanded_backbone_out["backbone_fpn"][i] = feat.expand(
751
+ batch_size, -1, -1, -1
752
+ )
753
+ for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
754
+ pos = pos.expand(batch_size, -1, -1, -1)
755
+ expanded_backbone_out["vision_pos_enc"][i] = pos
756
+
757
+ features = self._prepare_backbone_features(expanded_backbone_out)
758
+ features = (expanded_image,) + features
759
+ return features
760
+
761
+ def _run_single_frame_inference(
762
+ self,
763
+ inference_state,
764
+ output_dict,
765
+ frame_idx,
766
+ batch_size,
767
+ is_init_cond_frame,
768
+ point_inputs,
769
+ mask_inputs,
770
+ reverse,
771
+ run_mem_encoder,
772
+ prev_sam_mask_logits=None,
773
+ ):
774
+ """Run tracking on a single frame based on current inputs and previous memory."""
775
+ # Retrieve correct image features
776
+ (
777
+ _,
778
+ _,
779
+ current_vision_feats,
780
+ current_vision_pos_embeds,
781
+ feat_sizes,
782
+ ) = self._get_image_feature(inference_state, frame_idx, batch_size)
783
+
784
+ # point and mask should not appear as input simultaneously on the same frame
785
+ assert point_inputs is None or mask_inputs is None
786
+ current_out = self.track_step(
787
+ frame_idx=frame_idx,
788
+ is_init_cond_frame=is_init_cond_frame,
789
+ current_vision_feats=current_vision_feats,
790
+ current_vision_pos_embeds=current_vision_pos_embeds,
791
+ feat_sizes=feat_sizes,
792
+ point_inputs=point_inputs,
793
+ mask_inputs=mask_inputs,
794
+ output_dict=output_dict,
795
+ num_frames=inference_state["num_frames"],
796
+ track_in_reverse=reverse,
797
+ run_mem_encoder=run_mem_encoder,
798
+ prev_sam_mask_logits=prev_sam_mask_logits,
799
+ )
800
+
801
+ # optionally offload the output to CPU memory to save GPU space
802
+ storage_device = inference_state["storage_device"]
803
+ maskmem_features = current_out["maskmem_features"]
804
+ if maskmem_features is not None:
805
+ maskmem_features = maskmem_features.to(torch.bfloat16)
806
+ maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
807
+ pred_masks_gpu = current_out["pred_masks"]
808
+ # potentially fill holes in the predicted masks
809
+ if self.fill_hole_area > 0:
810
+ pred_masks_gpu = fill_holes_in_mask_scores(
811
+ pred_masks_gpu, self.fill_hole_area
812
+ )
813
+ pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
814
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
815
+ maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
816
+ # object pointer is a small tensor, so we always keep it on GPU memory for fast access
817
+ obj_ptr = current_out["obj_ptr"]
818
+ # make a compact version of this frame's output to reduce the state size
819
+ compact_current_out = {
820
+ "maskmem_features": maskmem_features,
821
+ "maskmem_pos_enc": maskmem_pos_enc,
822
+ "pred_masks": pred_masks,
823
+ "obj_ptr": obj_ptr,
824
+ }
825
+ return compact_current_out, pred_masks_gpu
826
+
827
+ def _run_memory_encoder(
828
+ self, inference_state, frame_idx, batch_size, high_res_masks, is_mask_from_pts
829
+ ):
830
+ """
831
+ Run the memory encoder on `high_res_masks`. This is usually after applying
832
+ non-overlapping constraints to object scores. Since their scores changed, their
833
+ memory also need to be computed again with the memory encoder.
834
+ """
835
+ # Retrieve correct image features
836
+ _, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
837
+ inference_state, frame_idx, batch_size
838
+ )
839
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
840
+ current_vision_feats=current_vision_feats,
841
+ feat_sizes=feat_sizes,
842
+ pred_masks_high_res=high_res_masks,
843
+ is_mask_from_pts=is_mask_from_pts,
844
+ )
845
+
846
+ # optionally offload the output to CPU memory to save GPU space
847
+ storage_device = inference_state["storage_device"]
848
+ maskmem_features = maskmem_features.to(torch.bfloat16)
849
+ maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
850
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
851
+ maskmem_pos_enc = self._get_maskmem_pos_enc(
852
+ inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
853
+ )
854
+ return maskmem_features, maskmem_pos_enc
855
+
856
+ def _get_maskmem_pos_enc(self, inference_state, current_out):
857
+ """
858
+ `maskmem_pos_enc` is the same across frames and objects, so we cache it as
859
+ a constant in the inference session to reduce session storage size.
860
+ """
861
+ model_constants = inference_state["constants"]
862
+ # "out_maskmem_pos_enc" should be either a list of tensors or None
863
+ out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
864
+ if out_maskmem_pos_enc is not None:
865
+ if "maskmem_pos_enc" not in model_constants:
866
+ assert isinstance(out_maskmem_pos_enc, list)
867
+ # only take the slice for one object, since it's same across objects
868
+ maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
869
+ model_constants["maskmem_pos_enc"] = maskmem_pos_enc
870
+ else:
871
+ maskmem_pos_enc = model_constants["maskmem_pos_enc"]
872
+ # expand the cached maskmem_pos_enc to the actual batch size
873
+ batch_size = out_maskmem_pos_enc[0].size(0)
874
+ expanded_maskmem_pos_enc = [
875
+ x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
876
+ ]
877
+ else:
878
+ expanded_maskmem_pos_enc = None
879
+ return expanded_maskmem_pos_enc
880
+
881
+ def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
882
+ """
883
+ Remove the non-conditioning memory around the input frame. When users provide
884
+ correction clicks, the surrounding frames' non-conditioning memories can still
885
+ contain outdated object appearance information and could confuse the model.
886
+
887
+ This method clears those non-conditioning memories surrounding the interacted
888
+ frame to avoid giving the model both old and new information about the object.
889
+ """
890
+ r = self.memory_temporal_stride_for_eval
891
+ frame_idx_begin = frame_idx - r * self.num_maskmem
892
+ frame_idx_end = frame_idx + r * self.num_maskmem
893
+ output_dict = inference_state["output_dict"]
894
+ non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
895
+ for t in range(frame_idx_begin, frame_idx_end + 1):
896
+ non_cond_frame_outputs.pop(t, None)
897
+ for obj_output_dict in inference_state["output_dict_per_obj"].values():
898
+ obj_output_dict["non_cond_frame_outputs"].pop(t, None)
segment-anything-2/sam2/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
segment-anything-2/sam2/utils/amg.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from copy import deepcopy
9
+ from itertools import product
10
+ from typing import Any, Dict, Generator, ItemsView, List, Tuple
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+ # Very lightly adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/utils/amg.py
16
+
17
+
18
+ class MaskData:
19
+ """
20
+ A structure for storing masks and their related data in batched format.
21
+ Implements basic filtering and concatenation.
22
+ """
23
+
24
+ def __init__(self, **kwargs) -> None:
25
+ for v in kwargs.values():
26
+ assert isinstance(
27
+ v, (list, np.ndarray, torch.Tensor)
28
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
29
+ self._stats = dict(**kwargs)
30
+
31
+ def __setitem__(self, key: str, item: Any) -> None:
32
+ assert isinstance(
33
+ item, (list, np.ndarray, torch.Tensor)
34
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
35
+ self._stats[key] = item
36
+
37
+ def __delitem__(self, key: str) -> None:
38
+ del self._stats[key]
39
+
40
+ def __getitem__(self, key: str) -> Any:
41
+ return self._stats[key]
42
+
43
+ def items(self) -> ItemsView[str, Any]:
44
+ return self._stats.items()
45
+
46
+ def filter(self, keep: torch.Tensor) -> None:
47
+ for k, v in self._stats.items():
48
+ if v is None:
49
+ self._stats[k] = None
50
+ elif isinstance(v, torch.Tensor):
51
+ self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
52
+ elif isinstance(v, np.ndarray):
53
+ self._stats[k] = v[keep.detach().cpu().numpy()]
54
+ elif isinstance(v, list) and keep.dtype == torch.bool:
55
+ self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
56
+ elif isinstance(v, list):
57
+ self._stats[k] = [v[i] for i in keep]
58
+ else:
59
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
60
+
61
+ def cat(self, new_stats: "MaskData") -> None:
62
+ for k, v in new_stats.items():
63
+ if k not in self._stats or self._stats[k] is None:
64
+ self._stats[k] = deepcopy(v)
65
+ elif isinstance(v, torch.Tensor):
66
+ self._stats[k] = torch.cat([self._stats[k], v], dim=0)
67
+ elif isinstance(v, np.ndarray):
68
+ self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
69
+ elif isinstance(v, list):
70
+ self._stats[k] = self._stats[k] + deepcopy(v)
71
+ else:
72
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
73
+
74
+ def to_numpy(self) -> None:
75
+ for k, v in self._stats.items():
76
+ if isinstance(v, torch.Tensor):
77
+ self._stats[k] = v.float().detach().cpu().numpy()
78
+
79
+
80
+ def is_box_near_crop_edge(
81
+ boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
82
+ ) -> torch.Tensor:
83
+ """Filter masks at the edge of a crop, but not at the edge of the original image."""
84
+ crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
85
+ orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
86
+ boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
87
+ near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
88
+ near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
89
+ near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
90
+ return torch.any(near_crop_edge, dim=1)
91
+
92
+
93
+ def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
94
+ box_xywh = deepcopy(box_xyxy)
95
+ box_xywh[2] = box_xywh[2] - box_xywh[0]
96
+ box_xywh[3] = box_xywh[3] - box_xywh[1]
97
+ return box_xywh
98
+
99
+
100
+ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
101
+ assert len(args) > 0 and all(
102
+ len(a) == len(args[0]) for a in args
103
+ ), "Batched iteration must have inputs of all the same size."
104
+ n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
105
+ for b in range(n_batches):
106
+ yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
107
+
108
+
109
+ def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
110
+ """
111
+ Encodes masks to an uncompressed RLE, in the format expected by
112
+ pycoco tools.
113
+ """
114
+ # Put in fortran order and flatten h,w
115
+ b, h, w = tensor.shape
116
+ tensor = tensor.permute(0, 2, 1).flatten(1)
117
+
118
+ # Compute change indices
119
+ diff = tensor[:, 1:] ^ tensor[:, :-1]
120
+ change_indices = diff.nonzero()
121
+
122
+ # Encode run length
123
+ out = []
124
+ for i in range(b):
125
+ cur_idxs = change_indices[change_indices[:, 0] == i, 1]
126
+ cur_idxs = torch.cat(
127
+ [
128
+ torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
129
+ cur_idxs + 1,
130
+ torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
131
+ ]
132
+ )
133
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
134
+ counts = [] if tensor[i, 0] == 0 else [0]
135
+ counts.extend(btw_idxs.detach().cpu().tolist())
136
+ out.append({"size": [h, w], "counts": counts})
137
+ return out
138
+
139
+
140
+ def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
141
+ """Compute a binary mask from an uncompressed RLE."""
142
+ h, w = rle["size"]
143
+ mask = np.empty(h * w, dtype=bool)
144
+ idx = 0
145
+ parity = False
146
+ for count in rle["counts"]:
147
+ mask[idx : idx + count] = parity
148
+ idx += count
149
+ parity ^= True
150
+ mask = mask.reshape(w, h)
151
+ return mask.transpose() # Put in C order
152
+
153
+
154
+ def area_from_rle(rle: Dict[str, Any]) -> int:
155
+ return sum(rle["counts"][1::2])
156
+
157
+
158
+ def calculate_stability_score(
159
+ masks: torch.Tensor, mask_threshold: float, threshold_offset: float
160
+ ) -> torch.Tensor:
161
+ """
162
+ Computes the stability score for a batch of masks. The stability
163
+ score is the IoU between the binary masks obtained by thresholding
164
+ the predicted mask logits at high and low values.
165
+ """
166
+ # One mask is always contained inside the other.
167
+ # Save memory by preventing unnecessary cast to torch.int64
168
+ intersections = (
169
+ (masks > (mask_threshold + threshold_offset))
170
+ .sum(-1, dtype=torch.int16)
171
+ .sum(-1, dtype=torch.int32)
172
+ )
173
+ unions = (
174
+ (masks > (mask_threshold - threshold_offset))
175
+ .sum(-1, dtype=torch.int16)
176
+ .sum(-1, dtype=torch.int32)
177
+ )
178
+ return intersections / unions
179
+
180
+
181
+ def build_point_grid(n_per_side: int) -> np.ndarray:
182
+ """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
183
+ offset = 1 / (2 * n_per_side)
184
+ points_one_side = np.linspace(offset, 1 - offset, n_per_side)
185
+ points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
186
+ points_y = np.tile(points_one_side[:, None], (1, n_per_side))
187
+ points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
188
+ return points
189
+
190
+
191
+ def build_all_layer_point_grids(
192
+ n_per_side: int, n_layers: int, scale_per_layer: int
193
+ ) -> List[np.ndarray]:
194
+ """Generates point grids for all crop layers."""
195
+ points_by_layer = []
196
+ for i in range(n_layers + 1):
197
+ n_points = int(n_per_side / (scale_per_layer**i))
198
+ points_by_layer.append(build_point_grid(n_points))
199
+ return points_by_layer
200
+
201
+
202
+ def generate_crop_boxes(
203
+ im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
204
+ ) -> Tuple[List[List[int]], List[int]]:
205
+ """
206
+ Generates a list of crop boxes of different sizes. Each layer
207
+ has (2**i)**2 boxes for the ith layer.
208
+ """
209
+ crop_boxes, layer_idxs = [], []
210
+ im_h, im_w = im_size
211
+ short_side = min(im_h, im_w)
212
+
213
+ # Original image
214
+ crop_boxes.append([0, 0, im_w, im_h])
215
+ layer_idxs.append(0)
216
+
217
+ def crop_len(orig_len, n_crops, overlap):
218
+ return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
219
+
220
+ for i_layer in range(n_layers):
221
+ n_crops_per_side = 2 ** (i_layer + 1)
222
+ overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
223
+
224
+ crop_w = crop_len(im_w, n_crops_per_side, overlap)
225
+ crop_h = crop_len(im_h, n_crops_per_side, overlap)
226
+
227
+ crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
228
+ crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
229
+
230
+ # Crops in XYWH format
231
+ for x0, y0 in product(crop_box_x0, crop_box_y0):
232
+ box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
233
+ crop_boxes.append(box)
234
+ layer_idxs.append(i_layer + 1)
235
+
236
+ return crop_boxes, layer_idxs
237
+
238
+
239
+ def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
240
+ x0, y0, _, _ = crop_box
241
+ offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
242
+ # Check if boxes has a channel dimension
243
+ if len(boxes.shape) == 3:
244
+ offset = offset.unsqueeze(1)
245
+ return boxes + offset
246
+
247
+
248
+ def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
249
+ x0, y0, _, _ = crop_box
250
+ offset = torch.tensor([[x0, y0]], device=points.device)
251
+ # Check if points has a channel dimension
252
+ if len(points.shape) == 3:
253
+ offset = offset.unsqueeze(1)
254
+ return points + offset
255
+
256
+
257
+ def uncrop_masks(
258
+ masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
259
+ ) -> torch.Tensor:
260
+ x0, y0, x1, y1 = crop_box
261
+ if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
262
+ return masks
263
+ # Coordinate transform masks
264
+ pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
265
+ pad = (x0, pad_x - x0, y0, pad_y - y0)
266
+ return torch.nn.functional.pad(masks, pad, value=0)
267
+
268
+
269
+ def remove_small_regions(
270
+ mask: np.ndarray, area_thresh: float, mode: str
271
+ ) -> Tuple[np.ndarray, bool]:
272
+ """
273
+ Removes small disconnected regions and holes in a mask. Returns the
274
+ mask and an indicator of if the mask has been modified.
275
+ """
276
+ import cv2 # type: ignore
277
+
278
+ assert mode in ["holes", "islands"]
279
+ correct_holes = mode == "holes"
280
+ working_mask = (correct_holes ^ mask).astype(np.uint8)
281
+ n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
282
+ sizes = stats[:, -1][1:] # Row 0 is background label
283
+ small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
284
+ if len(small_regions) == 0:
285
+ return mask, False
286
+ fill_labels = [0] + small_regions
287
+ if not correct_holes:
288
+ fill_labels = [i for i in range(n_labels) if i not in fill_labels]
289
+ # If every region is below threshold, keep largest
290
+ if len(fill_labels) == 0:
291
+ fill_labels = [int(np.argmax(sizes)) + 1]
292
+ mask = np.isin(regions, fill_labels)
293
+ return mask, True
294
+
295
+
296
+ def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
297
+ from pycocotools import mask as mask_utils # type: ignore
298
+
299
+ h, w = uncompressed_rle["size"]
300
+ rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
301
+ rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
302
+ return rle
303
+
304
+
305
+ def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
306
+ """
307
+ Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
308
+ an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
309
+ """
310
+ # torch.max below raises an error on empty inputs, just skip in this case
311
+ if torch.numel(masks) == 0:
312
+ return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
313
+
314
+ # Normalize shape to CxHxW
315
+ shape = masks.shape
316
+ h, w = shape[-2:]
317
+ if len(shape) > 2:
318
+ masks = masks.flatten(0, -3)
319
+ else:
320
+ masks = masks.unsqueeze(0)
321
+
322
+ # Get top and bottom edges
323
+ in_height, _ = torch.max(masks, dim=-1)
324
+ in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
325
+ bottom_edges, _ = torch.max(in_height_coords, dim=-1)
326
+ in_height_coords = in_height_coords + h * (~in_height)
327
+ top_edges, _ = torch.min(in_height_coords, dim=-1)
328
+
329
+ # Get left and right edges
330
+ in_width, _ = torch.max(masks, dim=-2)
331
+ in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
332
+ right_edges, _ = torch.max(in_width_coords, dim=-1)
333
+ in_width_coords = in_width_coords + w * (~in_width)
334
+ left_edges, _ = torch.min(in_width_coords, dim=-1)
335
+
336
+ # If the mask is empty the right edge will be to the left of the left edge.
337
+ # Replace these boxes with [0, 0, 0, 0]
338
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
339
+ out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
340
+ out = out * (~empty_filter).unsqueeze(-1)
341
+
342
+ # Return to original shape
343
+ if len(shape) > 2:
344
+ out = out.reshape(*shape[:-2], 4)
345
+ else:
346
+ out = out[0]
347
+
348
+ return out
segment-anything-2/sam2/utils/misc.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import warnings
9
+ from threading import Thread
10
+
11
+ import numpy as np
12
+ import torch
13
+ from PIL import Image
14
+ from tqdm import tqdm
15
+
16
+
17
+ def get_sdpa_settings():
18
+ if torch.cuda.is_available():
19
+ old_gpu = torch.cuda.get_device_properties(0).major < 7
20
+ # only use Flash Attention on Ampere (8.0) or newer GPUs
21
+ use_flash_attn = torch.cuda.get_device_properties(0).major >= 8
22
+ if not use_flash_attn:
23
+ warnings.warn(
24
+ "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.",
25
+ category=UserWarning,
26
+ stacklevel=2,
27
+ )
28
+ # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only
29
+ # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases)
30
+ pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2])
31
+ if pytorch_version < (2, 2):
32
+ warnings.warn(
33
+ f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. "
34
+ "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).",
35
+ category=UserWarning,
36
+ stacklevel=2,
37
+ )
38
+ math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn
39
+ else:
40
+ old_gpu = True
41
+ use_flash_attn = False
42
+ math_kernel_on = True
43
+
44
+ return old_gpu, use_flash_attn, math_kernel_on
45
+
46
+
47
+ def get_connected_components(mask):
48
+ """
49
+ Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W).
50
+
51
+ Inputs:
52
+ - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is
53
+ background.
54
+
55
+ Outputs:
56
+ - labels: A tensor of shape (N, 1, H, W) containing the connected component labels
57
+ for foreground pixels and 0 for background pixels.
58
+ - counts: A tensor of shape (N, 1, H, W) containing the area of the connected
59
+ components for foreground pixels and 0 for background pixels.
60
+ """
61
+ from sam2 import _C
62
+
63
+ return _C.get_connected_componnets(mask.to(torch.uint8).contiguous())
64
+
65
+
66
+ def mask_to_box(masks: torch.Tensor):
67
+ """
68
+ compute bounding box given an input mask
69
+
70
+ Inputs:
71
+ - masks: [B, 1, H, W] boxes, dtype=torch.Tensor
72
+
73
+ Returns:
74
+ - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor
75
+ """
76
+ B, _, h, w = masks.shape
77
+ device = masks.device
78
+ xs = torch.arange(w, device=device, dtype=torch.int32)
79
+ ys = torch.arange(h, device=device, dtype=torch.int32)
80
+ grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy")
81
+ grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w)
82
+ grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w)
83
+ min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1)
84
+ max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1)
85
+ min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1)
86
+ max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1)
87
+ bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1)
88
+
89
+ return bbox_coords
90
+
91
+
92
+ def _load_img_as_tensor(img_path, image_size):
93
+ img_pil = Image.open(img_path)
94
+ img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size)))
95
+ if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images
96
+ img_np = img_np / 255.0
97
+ else:
98
+ raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}")
99
+ img = torch.from_numpy(img_np).permute(2, 0, 1)
100
+ video_width, video_height = img_pil.size # the original video size
101
+ return img, video_height, video_width
102
+
103
+
104
+ class AsyncVideoFrameLoader:
105
+ """
106
+ A list of video frames to be load asynchronously without blocking session start.
107
+ """
108
+
109
+ def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std):
110
+ self.img_paths = img_paths
111
+ self.image_size = image_size
112
+ self.offload_video_to_cpu = offload_video_to_cpu
113
+ self.img_mean = img_mean
114
+ self.img_std = img_std
115
+ # items in `self._images` will be loaded asynchronously
116
+ self.images = [None] * len(img_paths)
117
+ # catch and raise any exceptions in the async loading thread
118
+ self.exception = None
119
+ # video_height and video_width be filled when loading the first image
120
+ self.video_height = None
121
+ self.video_width = None
122
+
123
+ # load the first frame to fill video_height and video_width and also
124
+ # to cache it (since it's most likely where the user will click)
125
+ self.__getitem__(0)
126
+
127
+ # load the rest of frames asynchronously without blocking the session start
128
+ def _load_frames():
129
+ try:
130
+ for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"):
131
+ self.__getitem__(n)
132
+ except Exception as e:
133
+ self.exception = e
134
+
135
+ self.thread = Thread(target=_load_frames, daemon=True)
136
+ self.thread.start()
137
+
138
+ def __getitem__(self, index):
139
+ if self.exception is not None:
140
+ raise RuntimeError("Failure in frame loading thread") from self.exception
141
+
142
+ img = self.images[index]
143
+ if img is not None:
144
+ return img
145
+
146
+ img, video_height, video_width = _load_img_as_tensor(
147
+ self.img_paths[index], self.image_size
148
+ )
149
+ self.video_height = video_height
150
+ self.video_width = video_width
151
+ # normalize by mean and std
152
+ img -= self.img_mean
153
+ img /= self.img_std
154
+ if not self.offload_video_to_cpu:
155
+ img = img.cuda(non_blocking=True)
156
+ self.images[index] = img
157
+ return img
158
+
159
+ def __len__(self):
160
+ return len(self.images)
161
+
162
+
163
+ def load_video_frames(
164
+ video_path,
165
+ image_size,
166
+ offload_video_to_cpu,
167
+ img_mean=(0.485, 0.456, 0.406),
168
+ img_std=(0.229, 0.224, 0.225),
169
+ async_loading_frames=False,
170
+ ):
171
+ """
172
+ Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
173
+
174
+ The frames are resized to image_size x image_size and are loaded to GPU if
175
+ `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`.
176
+
177
+ You can load a frame asynchronously by setting `async_loading_frames` to `True`.
178
+ """
179
+ if isinstance(video_path, str) and os.path.isdir(video_path):
180
+ jpg_folder = video_path
181
+ else:
182
+ raise NotImplementedError("Only JPEG frames are supported at this moment")
183
+
184
+ frame_names = [
185
+ p
186
+ for p in os.listdir(jpg_folder)
187
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
188
+ ]
189
+ frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
190
+ num_frames = len(frame_names)
191
+ if num_frames == 0:
192
+ raise RuntimeError(f"no images found in {jpg_folder}")
193
+ img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names]
194
+ img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
195
+ img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
196
+
197
+ if async_loading_frames:
198
+ lazy_images = AsyncVideoFrameLoader(
199
+ img_paths, image_size, offload_video_to_cpu, img_mean, img_std
200
+ )
201
+ return lazy_images, lazy_images.video_height, lazy_images.video_width
202
+
203
+ images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
204
+ for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
205
+ images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
206
+ if not offload_video_to_cpu:
207
+ images = images.cuda()
208
+ img_mean = img_mean.cuda()
209
+ img_std = img_std.cuda()
210
+ # normalize by mean and std
211
+ images -= img_mean
212
+ images /= img_std
213
+ return images, video_height, video_width
214
+
215
+
216
+ def fill_holes_in_mask_scores(mask, max_area):
217
+ """
218
+ A post processor to fill small holes in mask scores with area under `max_area`.
219
+ """
220
+ # Holes are those connected components in background with area <= self.max_area
221
+ # (background regions are those with mask scores <= 0)
222
+ assert max_area > 0, "max_area must be positive"
223
+ labels, areas = get_connected_components(mask <= 0)
224
+ is_hole = (labels > 0) & (areas <= max_area)
225
+ # We fill holes with a small positive mask score (0.1) to change them to foreground.
226
+ mask = torch.where(is_hole, 0.1, mask)
227
+ return mask
228
+
229
+
230
+ def concat_points(old_point_inputs, new_points, new_labels):
231
+ """Add new points and labels to previous point inputs (add at the end)."""
232
+ if old_point_inputs is None:
233
+ points, labels = new_points, new_labels
234
+ else:
235
+ points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1)
236
+ labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1)
237
+
238
+ return {"point_coords": points, "point_labels": labels}
segment-anything-2/sam2/utils/transforms.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torchvision.transforms import Normalize, Resize, ToTensor
11
+
12
+
13
+ class SAM2Transforms(nn.Module):
14
+ def __init__(
15
+ self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0
16
+ ):
17
+ """
18
+ Transforms for SAM2.
19
+ """
20
+ super().__init__()
21
+ self.resolution = resolution
22
+ self.mask_threshold = mask_threshold
23
+ self.max_hole_area = max_hole_area
24
+ self.max_sprinkle_area = max_sprinkle_area
25
+ self.mean = [0.485, 0.456, 0.406]
26
+ self.std = [0.229, 0.224, 0.225]
27
+ self.to_tensor = ToTensor()
28
+ self.transforms = torch.jit.script(
29
+ nn.Sequential(
30
+ Resize((self.resolution, self.resolution)),
31
+ Normalize(self.mean, self.std),
32
+ )
33
+ )
34
+
35
+ def __call__(self, x):
36
+ x = self.to_tensor(x)
37
+ return self.transforms(x)
38
+
39
+ def forward_batch(self, img_list):
40
+ img_batch = [self.transforms(self.to_tensor(img)) for img in img_list]
41
+ img_batch = torch.stack(img_batch, dim=0)
42
+ return img_batch
43
+
44
+ def transform_coords(
45
+ self, coords: torch.Tensor, normalize=False, orig_hw=None
46
+ ) -> torch.Tensor:
47
+ """
48
+ Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates,
49
+ If the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
50
+
51
+ Returns
52
+ Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model.
53
+ """
54
+ if normalize:
55
+ assert orig_hw is not None
56
+ h, w = orig_hw
57
+ coords = coords.clone()
58
+ coords[..., 0] = coords[..., 0] / w
59
+ coords[..., 1] = coords[..., 1] / h
60
+
61
+ coords = coords * self.resolution # unnormalize coords
62
+ return coords
63
+
64
+ def transform_boxes(
65
+ self, boxes: torch.Tensor, normalize=False, orig_hw=None
66
+ ) -> torch.Tensor:
67
+ """
68
+ Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates,
69
+ if the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
70
+ """
71
+ boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw)
72
+ return boxes
73
+
74
+ def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
75
+ """
76
+ Perform PostProcessing on output masks.
77
+ """
78
+ from sam2.utils.misc import get_connected_components
79
+
80
+ masks = masks.float()
81
+ if self.max_hole_area > 0:
82
+ # Holes are those connected components in background with area <= self.fill_hole_area
83
+ # (background regions are those with mask scores <= self.mask_threshold)
84
+ mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
85
+ labels, areas = get_connected_components(mask_flat <= self.mask_threshold)
86
+ is_hole = (labels > 0) & (areas <= self.max_hole_area)
87
+ is_hole = is_hole.reshape_as(masks)
88
+ # We fill holes with a small positive mask score (10.0) to change them to foreground.
89
+ masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
90
+
91
+ if self.max_sprinkle_area > 0:
92
+ labels, areas = get_connected_components(mask_flat > self.mask_threshold)
93
+ is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
94
+ is_hole = is_hole.reshape_as(masks)
95
+ # We fill holes with negative mask score (-10.0) to change them to background.
96
+ masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
97
+
98
+ masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
99
+ return masks
segment-anything-2/sam2_configs/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
segment-anything-2/sam2_configs/sam2_hiera_l.yaml ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 144
12
+ num_heads: 2
13
+ stages: [2, 6, 36, 4]
14
+ global_att_blocks: [23, 33, 43]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ window_spec: [8, 4, 16, 8]
17
+ neck:
18
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
19
+ position_encoding:
20
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
21
+ num_pos_feats: 256
22
+ normalize: true
23
+ scale: null
24
+ temperature: 10000
25
+ d_model: 256
26
+ backbone_channel_list: [1152, 576, 288, 144]
27
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
28
+ fpn_interp_model: nearest
29
+
30
+ memory_attention:
31
+ _target_: sam2.modeling.memory_attention.MemoryAttention
32
+ d_model: 256
33
+ pos_enc_at_input: true
34
+ layer:
35
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
36
+ activation: relu
37
+ dim_feedforward: 2048
38
+ dropout: 0.1
39
+ pos_enc_at_attn: false
40
+ self_attention:
41
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
42
+ rope_theta: 10000.0
43
+ feat_sizes: [32, 32]
44
+ embedding_dim: 256
45
+ num_heads: 1
46
+ downsample_rate: 1
47
+ dropout: 0.1
48
+ d_model: 256
49
+ pos_enc_at_cross_attn_keys: true
50
+ pos_enc_at_cross_attn_queries: false
51
+ cross_attention:
52
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
53
+ rope_theta: 10000.0
54
+ feat_sizes: [32, 32]
55
+ rope_k_repeat: True
56
+ embedding_dim: 256
57
+ num_heads: 1
58
+ downsample_rate: 1
59
+ dropout: 0.1
60
+ kv_in_dim: 64
61
+ num_layers: 4
62
+
63
+ memory_encoder:
64
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
65
+ out_dim: 64
66
+ position_encoding:
67
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
68
+ num_pos_feats: 64
69
+ normalize: true
70
+ scale: null
71
+ temperature: 10000
72
+ mask_downsampler:
73
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
74
+ kernel_size: 3
75
+ stride: 2
76
+ padding: 1
77
+ fuser:
78
+ _target_: sam2.modeling.memory_encoder.Fuser
79
+ layer:
80
+ _target_: sam2.modeling.memory_encoder.CXBlock
81
+ dim: 256
82
+ kernel_size: 7
83
+ padding: 3
84
+ layer_scale_init_value: 1e-6
85
+ use_dwconv: True # depth-wise convs
86
+ num_layers: 2
87
+
88
+ num_maskmem: 7
89
+ image_size: 1024
90
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ # use high-resolution feature map in the SAM mask decoder
97
+ use_high_res_features_in_sam: true
98
+ # output 3 masks on the first click on initial conditioning frames
99
+ multimask_output_in_sam: true
100
+ # SAM heads
101
+ iou_prediction_use_sigmoid: True
102
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103
+ use_obj_ptrs_in_encoder: true
104
+ add_tpos_enc_to_obj_ptrs: false
105
+ only_obj_ptrs_in_the_past_for_eval: true
106
+ # object occlusion prediction
107
+ pred_obj_scores: true
108
+ pred_obj_scores_mlp: true
109
+ fixed_no_obj_ptr: true
110
+ # multimask tracking settings
111
+ multimask_output_for_tracking: true
112
+ use_multimask_token_for_obj_ptr: true
113
+ multimask_min_pt_num: 0
114
+ multimask_max_pt_num: 1
115
+ use_mlp_for_obj_ptr_proj: true
116
+ # Compilation flag
117
+ compile_image_encoder: False
segment-anything-2/sam2_configs/sam2_hiera_s.yaml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 11, 2]
14
+ global_att_blocks: [7, 10, 13]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [32, 32]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [32, 32]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ sigmoid_scale_for_mem_enc: 20.0
91
+ sigmoid_bias_for_mem_enc: -10.0
92
+ use_mask_input_as_output_without_sam: true
93
+ # Memory
94
+ directly_add_no_mem_embed: true
95
+ # use high-resolution feature map in the SAM mask decoder
96
+ use_high_res_features_in_sam: true
97
+ # output 3 masks on the first click on initial conditioning frames
98
+ multimask_output_in_sam: true
99
+ # SAM heads
100
+ iou_prediction_use_sigmoid: True
101
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
102
+ use_obj_ptrs_in_encoder: true
103
+ add_tpos_enc_to_obj_ptrs: false
104
+ only_obj_ptrs_in_the_past_for_eval: true
105
+ # object occlusion prediction
106
+ pred_obj_scores: true
107
+ pred_obj_scores_mlp: true
108
+ fixed_no_obj_ptr: true
109
+ # multimask tracking settings
110
+ multimask_output_for_tracking: true
111
+ use_multimask_token_for_obj_ptr: true
112
+ multimask_min_pt_num: 0
113
+ multimask_max_pt_num: 1
114
+ use_mlp_for_obj_ptr_proj: true
115
+ # Compilation flag
116
+ compile_image_encoder: False
segment-anything-2/sam2_configs/sam2_hiera_t.yaml ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 7, 2]
14
+ global_att_blocks: [5, 7, 9]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [32, 32]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [32, 32]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ # SAM decoder
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ # use high-resolution feature map in the SAM mask decoder
97
+ use_high_res_features_in_sam: true
98
+ # output 3 masks on the first click on initial conditioning frames
99
+ multimask_output_in_sam: true
100
+ # SAM heads
101
+ iou_prediction_use_sigmoid: True
102
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103
+ use_obj_ptrs_in_encoder: true
104
+ add_tpos_enc_to_obj_ptrs: false
105
+ only_obj_ptrs_in_the_past_for_eval: true
106
+ # object occlusion prediction
107
+ pred_obj_scores: true
108
+ pred_obj_scores_mlp: true
109
+ fixed_no_obj_ptr: true
110
+ # multimask tracking settings
111
+ multimask_output_for_tracking: true
112
+ use_multimask_token_for_obj_ptr: true
113
+ multimask_min_pt_num: 0
114
+ multimask_max_pt_num: 1
115
+ use_mlp_for_obj_ptr_proj: true
116
+ # Compilation flag
117
+ # HieraT does not currently support compilation, should always be set to False
118
+ compile_image_encoder: False
segment-anything-2/sav_dataset/LICENSE ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD License
2
+
3
+ For SAM 2 Eval software
4
+
5
+ Copyright (c) Meta Platforms, Inc. and affiliates.
6
+
7
+ Redistribution and use in source and binary forms, with or without modification,
8
+ are permitted provided that the following conditions are met:
9
+
10
+ * Redistributions of source code must retain the above copyright notice, this
11
+ list of conditions and the following disclaimer.
12
+
13
+ * Redistributions in binary form must reproduce the above copyright notice,
14
+ this list of conditions and the following disclaimer in the documentation
15
+ and/or other materials provided with the distribution.
16
+
17
+ * Neither the name Meta nor the names of its contributors may be used to
18
+ endorse or promote products derived from this software without specific
19
+ prior written permission.
20
+
21
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
28
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
segment-anything-2/sav_dataset/LICENSE_DAVIS ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2020, DAVIS: Densely Annotated VIdeo Segmentation
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ 3. Neither the name of the copyright holder nor the names of its
17
+ contributors may be used to endorse or promote products derived from
18
+ this software without specific prior written permission.
19
+
20
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
segment-anything-2/sav_dataset/LICENSE_VOS_BENCHMARK ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Copyright 2023 Rex Cheng
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4
+
5
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6
+
7
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
segment-anything-2/sav_dataset/README.md ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Segment Anything Video (SA-V) Dataset
2
+
3
+ ## Overview
4
+
5
+ [Segment Anything Video (SA-V)](https://ai.meta.com/datasets/segment-anything-video/), consists of 51K diverse videos and 643K high-quality spatio-temporal segmentation masks (i.e., masklets). The dataset is released under the CC by 4.0 license. Browse the dataset [here](https://sam2.metademolab.com/dataset).
6
+
7
+ ![SA-V dataset](../assets/sa_v_dataset.jpg?raw=true)
8
+
9
+ ## Getting Started
10
+
11
+ ### Download the dataset
12
+
13
+ Visit [here](https://ai.meta.com/datasets/segment-anything-video-downloads/) to download SA-V including the training, val and test sets.
14
+
15
+ ### Dataset Stats
16
+
17
+ | | Num Videos | Num Masklets |
18
+ | ---------- | ---------- | ----------------------------------------- |
19
+ | SA-V train | 50,583 | 642,036 (auto 451,720 and manual 190,316) |
20
+ | SA-V val | 155 | 293 |
21
+ | SA-V test | 150 | 278 |
22
+
23
+ ### Notebooks
24
+
25
+ To load and visualize the SA-V training set annotations, refer to the example [sav_visualization_example.ipynb](./sav_visualization_example.ipynb) notebook.
26
+
27
+ ### SA-V train
28
+
29
+ For SA-V training set we release the mp4 videos and store the masklet annotations per video as json files . Automatic masklets and manual masklets are stored separately as two json files: `{video_id}_auto.json` and `{video_id}_manual.json`. They can be loaded as dictionaries in python in the format below.
30
+
31
+ ```
32
+ {
33
+ "video_id" : str; video id
34
+ "video_duration" : float64; the duration in seconds of this video
35
+ "video_frame_count" : float64; the number of frames in the video
36
+ "video_height" : float64; the height of the video
37
+ "video_width" : float64; the width of the video
38
+ "video_resolution" : float64; video_height $\times$ video_width
39
+ "video_environment" : List[str]; "Indoor" or "Outdoor"
40
+ "video_split" : str; "train" for training set
41
+ "masklet" : List[List[Dict]]; masklet annotations in list of list of RLEs.
42
+ The outer list is over frames in the video and the inner list
43
+ is over objects in the video.
44
+ "masklet_id" : List[int]; the masklet ids
45
+ "masklet_size_rel" : List[float]; the average mask area normalized by resolution
46
+ across all the frames where the object is visible
47
+ "masklet_size_abs" : List[float]; the average mask area (in pixels)
48
+ across all the frames where the object is visible
49
+ "masklet_size_bucket" : List[str]; "small": $1$ <= masklet_size_abs < $32^2$,
50
+ "medium": $32^2$ <= masklet_size_abs < $96^2$,
51
+ and "large": masklet_size_abs > $96^2$
52
+ "masklet_visibility_changes" : List[int]; the number of times where the visibility changes
53
+ after the first appearance (e.g., invisible -> visible
54
+ or visible -> invisible)
55
+ "masklet_first_appeared_frame" : List[int]; the index of the frame where the object appears
56
+ the first time in the video. Always 0 for auto masklets.
57
+ "masklet_frame_count" : List[int]; the number of frames being annotated. Note that
58
+ videos are annotated at 6 fps (annotated every 4 frames)
59
+ while the videos are at 24 fps.
60
+ "masklet_edited_frame_count" : List[int]; the number of frames being edited by human annotators.
61
+ Always 0 for auto masklets.
62
+ "masklet_type" : List[str]; "auto" or "manual"
63
+ "masklet_stability_score" : Optional[List[List[float]]]; per-mask stability scores. Auto annotation only.
64
+ "masklet_num" : int; the number of manual/auto masklets in the video
65
+
66
+ }
67
+ ```
68
+
69
+ Note that in SA-V train, there are in total 50,583 videos where all of them have manual annotations. Among the 50,583 videos there are 48,436 videos that also have automatic annotations.
70
+
71
+ ### SA-V val and test
72
+
73
+ For SA-V val and test sets, we release the extracted frames as jpeg files, and the masks as png files with the following directory structure:
74
+
75
+ ```
76
+ sav_val(sav_test)
77
+ ├── sav_val.txt (sav_test.txt): a list of video ids in the split
78
+ ├── JPEGImages_24fps # videos are extracted at 24 fps
79
+ │ ├── {video_id}
80
+ │ │ ├── 00000.jpg # video frame
81
+ │ │ ├── 00001.jpg # video frame
82
+ │ │ ├── 00002.jpg # video frame
83
+ │ │ ├── 00003.jpg # video frame
84
+ │ │ └── ...
85
+ │ ├── {video_id}
86
+ │ ├── {video_id}
87
+ │ └── ...
88
+ └── Annotations_6fps # videos are annotated at 6 fps
89
+ ├── {video_id}
90
+ │ ├── 000 # obj 000
91
+ │ │ ├── 00000.png # mask for object 000 in 00000.jpg
92
+ │ │ ├── 00004.png # mask for object 000 in 00004.jpg
93
+ │ │ ├── 00008.png # mask for object 000 in 00008.jpg
94
+ │ │ ├── 00012.png # mask for object 000 in 00012.jpg
95
+ │ │ └── ...
96
+ │ ├── 001 # obj 001
97
+ │ ├── 002 # obj 002
98
+ │ └── ...
99
+ ├── {video_id}
100
+ ├── {video_id}
101
+ └── ...
102
+ ```
103
+
104
+ All masklets in val and test sets are manually annotated in every frame by annotators. For each annotated object in a video, we store the annotated masks in a single png. This is because the annotated objects may overlap, e.g., it is possible in our SA-V dataset for there to be a mask for the whole person as well as a separate mask for their hands.
105
+
106
+ ## SA-V Val and Test Evaluation
107
+
108
+ We provide an evaluator to compute the common J and F metrics on SA-V val and test sets. To run the evaluation, we need to first install a few dependencies as follows:
109
+
110
+ ```
111
+ pip install -r requirements.txt
112
+ ```
113
+
114
+ Then we can evaluate the predictions as follows:
115
+
116
+ ```
117
+ python sav_evaluator.py --gt_root {GT_ROOT} --pred_root {PRED_ROOT}
118
+ ```
119
+
120
+ or run
121
+
122
+ ```
123
+ python sav_evaluator.py --help
124
+ ```
125
+
126
+ to print a complete help message.
127
+
128
+ The evaluator expects the `GT_ROOT` to be one of the following folder structures, and `GT_ROOT` and `PRED_ROOT` to have the same structure.
129
+
130
+ - Same as SA-V val and test directory structure
131
+
132
+ ```
133
+ {GT_ROOT} # gt root folder
134
+ ├── {video_id}
135
+ │ ├── 000 # all masks associated with obj 000
136
+ │ │ ├── 00000.png # mask for object 000 in frame 00000 (binary mask)
137
+ │ │ └── ...
138
+ │ ├── 001 # all masks associated with obj 001
139
+ │ ├── 002 # all masks associated with obj 002
140
+ │ └── ...
141
+ ├── {video_id}
142
+ ├── {video_id}
143
+ └── ...
144
+ ```
145
+
146
+ In the paper for the experiments on SA-V val and test, we run inference on the 24 fps videos, and evaluate on the subset of frames where we have ground truth annotations (first and last annotated frames dropped). The evaluator will ignore the masks in frames where we don't have ground truth annotations.
147
+
148
+ - Same as [DAVIS](https://github.com/davisvideochallenge/davis2017-evaluation) directory structure
149
+
150
+ ```
151
+ {GT_ROOT} # gt root folder
152
+ ├── {video_id}
153
+ │ ├── 00000.png # annotations in frame 00000 (may contain multiple objects)
154
+ │ └── ...
155
+ ├── {video_id}
156
+ ├── {video_id}
157
+ └── ...
158
+ ```
159
+
160
+ ## License
161
+
162
+ The evaluation code is licensed under the [BSD 3 license](./LICENSE). Please refer to the paper for more details on the models. The videos and annotations in SA-V Dataset are released under CC BY 4.0.
163
+
164
+ Third-party code: the evaluation software is heavily adapted from [`VOS-Benchmark`](https://github.com/hkchengrex/vos-benchmark) and [`DAVIS`](https://github.com/davisvideochallenge/davis2017-evaluation) (with their licenses in [`LICENSE_DAVIS`](./LICENSE_DAVIS) and [`LICENSE_VOS_BENCHMARK`](./LICENSE_VOS_BENCHMARK)).
segment-anything-2/sav_dataset/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ pycocoevalcap
2
+ scikit-image
3
+ opencv-python
4
+ tqdm
5
+ pillow
6
+ numpy
7
+ matplotlib
segment-anything-2/sav_dataset/sav_evaluator.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the sav_dataset directory of this source tree.
6
+
7
+ # adapted from https://github.com/hkchengrex/vos-benchmark
8
+ # and https://github.com/davisvideochallenge/davis2017-evaluation
9
+ # with their licenses found in the LICENSE_VOS_BENCHMARK and LICENSE_DAVIS files
10
+ # in the sav_dataset directory.
11
+ from argparse import ArgumentParser
12
+
13
+ from utils.sav_benchmark import benchmark
14
+
15
+ """
16
+ The structure of the {GT_ROOT} can be either of the follow two structures.
17
+ {GT_ROOT} and {PRED_ROOT} should be of the same format
18
+
19
+ 1. SA-V val/test structure
20
+ {GT_ROOT} # gt root folder
21
+ ├── {video_id}
22
+ │ ├── 000 # all masks associated with obj 000
23
+ │ │ ├── {frame_id}.png # mask for object 000 in {frame_id} (binary mask)
24
+ │ │ └── ...
25
+ │ ├── 001 # all masks associated with obj 001
26
+ │ ├── 002 # all masks associated with obj 002
27
+ │ └── ...
28
+ ├── {video_id}
29
+ ├── {video_id}
30
+ └── ...
31
+
32
+ 2. Similar to DAVIS structure:
33
+
34
+ {GT_ROOT} # gt root folder
35
+ ├── {video_id}
36
+ │ ├── {frame_id}.png # annotation in {frame_id} (may contain multiple objects)
37
+ │ └── ...
38
+ ├── {video_id}
39
+ ├── {video_id}
40
+ └── ...
41
+ """
42
+
43
+
44
+ parser = ArgumentParser()
45
+ parser.add_argument(
46
+ "--gt_root",
47
+ required=True,
48
+ help="Path to the GT folder. For SA-V, it's sav_val/Annotations_6fps or sav_test/Annotations_6fps",
49
+ )
50
+ parser.add_argument(
51
+ "--pred_root",
52
+ required=True,
53
+ help="Path to a folder containing folders of masks to be evaluated, with exactly the same structure as gt_root",
54
+ )
55
+ parser.add_argument(
56
+ "-n", "--num_processes", default=16, type=int, help="Number of concurrent processes"
57
+ )
58
+ parser.add_argument(
59
+ "-s",
60
+ "--strict",
61
+ help="Make sure every video in the gt_root folder has a corresponding video in the prediction",
62
+ action="store_true",
63
+ )
64
+ parser.add_argument(
65
+ "-q",
66
+ "--quiet",
67
+ help="Quietly run evaluation without printing the information out",
68
+ action="store_true",
69
+ )
70
+
71
+ # https://github.com/davisvideochallenge/davis2017-evaluation/blob/d34fdef71ce3cb24c1a167d860b707e575b3034c/davis2017/evaluation.py#L85
72
+ parser.add_argument(
73
+ "--do_not_skip_first_and_last_frame",
74
+ help="In SA-V val and test, we skip the first and the last annotated frames in evaluation. "
75
+ "Set this to true for evaluation on settings that doen't skip first and last frames",
76
+ action="store_true",
77
+ )
78
+
79
+
80
+ if __name__ == "__main__":
81
+ args = parser.parse_args()
82
+ benchmark(
83
+ [args.gt_root],
84
+ [args.pred_root],
85
+ args.strict,
86
+ args.num_processes,
87
+ verbose=not args.quiet,
88
+ skip_first_and_last=not args.do_not_skip_first_and_last_frame,
89
+ )
segment-anything-2/sav_dataset/sav_visualization_example.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
segment-anything-2/sav_dataset/utils/sav_benchmark.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the sav_dataset directory of this source tree.
6
+
7
+ # adapted from https://github.com/hkchengrex/vos-benchmark
8
+ # and https://github.com/davisvideochallenge/davis2017-evaluation
9
+ # with their licenses found in the LICENSE_VOS_BENCHMARK and LICENSE_DAVIS files
10
+ # in the sav_dataset directory.
11
+ import math
12
+ import os
13
+ import time
14
+ from collections import defaultdict
15
+ from multiprocessing import Pool
16
+ from os import path
17
+ from typing import Dict, List, Tuple
18
+
19
+ import cv2
20
+ import numpy as np
21
+ import tqdm
22
+ from PIL import Image
23
+ from skimage.morphology import disk
24
+
25
+
26
+ class VideoEvaluator:
27
+ def __init__(self, gt_root, pred_root, skip_first_and_last=True) -> None:
28
+ """
29
+ gt_root: path to the folder storing the gt masks
30
+ pred_root: path to the folder storing the predicted masks
31
+ skip_first_and_last: whether we should skip the evaluation of the first and the last frame.
32
+ True for SA-V val and test, same as in DAVIS semi-supervised evaluation.
33
+ """
34
+ self.gt_root = gt_root
35
+ self.pred_root = pred_root
36
+ self.skip_first_and_last = skip_first_and_last
37
+
38
+ def __call__(self, vid_name: str) -> Tuple[str, Dict[str, float], Dict[str, float]]:
39
+ """
40
+ vid_name: name of the video to evaluate
41
+ """
42
+
43
+ # scan the folder to find subfolders for evaluation and
44
+ # check if the folder structure is SA-V
45
+ to_evaluate, is_sav_format = self.scan_vid_folder(vid_name)
46
+
47
+ # evaluate each (gt_path, pred_path) pair
48
+ eval_results = []
49
+ for all_frames, obj_id, gt_path, pred_path in to_evaluate:
50
+ if self.skip_first_and_last:
51
+ # skip the first and the last frames
52
+ all_frames = all_frames[1:-1]
53
+
54
+ evaluator = Evaluator(name=vid_name, obj_id=obj_id)
55
+ for frame in all_frames:
56
+ gt_array, pred_array = self.get_gt_and_pred(
57
+ gt_path, pred_path, frame, is_sav_format
58
+ )
59
+ evaluator.feed_frame(mask=pred_array, gt=gt_array)
60
+
61
+ iou, boundary_f = evaluator.conclude()
62
+ eval_results.append((obj_id, iou, boundary_f))
63
+
64
+ if is_sav_format:
65
+ iou_output, boundary_f_output = self.consolidate(eval_results)
66
+ else:
67
+ assert len(eval_results) == 1
68
+ iou_output = eval_results[0][1]
69
+ boundary_f_output = eval_results[0][2]
70
+
71
+ return vid_name, iou_output, boundary_f_output
72
+
73
+ def get_gt_and_pred(
74
+ self,
75
+ gt_path: str,
76
+ pred_path: str,
77
+ f_name: str,
78
+ is_sav_format: bool,
79
+ ) -> Tuple[np.ndarray, np.ndarray]:
80
+ """
81
+ Get the ground-truth and predicted masks for a single frame.
82
+ """
83
+ gt_mask_path = path.join(gt_path, f_name)
84
+ pred_mask_path = path.join(pred_path, f_name)
85
+ assert os.path.exists(pred_mask_path), f"{pred_mask_path} not found"
86
+
87
+ gt_array = np.array(Image.open(gt_mask_path))
88
+ pred_array = np.array(Image.open(pred_mask_path))
89
+ assert (
90
+ gt_array.shape[-2:] == pred_array.shape[-2:]
91
+ ), f"shape mismatch: {gt_mask_path}, {pred_mask_path}"
92
+
93
+ if is_sav_format:
94
+ assert len(np.unique(gt_array)) <= 2, (
95
+ f"found more than 1 object in {gt_mask_path} "
96
+ "SA-V format assumes one object mask per png file."
97
+ )
98
+ assert len(np.unique(pred_array)) <= 2, (
99
+ f"found more than 1 object in {pred_mask_path} "
100
+ "SA-V format assumes one object mask per png file."
101
+ )
102
+ gt_array = gt_array > 0
103
+ pred_array = pred_array > 0
104
+
105
+ return gt_array, pred_array
106
+
107
+ def scan_vid_folder(self, vid_name) -> Tuple[List, bool]:
108
+ """
109
+ Scan the folder structure of the video and return a list of folders for evaluate.
110
+ """
111
+
112
+ vid_gt_path = path.join(self.gt_root, vid_name)
113
+ vid_pred_path = path.join(self.pred_root, vid_name)
114
+ all_files_and_dirs = sorted(os.listdir(vid_gt_path))
115
+ to_evaluate = []
116
+ if all(name.endswith(".png") for name in all_files_and_dirs):
117
+ # All files are png files, dataset structure similar to DAVIS
118
+ is_sav_format = False
119
+ frames = all_files_and_dirs
120
+ obj_dir = None
121
+ to_evaluate.append((frames, obj_dir, vid_gt_path, vid_pred_path))
122
+ else:
123
+ # SA-V dataset structure, going one layer down into each subdirectory
124
+ is_sav_format = True
125
+ for obj_dir in all_files_and_dirs:
126
+ obj_gt_path = path.join(vid_gt_path, obj_dir)
127
+ obj_pred_path = path.join(vid_pred_path, obj_dir)
128
+ frames = sorted(os.listdir(obj_gt_path))
129
+ to_evaluate.append((frames, obj_dir, obj_gt_path, obj_pred_path))
130
+ return to_evaluate, is_sav_format
131
+
132
+ def consolidate(
133
+ self, eval_results
134
+ ) -> Tuple[str, Dict[str, float], Dict[str, float]]:
135
+ """
136
+ Consolidate the results of all the objects from the video into one dictionary.
137
+ """
138
+ iou_output = {}
139
+ boundary_f_output = {}
140
+ for obj_id, iou, boundary_f in eval_results:
141
+ assert len(iou) == 1
142
+ key = list(iou.keys())[0]
143
+ iou_output[obj_id] = iou[key]
144
+ boundary_f_output[obj_id] = boundary_f[key]
145
+ return iou_output, boundary_f_output
146
+
147
+
148
+ #################################################################################################################
149
+ # Functions below are from https://github.com/hkchengrex/vos-benchmark with minor modifications
150
+ # _seg2bmap from https://github.com/hkchengrex/vos-benchmark/blob/main/vos_benchmark/utils.py
151
+ # get_iou and Evaluator from https://github.com/hkchengrex/vos-benchmark/blob/main/vos_benchmark/evaluator.py
152
+ # benchmark from https://github.com/hkchengrex/vos-benchmark/blob/main/vos_benchmark/benchmark.py with slight mod
153
+ #################################################################################################################
154
+
155
+
156
+ def _seg2bmap(seg, width=None, height=None):
157
+ """
158
+ From a segmentation, compute a binary boundary map with 1 pixel wide
159
+ boundaries. The boundary pixels are offset by 1/2 pixel towards the
160
+ origin from the actual segment boundary.
161
+ Arguments:
162
+ seg : Segments labeled from 1..k.
163
+ width : Width of desired bmap <= seg.shape[1]
164
+ height : Height of desired bmap <= seg.shape[0]
165
+ Returns:
166
+ bmap (ndarray): Binary boundary map.
167
+ David Martin <[email protected]>
168
+ January 2003
169
+ """
170
+
171
+ seg = seg.astype(bool)
172
+ seg[seg > 0] = 1
173
+
174
+ assert np.atleast_3d(seg).shape[2] == 1
175
+
176
+ width = seg.shape[1] if width is None else width
177
+ height = seg.shape[0] if height is None else height
178
+
179
+ h, w = seg.shape[:2]
180
+
181
+ ar1 = float(width) / float(height)
182
+ ar2 = float(w) / float(h)
183
+
184
+ assert not (
185
+ width > w | height > h | abs(ar1 - ar2) > 0.01
186
+ ), "Can" "t convert %dx%d seg to %dx%d bmap." % (w, h, width, height)
187
+
188
+ e = np.zeros_like(seg)
189
+ s = np.zeros_like(seg)
190
+ se = np.zeros_like(seg)
191
+
192
+ e[:, :-1] = seg[:, 1:]
193
+ s[:-1, :] = seg[1:, :]
194
+ se[:-1, :-1] = seg[1:, 1:]
195
+
196
+ b = seg ^ e | seg ^ s | seg ^ se
197
+ b[-1, :] = seg[-1, :] ^ e[-1, :]
198
+ b[:, -1] = seg[:, -1] ^ s[:, -1]
199
+ b[-1, -1] = 0
200
+
201
+ if w == width and h == height:
202
+ bmap = b
203
+ else:
204
+ bmap = np.zeros((height, width))
205
+ for x in range(w):
206
+ for y in range(h):
207
+ if b[y, x]:
208
+ j = 1 + math.floor((y - 1) + height / h)
209
+ i = 1 + math.floor((x - 1) + width / h)
210
+ bmap[j, i] = 1
211
+
212
+ return bmap
213
+
214
+
215
+ def get_iou(intersection, pixel_sum):
216
+ # handle edge cases without resorting to epsilon
217
+ if intersection == pixel_sum:
218
+ # both mask and gt have zero pixels in them
219
+ assert intersection == 0
220
+ return 1
221
+
222
+ return intersection / (pixel_sum - intersection)
223
+
224
+
225
+ class Evaluator:
226
+ def __init__(self, boundary=0.008, name=None, obj_id=None):
227
+ # boundary: used in computing boundary F-score
228
+ self.boundary = boundary
229
+ self.name = name
230
+ self.obj_id = obj_id
231
+ self.objects_in_gt = set()
232
+ self.objects_in_masks = set()
233
+
234
+ self.object_iou = defaultdict(list)
235
+ self.boundary_f = defaultdict(list)
236
+
237
+ def feed_frame(self, mask: np.ndarray, gt: np.ndarray):
238
+ """
239
+ Compute and accumulate metrics for a single frame (mask/gt pair)
240
+ """
241
+
242
+ # get all objects in the ground-truth
243
+ gt_objects = np.unique(gt)
244
+ gt_objects = gt_objects[gt_objects != 0].tolist()
245
+
246
+ # get all objects in the predicted mask
247
+ mask_objects = np.unique(mask)
248
+ mask_objects = mask_objects[mask_objects != 0].tolist()
249
+
250
+ self.objects_in_gt.update(set(gt_objects))
251
+ self.objects_in_masks.update(set(mask_objects))
252
+
253
+ all_objects = self.objects_in_gt.union(self.objects_in_masks)
254
+
255
+ # boundary disk for boundary F-score. It is the same for all objects.
256
+ bound_pix = np.ceil(self.boundary * np.linalg.norm(mask.shape))
257
+ boundary_disk = disk(bound_pix)
258
+
259
+ for obj_idx in all_objects:
260
+ obj_mask = mask == obj_idx
261
+ obj_gt = gt == obj_idx
262
+
263
+ # object iou
264
+ self.object_iou[obj_idx].append(
265
+ get_iou((obj_mask * obj_gt).sum(), obj_mask.sum() + obj_gt.sum())
266
+ )
267
+ """
268
+ # boundary f-score
269
+ This part is copied from davis2017-evaluation
270
+ """
271
+ mask_boundary = _seg2bmap(obj_mask)
272
+ gt_boundary = _seg2bmap(obj_gt)
273
+ mask_dilated = cv2.dilate(mask_boundary.astype(np.uint8), boundary_disk)
274
+ gt_dilated = cv2.dilate(gt_boundary.astype(np.uint8), boundary_disk)
275
+
276
+ # Get the intersection
277
+ gt_match = gt_boundary * mask_dilated
278
+ fg_match = mask_boundary * gt_dilated
279
+
280
+ # Area of the intersection
281
+ n_fg = np.sum(mask_boundary)
282
+ n_gt = np.sum(gt_boundary)
283
+
284
+ # Compute precision and recall
285
+ if n_fg == 0 and n_gt > 0:
286
+ precision = 1
287
+ recall = 0
288
+ elif n_fg > 0 and n_gt == 0:
289
+ precision = 0
290
+ recall = 1
291
+ elif n_fg == 0 and n_gt == 0:
292
+ precision = 1
293
+ recall = 1
294
+ else:
295
+ precision = np.sum(fg_match) / float(n_fg)
296
+ recall = np.sum(gt_match) / float(n_gt)
297
+
298
+ # Compute F measure
299
+ if precision + recall == 0:
300
+ F = 0
301
+ else:
302
+ F = 2 * precision * recall / (precision + recall)
303
+ self.boundary_f[obj_idx].append(F)
304
+
305
+ def conclude(self):
306
+ all_iou = {}
307
+ all_boundary_f = {}
308
+
309
+ for object_id in self.objects_in_gt:
310
+ all_iou[object_id] = np.mean(self.object_iou[object_id]) * 100
311
+ all_boundary_f[object_id] = np.mean(self.boundary_f[object_id]) * 100
312
+
313
+ return all_iou, all_boundary_f
314
+
315
+
316
+ def benchmark(
317
+ gt_roots,
318
+ mask_roots,
319
+ strict=True,
320
+ num_processes=None,
321
+ *,
322
+ verbose=True,
323
+ skip_first_and_last=True,
324
+ ):
325
+ """
326
+ gt_roots: a list of paths to datasets, i.e., [path_to_DatasetA, path_to_DatasetB, ...]
327
+ mask_roots: same as above, but the .png are masks predicted by the model
328
+ strict: when True, all videos in the dataset must have corresponding predictions.
329
+ Setting it to False is useful in cases where the ground-truth contains both train/val
330
+ sets, but the model only predicts the val subset.
331
+ Either way, if a video is predicted (i.e., the corresponding folder exists),
332
+ then it must at least contain all the masks in the ground truth annotations.
333
+ Masks that are in the prediction but not in the ground-truth
334
+ (i.e., sparse annotations) are ignored.
335
+ skip_first_and_last: whether we should skip the first and the last frame in evaluation.
336
+ This is used by DAVIS 2017 in their semi-supervised evaluation.
337
+ It should be disabled for unsupervised evaluation.
338
+ """
339
+
340
+ assert len(gt_roots) == len(mask_roots)
341
+ single_dataset = len(gt_roots) == 1
342
+
343
+ if verbose:
344
+ if skip_first_and_last:
345
+ print(
346
+ "We are *SKIPPING* the evaluation of the first and the last frame (standard for semi-supervised video object segmentation)."
347
+ )
348
+ else:
349
+ print(
350
+ "We are *NOT SKIPPING* the evaluation of the first and the last frame (*NOT STANDARD* for semi-supervised video object segmentation)."
351
+ )
352
+
353
+ pool = Pool(num_processes)
354
+ start = time.time()
355
+ to_wait = []
356
+ for gt_root, mask_root in zip(gt_roots, mask_roots):
357
+ # Validate folders
358
+ validated = True
359
+ gt_videos = os.listdir(gt_root)
360
+ mask_videos = os.listdir(mask_root)
361
+
362
+ # if the user passed the root directory instead of Annotations
363
+ if len(gt_videos) != len(mask_videos):
364
+ if "Annotations" in gt_videos:
365
+ if ".png" not in os.listdir(path.join(gt_root, "Annotations"))[0]:
366
+ gt_root = path.join(gt_root, "Annotations")
367
+ gt_videos = os.listdir(gt_root)
368
+
369
+ # remove non-folder items
370
+ gt_videos = list(filter(lambda x: path.isdir(path.join(gt_root, x)), gt_videos))
371
+ mask_videos = list(
372
+ filter(lambda x: path.isdir(path.join(mask_root, x)), mask_videos)
373
+ )
374
+
375
+ if not strict:
376
+ videos = sorted(list(set(gt_videos) & set(mask_videos)))
377
+ else:
378
+ gt_extras = set(gt_videos) - set(mask_videos)
379
+ mask_extras = set(mask_videos) - set(gt_videos)
380
+
381
+ if len(gt_extras) > 0:
382
+ print(
383
+ f"Videos that are in {gt_root} but not in {mask_root}: {gt_extras}"
384
+ )
385
+ validated = False
386
+ if len(mask_extras) > 0:
387
+ print(
388
+ f"Videos that are in {mask_root} but not in {gt_root}: {mask_extras}"
389
+ )
390
+ validated = False
391
+ if not validated:
392
+ print("Validation failed. Exiting.")
393
+ exit(1)
394
+
395
+ videos = sorted(gt_videos)
396
+
397
+ if verbose:
398
+ print(
399
+ f"In dataset {gt_root}, we are evaluating on {len(videos)} videos: {videos}"
400
+ )
401
+
402
+ if single_dataset:
403
+ if verbose:
404
+ results = tqdm.tqdm(
405
+ pool.imap(
406
+ VideoEvaluator(
407
+ gt_root, mask_root, skip_first_and_last=skip_first_and_last
408
+ ),
409
+ videos,
410
+ ),
411
+ total=len(videos),
412
+ )
413
+ else:
414
+ results = pool.map(
415
+ VideoEvaluator(
416
+ gt_root, mask_root, skip_first_and_last=skip_first_and_last
417
+ ),
418
+ videos,
419
+ )
420
+ else:
421
+ to_wait.append(
422
+ pool.map_async(
423
+ VideoEvaluator(
424
+ gt_root, mask_root, skip_first_and_last=skip_first_and_last
425
+ ),
426
+ videos,
427
+ )
428
+ )
429
+
430
+ pool.close()
431
+
432
+ all_global_jf, all_global_j, all_global_f = [], [], []
433
+ all_object_metrics = []
434
+ for i, mask_root in enumerate(mask_roots):
435
+ if not single_dataset:
436
+ results = to_wait[i].get()
437
+
438
+ all_iou = []
439
+ all_boundary_f = []
440
+ object_metrics = {}
441
+ for name, iou, boundary_f in results:
442
+ all_iou.extend(list(iou.values()))
443
+ all_boundary_f.extend(list(boundary_f.values()))
444
+ object_metrics[name] = (iou, boundary_f)
445
+
446
+ global_j = np.array(all_iou).mean()
447
+ global_f = np.array(all_boundary_f).mean()
448
+ global_jf = (global_j + global_f) / 2
449
+
450
+ time_taken = time.time() - start
451
+ """
452
+ Build string for reporting results
453
+ """
454
+ # find max length for padding
455
+ ml = max(*[len(n) for n in object_metrics.keys()], len("Global score"))
456
+ # build header
457
+ out_string = f'{"sequence":<{ml}},{"obj":>3}, {"J&F":>4}, {"J":>4}, {"F":>4}\n'
458
+ out_string += f'{"Global score":<{ml}},{"":>3}, {global_jf:.1f}, {global_j:.1f}, {global_f:.1f}\n'
459
+ # append one line for each object
460
+ for name, (iou, boundary_f) in object_metrics.items():
461
+ for object_idx in iou.keys():
462
+ j, f = iou[object_idx], boundary_f[object_idx]
463
+ jf = (j + f) / 2
464
+ out_string += (
465
+ f"{name:<{ml}},{object_idx:03}, {jf:>4.1f}, {j:>4.1f}, {f:>4.1f}\n"
466
+ )
467
+
468
+ # print to console
469
+ if verbose:
470
+ print(out_string.replace(",", " "), end="")
471
+ print("\nSummary:")
472
+ print(
473
+ f"Global score: J&F: {global_jf:.1f} J: {global_j:.1f} F: {global_f:.1f}"
474
+ )
475
+ print(f"Time taken: {time_taken:.2f}s")
476
+
477
+ # print to file
478
+ result_path = path.join(mask_root, "results.csv")
479
+ print(f"Saving the results to {result_path}")
480
+ with open(result_path, "w") as f:
481
+ f.write(out_string)
482
+
483
+ all_global_jf.append(global_jf)
484
+ all_global_j.append(global_j)
485
+ all_global_f.append(global_f)
486
+ all_object_metrics.append(object_metrics)
487
+
488
+ return all_global_jf, all_global_j, all_global_f, all_object_metrics
segment-anything-2/sav_dataset/utils/sav_utils.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the sav_dataset directory of this source tree.
6
+ import json
7
+ import os
8
+ from typing import Dict, List, Optional, Tuple
9
+
10
+ import cv2
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
13
+ import pycocotools.mask as mask_util
14
+
15
+
16
+ def decode_video(video_path: str) -> List[np.ndarray]:
17
+ """
18
+ Decode the video and return the RGB frames
19
+ """
20
+ video = cv2.VideoCapture(video_path)
21
+ video_frames = []
22
+ while video.isOpened():
23
+ ret, frame = video.read()
24
+ if ret:
25
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
26
+ video_frames.append(frame)
27
+ else:
28
+ break
29
+ return video_frames
30
+
31
+
32
+ def show_anns(masks, colors: List, borders=True) -> None:
33
+ """
34
+ show the annotations
35
+ """
36
+ # return if no masks
37
+ if len(masks) == 0:
38
+ return
39
+
40
+ # sort masks by size
41
+ sorted_annot_and_color = sorted(
42
+ zip(masks, colors), key=(lambda x: x[0].sum()), reverse=True
43
+ )
44
+ H, W = sorted_annot_and_color[0][0].shape[0], sorted_annot_and_color[0][0].shape[1]
45
+
46
+ canvas = np.ones((H, W, 4))
47
+ canvas[:, :, 3] = 0 # set the alpha channel
48
+ contour_thickness = max(1, int(min(5, 0.01 * min(H, W))))
49
+ for mask, color in sorted_annot_and_color:
50
+ canvas[mask] = np.concatenate([color, [0.55]])
51
+ if borders:
52
+ contours, _ = cv2.findContours(
53
+ np.array(mask, dtype=np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE
54
+ )
55
+ cv2.drawContours(
56
+ canvas, contours, -1, (0.05, 0.05, 0.05, 1), thickness=contour_thickness
57
+ )
58
+
59
+ ax = plt.gca()
60
+ ax.imshow(canvas)
61
+
62
+
63
+ class SAVDataset:
64
+ """
65
+ SAVDataset is a class to load the SAV dataset and visualize the annotations.
66
+ """
67
+
68
+ def __init__(self, sav_dir, annot_sample_rate=4):
69
+ """
70
+ Args:
71
+ sav_dir: the directory of the SAV dataset
72
+ annot_sample_rate: the sampling rate of the annotations.
73
+ The annotations are aligned with the videos at 6 fps.
74
+ """
75
+ self.sav_dir = sav_dir
76
+ self.annot_sample_rate = annot_sample_rate
77
+ self.manual_mask_colors = np.random.random((256, 3))
78
+ self.auto_mask_colors = np.random.random((256, 3))
79
+
80
+ def read_frames(self, mp4_path: str) -> None:
81
+ """
82
+ Read the frames and downsample them to align with the annotations.
83
+ """
84
+ if not os.path.exists(mp4_path):
85
+ print(f"{mp4_path} doesn't exist.")
86
+ return None
87
+ else:
88
+ # decode the video
89
+ frames = decode_video(mp4_path)
90
+ print(f"There are {len(frames)} frames decoded from {mp4_path} (24fps).")
91
+
92
+ # downsample the frames to align with the annotations
93
+ frames = frames[:: self.annot_sample_rate]
94
+ print(
95
+ f"Videos are annotated every {self.annot_sample_rate} frames. "
96
+ "To align with the annotations, "
97
+ f"downsample the video to {len(frames)} frames."
98
+ )
99
+ return frames
100
+
101
+ def get_frames_and_annotations(
102
+ self, video_id: str
103
+ ) -> Tuple[List | None, Dict | None, Dict | None]:
104
+ """
105
+ Get the frames and annotations for video.
106
+ """
107
+ # load the video
108
+ mp4_path = os.path.join(self.sav_dir, video_id + ".mp4")
109
+ frames = self.read_frames(mp4_path)
110
+ if frames is None:
111
+ return None, None, None
112
+
113
+ # load the manual annotations
114
+ manual_annot_path = os.path.join(self.sav_dir, video_id + "_manual.json")
115
+ if not os.path.exists(manual_annot_path):
116
+ print(f"{manual_annot_path} doesn't exist. Something might be wrong.")
117
+ manual_annot = None
118
+ else:
119
+ manual_annot = json.load(open(manual_annot_path))
120
+
121
+ # load the manual annotations
122
+ auto_annot_path = os.path.join(self.sav_dir, video_id + "_auto.json")
123
+ if not os.path.exists(auto_annot_path):
124
+ print(f"{auto_annot_path} doesn't exist.")
125
+ auto_annot = None
126
+ else:
127
+ auto_annot = json.load(open(auto_annot_path))
128
+
129
+ return frames, manual_annot, auto_annot
130
+
131
+ def visualize_annotation(
132
+ self,
133
+ frames: List[np.ndarray],
134
+ auto_annot: Optional[Dict],
135
+ manual_annot: Optional[Dict],
136
+ annotated_frame_id: int,
137
+ show_auto=True,
138
+ show_manual=True,
139
+ ) -> None:
140
+ """
141
+ Visualize the annotations on the annotated_frame_id.
142
+ If show_manual is True, show the manual annotations.
143
+ If show_auto is True, show the auto annotations.
144
+ By default, show both auto and manual annotations.
145
+ """
146
+
147
+ if annotated_frame_id >= len(frames):
148
+ print("invalid annotated_frame_id")
149
+ return
150
+
151
+ rles = []
152
+ colors = []
153
+ if show_manual and manual_annot is not None:
154
+ rles.extend(manual_annot["masklet"][annotated_frame_id])
155
+ colors.extend(
156
+ self.manual_mask_colors[
157
+ : len(manual_annot["masklet"][annotated_frame_id])
158
+ ]
159
+ )
160
+ if show_auto and auto_annot is not None:
161
+ rles.extend(auto_annot["masklet"][annotated_frame_id])
162
+ colors.extend(
163
+ self.auto_mask_colors[: len(auto_annot["masklet"][annotated_frame_id])]
164
+ )
165
+
166
+ plt.imshow(frames[annotated_frame_id])
167
+
168
+ if len(rles) > 0:
169
+ masks = [mask_util.decode(rle) > 0 for rle in rles]
170
+ show_anns(masks, colors)
171
+ else:
172
+ print("No annotation will be shown")
173
+
174
+ plt.axis("off")
175
+ plt.show()
segment-anything-2/setup.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from setuptools import find_packages, setup
8
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
9
+
10
+ # Package metadata
11
+ NAME = "SAM 2"
12
+ VERSION = "1.0"
13
+ DESCRIPTION = "SAM 2: Segment Anything in Images and Videos"
14
+ URL = "https://github.com/facebookresearch/segment-anything-2"
15
+ AUTHOR = "Meta AI"
16
+ AUTHOR_EMAIL = "[email protected]"
17
+ LICENSE = "Apache 2.0"
18
+
19
+ # Read the contents of README file
20
+ with open("README.md", "r") as f:
21
+ LONG_DESCRIPTION = f.read()
22
+
23
+ # Required dependencies
24
+ REQUIRED_PACKAGES = [
25
+ "torch>=2.3.1",
26
+ "torchvision>=0.18.1",
27
+ "numpy>=1.24.4",
28
+ "tqdm>=4.66.1",
29
+ "hydra-core>=1.3.2",
30
+ "iopath>=0.1.10",
31
+ "pillow>=9.4.0",
32
+ ]
33
+
34
+ EXTRA_PACKAGES = {
35
+ "demo": ["matplotlib>=3.9.1", "jupyter>=1.0.0", "opencv-python>=4.7.0"],
36
+ "dev": ["black==24.2.0", "usort==1.0.2", "ufmt==2.0.0b2"],
37
+ }
38
+
39
+
40
+ def get_extensions():
41
+ srcs = ["sam2/csrc/connected_components.cu"]
42
+ compile_args = {
43
+ "cxx": [],
44
+ "nvcc": [
45
+ "-DCUDA_HAS_FP16=1",
46
+ "-D__CUDA_NO_HALF_OPERATORS__",
47
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
48
+ "-D__CUDA_NO_HALF2_OPERATORS__",
49
+ "-allow-unsupported-compiler"
50
+ ],
51
+ }
52
+ ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
53
+ return ext_modules
54
+
55
+
56
+ # Setup configuration
57
+ setup(
58
+ name=NAME,
59
+ version=VERSION,
60
+ description=DESCRIPTION,
61
+ long_description=LONG_DESCRIPTION,
62
+ long_description_content_type="text/markdown",
63
+ url=URL,
64
+ author=AUTHOR,
65
+ author_email=AUTHOR_EMAIL,
66
+ license=LICENSE,
67
+ packages=find_packages(exclude="notebooks"),
68
+ install_requires=REQUIRED_PACKAGES,
69
+ extras_require=EXTRA_PACKAGES,
70
+ python_requires=">=3.10.0",
71
+ ext_modules=get_extensions(),
72
+ cmdclass={"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)},
73
+ )