shikunl commited on
Commit
7617596
β€’
1 Parent(s): ef1faf6
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. .gitignore +162 -1
  2. .gitmodules +3 -0
  3. .pre-commit-config.yaml +37 -0
  4. .style.yapf +5 -0
  5. README.md +4 -4
  6. app.py +20 -19
  7. gradio_caption.py β†’ app_caption.py +21 -7
  8. gradio_vqa.py +0 -33
  9. patch +82 -0
  10. prismer/dataset/__init__.py +0 -32
  11. prismer/dataset/ade_features.pt +0 -3
  12. prismer/dataset/background_features.pt +0 -3
  13. prismer/dataset/caption_dataset.py +0 -63
  14. prismer/dataset/clip_pca.pkl +0 -3
  15. prismer/dataset/coco_features.pt +0 -3
  16. prismer/dataset/detection_features.pt +0 -3
  17. prismer/dataset/randaugment.py +0 -265
  18. prismer/dataset/utils.py +0 -192
  19. prismer/dataset/vqa_dataset.py +0 -53
  20. prismer/demo.py +0 -77
  21. prismer/demo_vis.py +0 -161
  22. prismer/download_checkpoints.py +0 -124
  23. prismer/experts/depth/base_model.py +0 -16
  24. prismer/experts/depth/blocks.py +0 -383
  25. prismer/experts/depth/generate_dataset.py +0 -32
  26. prismer/experts/depth/models.py +0 -124
  27. prismer/experts/depth/vit.py +0 -576
  28. prismer/experts/edge/generate_dataset.py +0 -32
  29. prismer/experts/edge/images.py +0 -50
  30. prismer/experts/edge/model.py +0 -286
  31. prismer/experts/generate_depth.py +0 -56
  32. prismer/experts/generate_edge.py +0 -57
  33. prismer/experts/generate_normal.py +0 -58
  34. prismer/experts/generate_objdet.py +0 -115
  35. prismer/experts/generate_ocrdet.py +0 -86
  36. prismer/experts/generate_segmentation.py +0 -56
  37. prismer/experts/model_bank.py +0 -139
  38. prismer/experts/normal/generate_dataset.py +0 -34
  39. prismer/experts/normal/models/NNET.py +0 -22
  40. prismer/experts/normal/models/baseline.py +0 -85
  41. prismer/experts/normal/models/submodules/decoder.py +0 -202
  42. prismer/experts/normal/models/submodules/encoder.py +0 -32
  43. prismer/experts/normal/models/submodules/submodules.py +0 -140
  44. prismer/experts/normal/utils/losses.py +0 -178
  45. prismer/experts/normal/utils/utils.py +0 -191
  46. prismer/experts/obj_detection/configs/Base-CRCNN-COCO.yaml +0 -48
  47. prismer/experts/obj_detection/configs/O365_CRFR50_CAS_2x.yaml +0 -15
  48. prismer/experts/obj_detection/configs/OID_CRFR50_CAS_2x.yaml +0 -22
  49. prismer/experts/obj_detection/configs/Partitioned_COIM_R50_6x+2x.yaml +0 -28
  50. prismer/experts/obj_detection/configs/Partitioned_COI_R50_2x.yaml +0 -29
.gitignore CHANGED
@@ -1 +1,162 @@
1
- .idea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cache/
2
+
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ *.py[cod]
6
+ *$py.class
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # Distribution / packaging
12
+ .Python
13
+ build/
14
+ develop-eggs/
15
+ dist/
16
+ downloads/
17
+ eggs/
18
+ .eggs/
19
+ lib/
20
+ lib64/
21
+ parts/
22
+ sdist/
23
+ var/
24
+ wheels/
25
+ share/python-wheels/
26
+ *.egg-info/
27
+ .installed.cfg
28
+ *.egg
29
+ MANIFEST
30
+
31
+ # PyInstaller
32
+ # Usually these files are written by a python script from a template
33
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
34
+ *.manifest
35
+ *.spec
36
+
37
+ # Installer logs
38
+ pip-log.txt
39
+ pip-delete-this-directory.txt
40
+
41
+ # Unit test / coverage reports
42
+ htmlcov/
43
+ .tox/
44
+ .nox/
45
+ .coverage
46
+ .coverage.*
47
+ .cache
48
+ nosetests.xml
49
+ coverage.xml
50
+ *.cover
51
+ *.py,cover
52
+ .hypothesis/
53
+ .pytest_cache/
54
+ cover/
55
+
56
+ # Translations
57
+ *.mo
58
+ *.pot
59
+
60
+ # Django stuff:
61
+ *.log
62
+ local_settings.py
63
+ db.sqlite3
64
+ db.sqlite3-journal
65
+
66
+ # Flask stuff:
67
+ instance/
68
+ .webassets-cache
69
+
70
+ # Scrapy stuff:
71
+ .scrapy
72
+
73
+ # Sphinx documentation
74
+ docs/_build/
75
+
76
+ # PyBuilder
77
+ .pybuilder/
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # IPython
84
+ profile_default/
85
+ ipython_config.py
86
+
87
+ # pyenv
88
+ # For a library or package, you might want to ignore these files since the code is
89
+ # intended to run in multiple environments; otherwise, check them in:
90
+ # .python-version
91
+
92
+ # pipenv
93
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
95
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
96
+ # install all needed dependencies.
97
+ #Pipfile.lock
98
+
99
+ # poetry
100
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
102
+ # commonly ignored for libraries.
103
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104
+ #poetry.lock
105
+
106
+ # pdm
107
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108
+ #pdm.lock
109
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110
+ # in version control.
111
+ # https://pdm.fming.dev/#use-with-ide
112
+ .pdm.toml
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "prismer"]
2
+ path = prismer
3
+ url = https://github.com/nvlabs/prismer
.pre-commit-config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: patch
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.12.0
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.991
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ additional_dependencies: ['types-python-slugify']
33
+ - repo: https://github.com/google/yapf
34
+ rev: v0.32.0
35
+ hooks:
36
+ - id: yapf
37
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
README.md CHANGED
@@ -1,11 +1,11 @@
1
  ---
2
  title: Prismer
3
- emoji: πŸ”»
4
- colorFrom: yellow
5
- colorTo: purple
6
  sdk: docker
7
- sdk_version: 3.20.1
8
  pinned: false
 
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Prismer
3
+ emoji: πŸ¦€
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: docker
 
7
  pinned: false
8
+ license: other
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,34 +1,35 @@
 
 
1
  from __future__ import annotations
2
 
3
  import os
 
 
 
4
  import gradio as gr
5
 
6
- from gradio_caption import create_demo as create_caption
7
- from gradio_vqa import create_demo as create_vqa
8
- from prismer_model import build_deformable_conv, download_models
 
 
 
9
 
 
 
10
 
11
- # Prepare Prismer checkpoints
12
  download_models()
13
  build_deformable_conv()
14
 
15
- # Official Demo here
16
- description = """
17
- # Prismer
18
- The official demo for **Prismer: A Vision-Language Model with An Ensemble of Experts**.
19
- Please refer to our [project page](https://shikun.io/projects/prismer) or [github](https://github.com/NVlabs/prismer) for more details.
20
- """
21
 
22
  if (SPACE_ID := os.getenv('SPACE_ID')) is not None:
23
- description += f'For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a>'
24
-
25
 
26
- with gr.Blocks() as demo:
27
- gr.Markdown(description)
28
- with gr.Tab("Zero-shot Image Captioning"):
29
- create_caption()
30
- with gr.Tab("Visual Question Answering"):
31
- create_vqa()
32
 
33
  demo.queue(api_open=False).launch()
34
-
 
1
+ #!/usr/bin/env python
2
+
3
  from __future__ import annotations
4
 
5
  import os
6
+ import shutil
7
+ import subprocess
8
+
9
  import gradio as gr
10
 
11
+ if os.getenv('SYSTEM') == 'spaces':
12
+ with open('patch') as f:
13
+ subprocess.run('patch -p1'.split(), cwd='prismer', stdin=f)
14
+ shutil.copytree('prismer/helpers/images',
15
+ 'prismer/images',
16
+ dirs_exist_ok=True)
17
 
18
+ from app_caption import create_demo as create_demo_caption
19
+ from prismer_model import build_deformable_conv, download_models
20
 
 
21
  download_models()
22
  build_deformable_conv()
23
 
24
+ DESCRIPTION = '# [Prismer](https://github.com/nvlabs/prismer)'
 
 
 
 
 
25
 
26
  if (SPACE_ID := os.getenv('SPACE_ID')) is not None:
27
+ DESCRIPTION += f'<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
 
28
 
29
+ with gr.Blocks(css='style.css') as demo:
30
+ gr.Markdown(DESCRIPTION)
31
+ with gr.Tabs():
32
+ with gr.TabItem('Caption'):
33
+ create_demo_caption()
 
34
 
35
  demo.queue(api_open=False).launch()
 
gradio_caption.py β†’ app_caption.py RENAMED
@@ -1,17 +1,24 @@
 
 
1
  from __future__ import annotations
2
 
3
  import os
4
  import pathlib
 
5
  import gradio as gr
6
 
7
- from prismer_model import run_experts
8
 
9
 
10
  def create_demo():
 
 
11
  with gr.Row():
12
  with gr.Column():
13
  image = gr.Image(label='Input', type='filepath')
14
- model_name = gr.Dropdown(label='Model', choices=['prismer_base'], value='prismer_base')
 
 
15
  run_button = gr.Button('Run')
16
  with gr.Column(scale=1.5):
17
  caption = gr.Text(label='Caption')
@@ -24,21 +31,28 @@ def create_demo():
24
  object_detection = gr.Image(label='Object Detection')
25
  ocr = gr.Image(label='OCR Detection')
26
 
27
- inputs = [image]
28
- outputs = [depth, edge, normals]
 
 
 
 
 
 
 
 
29
 
30
  paths = sorted(pathlib.Path('prismer/images').glob('*'))
31
  examples = [[path.as_posix(), 'prismer_base'] for path in paths]
32
  gr.Examples(examples=examples,
33
  inputs=inputs,
34
  outputs=outputs,
35
- fn=run_experts,
36
  cache_examples=os.getenv('SYSTEM') == 'spaces')
37
 
38
- run_button.click(fn=run_experts, inputs=inputs, outputs=outputs)
39
 
40
 
41
  if __name__ == '__main__':
42
  demo = create_demo()
43
  demo.queue().launch()
44
-
 
1
+ #!/usr/bin/env python
2
+
3
  from __future__ import annotations
4
 
5
  import os
6
  import pathlib
7
+
8
  import gradio as gr
9
 
10
+ from prismer_model import Model
11
 
12
 
13
  def create_demo():
14
+ model = Model()
15
+
16
  with gr.Row():
17
  with gr.Column():
18
  image = gr.Image(label='Input', type='filepath')
19
+ model_name = gr.Dropdown(label='Model',
20
+ choices=['prismer_base'],
21
+ value='prismer_base')
22
  run_button = gr.Button('Run')
23
  with gr.Column(scale=1.5):
24
  caption = gr.Text(label='Caption')
 
31
  object_detection = gr.Image(label='Object Detection')
32
  ocr = gr.Image(label='OCR Detection')
33
 
34
+ inputs = [image, model_name]
35
+ outputs = [
36
+ caption,
37
+ depth,
38
+ edge,
39
+ normals,
40
+ segmentation,
41
+ object_detection,
42
+ ocr,
43
+ ]
44
 
45
  paths = sorted(pathlib.Path('prismer/images').glob('*'))
46
  examples = [[path.as_posix(), 'prismer_base'] for path in paths]
47
  gr.Examples(examples=examples,
48
  inputs=inputs,
49
  outputs=outputs,
50
+ fn=model.run_caption,
51
  cache_examples=os.getenv('SYSTEM') == 'spaces')
52
 
53
+ run_button.click(fn=model.run_caption, inputs=inputs, outputs=outputs)
54
 
55
 
56
  if __name__ == '__main__':
57
  demo = create_demo()
58
  demo.queue().launch()
 
gradio_vqa.py DELETED
@@ -1,33 +0,0 @@
1
- import gradio as gr
2
- from PIL import Image
3
- import tempfile
4
-
5
-
6
- def predict_depth(model, image):
7
- depth = model.infer_pil(image)
8
- return depth
9
-
10
-
11
- def create_demo():
12
- with gr.Row():
13
- with gr.Column(scale=1):
14
- model_type = gr.Dropdown(["Prismer-Base", "Prismer-Large"], label="Model Size", value="Prismer-Base")
15
- ques = gr.Textbox(label="Question", placeholder="What's the number of this player?")
16
- rgb = gr.Image(label="Input Image", type='pil', elem_id='img-display-input').style(height="auto")
17
- submit = gr.Button("Submit")
18
- with gr.Column(scale=2):
19
- pred = gr.Textbox(label="Model Prediction")
20
- with gr.Row():
21
- depth = gr.Image(label="Depth", elem_id='img-display-output')
22
- edge = gr.Image(label="Edge", elem_id='img-display-output')
23
- normals = gr.Image(label="Normals", elem_id='img-display-output')
24
- with gr.Row():
25
- seg = gr.Image(label="Segmentation", elem_id='img-display-output')
26
- obj_det = gr.Image(label="Object Detection", elem_id='img-display-output')
27
- ocr_det = gr.Image(label="OCR Detection", elem_id='img-display-output')
28
-
29
- def on_submit(im, q, model_type):
30
- return pred, depth, edge, normals, seg, obj_det, ocr_det
31
-
32
- submit.click(on_submit, inputs=[rgb, ques, model_type], outputs=[pred, depth, edge, normals, seg, obj_det, ocr_det])
33
- examples = gr.Examples(examples=["examples/1.png"], inputs=[rgb])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
patch ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/dataset/caption_dataset.py b/dataset/caption_dataset.py
2
+ index 266fdda..0cc5d3f 100644
3
+ --- a/dataset/caption_dataset.py
4
+ +++ b/dataset/caption_dataset.py
5
+ @@ -50,7 +50,7 @@ class Caption(Dataset):
6
+ elif self.dataset == 'demo':
7
+ img_path_split = self.data_list[index]['image'].split('/')
8
+ img_name = img_path_split[-2] + '/' + img_path_split[-1]
9
+ - image, labels, labels_info = get_expert_labels('', self.label_path, img_name, 'helpers', self.experts)
10
+ + image, labels, labels_info = get_expert_labels('prismer', self.label_path, img_name, 'helpers', self.experts)
11
+
12
+ experts = self.transform(image, labels)
13
+ experts = post_label_process(experts, labels_info)
14
+ diff --git a/dataset/utils.py b/dataset/utils.py
15
+ index b368aac..418358c 100644
16
+ --- a/dataset/utils.py
17
+ +++ b/dataset/utils.py
18
+ @@ -5,6 +5,7 @@
19
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
20
+
21
+ import os
22
+ +import pathlib
23
+ import re
24
+ import json
25
+ import torch
26
+ @@ -14,10 +15,12 @@ import torchvision.transforms as transforms
27
+ import torchvision.transforms.functional as transforms_f
28
+ from dataset.randaugment import RandAugment
29
+
30
+ -COCO_FEATURES = torch.load('dataset/coco_features.pt')['features']
31
+ -ADE_FEATURES = torch.load('dataset/ade_features.pt')['features']
32
+ -DETECTION_FEATURES = torch.load('dataset/detection_features.pt')['features']
33
+ -BACKGROUND_FEATURES = torch.load('dataset/background_features.pt')
34
+ +cur_dir = pathlib.Path(__file__).parent
35
+ +
36
+ +COCO_FEATURES = torch.load(cur_dir / 'coco_features.pt')['features']
37
+ +ADE_FEATURES = torch.load(cur_dir / 'ade_features.pt')['features']
38
+ +DETECTION_FEATURES = torch.load(cur_dir / 'detection_features.pt')['features']
39
+ +BACKGROUND_FEATURES = torch.load(cur_dir / 'background_features.pt')
40
+
41
+
42
+ class Transform:
43
+ diff --git a/model/prismer.py b/model/prismer.py
44
+ index 080253a..02362a4 100644
45
+ --- a/model/prismer.py
46
+ +++ b/model/prismer.py
47
+ @@ -5,6 +5,7 @@
48
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
49
+
50
+ import json
51
+ +import pathlib
52
+ import torch.nn as nn
53
+
54
+ from model.modules.vit import load_encoder
55
+ @@ -12,6 +13,9 @@ from model.modules.roberta import load_decoder
56
+ from transformers import RobertaTokenizer, RobertaConfig
57
+
58
+
59
+ +cur_dir = pathlib.Path(__file__).parent
60
+ +
61
+ +
62
+ class Prismer(nn.Module):
63
+ def __init__(self, config):
64
+ super().__init__()
65
+ @@ -26,7 +30,7 @@ class Prismer(nn.Module):
66
+ elif exp in ['obj_detection', 'ocr_detection']:
67
+ self.experts[exp] = 64
68
+
69
+ - prismer_config = json.load(open('configs/prismer.json', 'r'))[config['prismer_model']]
70
+ + prismer_config = json.load(open(f'{cur_dir.parent}/configs/prismer.json', 'r'))[config['prismer_model']]
71
+ roberta_config = RobertaConfig.from_dict(prismer_config['roberta_model'])
72
+
73
+ self.tokenizer = RobertaTokenizer.from_pretrained(prismer_config['roberta_model']['model_name'])
74
+ @@ -35,7 +39,7 @@ class Prismer(nn.Module):
75
+
76
+ self.prepare_to_train(config['freeze'])
77
+ self.ignored_modules = self.get_ignored_modules(config['freeze'])
78
+ -
79
+ +
80
+ def prepare_to_train(self, mode='none'):
81
+ for name, params in self.named_parameters():
82
+ if mode == 'freeze_lang':
prismer/dataset/__init__.py DELETED
@@ -1,32 +0,0 @@
1
- # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
- #
3
- # This work is made available under the Nvidia Source Code License-NC.
4
- # To view a copy of this license, visit
5
- # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
-
7
- from torch.utils.data import DataLoader
8
-
9
- from dataset.vqa_dataset import VQA
10
- from dataset.caption_dataset import Caption
11
-
12
-
13
- def create_dataset(dataset, config):
14
- if dataset == 'vqa':
15
- train_dataset = VQA(config, train=True)
16
- test_dataset = VQA(config, train=False)
17
- return train_dataset, test_dataset
18
-
19
- elif dataset == 'caption':
20
- train_dataset = Caption(config, train=True)
21
- test_dataset = Caption(config, train=False)
22
- return train_dataset, test_dataset
23
-
24
-
25
- def create_loader(dataset, batch_size, num_workers, train, collate_fn=None):
26
- data_loader = DataLoader(dataset,
27
- batch_size=batch_size,
28
- num_workers=num_workers,
29
- collate_fn=collate_fn,
30
- shuffle=True if train else False,
31
- drop_last=True if train else False)
32
- return data_loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/dataset/ade_features.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4cb5c802ef9eaf9c9213219bdc5f8f2d6579400275ab443c43905fdcdd805e5a
3
- size 41899
 
 
 
 
prismer/dataset/background_features.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5f5bf8e158715633a9e0f307be1947ca74c294e522b705e1e61bd49361088092
3
- size 1003
 
 
 
 
prismer/dataset/caption_dataset.py DELETED
@@ -1,63 +0,0 @@
1
- # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
- #
3
- # This work is made available under the Nvidia Source Code License-NC.
4
- # To view a copy of this license, visit
5
- # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
-
7
- import glob
8
-
9
- from torch.utils.data import Dataset
10
- from dataset.utils import *
11
- from PIL import ImageFile
12
- ImageFile.LOAD_TRUNCATED_IMAGES = True
13
-
14
-
15
- class Caption(Dataset):
16
- def __init__(self, config, train=True):
17
- self.data_path = config['data_path']
18
- self.label_path = config['label_path']
19
- self.experts = config['experts']
20
- self.prefix = config['prefix']
21
- self.dataset = config['dataset']
22
- self.transform = Transform(resize_resolution=config['image_resolution'], scale_size=[0.5, 1.0], train=train)
23
- self.train = train
24
-
25
- if train:
26
- self.data_list = []
27
- if self.dataset in ['coco', 'nocaps']:
28
- self.data_list += json.load(open(os.path.join(self.data_path, 'coco_karpathy_train.json'), 'r'))
29
- else:
30
- if self.dataset == 'coco':
31
- self.data_list = json.load(open(os.path.join(self.data_path, 'coco_karpathy_test.json'), 'r'))
32
- elif self.dataset == 'nocaps':
33
- self.data_list = json.load(open(os.path.join(self.data_path, 'nocaps_val.json'), 'r'))
34
- elif self.dataset == 'demo':
35
- data_folders = glob.glob(f'{self.data_path}/*/')
36
- self.data_list = [{'image': data} for f in data_folders for data in glob.glob(f + '*.jpg')]
37
- self.data_list += [{'image': data} for f in data_folders for data in glob.glob(f + '*.png')]
38
- self.data_list += [{'image': data} for f in data_folders for data in glob.glob(f + '*.jpeg')]
39
-
40
- def __len__(self):
41
- return len(self.data_list)
42
-
43
- def __getitem__(self, index):
44
- data = self.data_list[index]
45
-
46
- if self.dataset == 'coco':
47
- image, labels, labels_info = get_expert_labels(self.data_path, self.label_path, data['image'], 'vqav2', self.experts)
48
- elif self.dataset == 'nocaps':
49
- image, labels, labels_info = get_expert_labels(self.data_path, self.label_path, data['image'], 'nocaps', self.experts)
50
- elif self.dataset == 'demo':
51
- img_path_split = self.data_list[index]['image'].split('/')
52
- img_name = img_path_split[-2] + '/' + img_path_split[-1]
53
- image, labels, labels_info = get_expert_labels('prismer', self.label_path, img_name, 'helpers', self.experts)
54
-
55
- experts = self.transform(image, labels)
56
- experts = post_label_process(experts, labels_info)
57
-
58
- if self.train:
59
- caption = pre_caption(self.prefix + ' ' + self.data_list[index]['caption'], max_words=30)
60
- return experts, caption
61
- else:
62
- return experts, index
63
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/dataset/clip_pca.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:064950869bce7d622ff7e733dc64c8e1be18be926c3f64e66ae57164df3f8bc5
3
- size 401612
 
 
 
 
prismer/dataset/coco_features.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ccf18221afe8dddef3ffb9daad31d5c7a92cdc2f2f434d77cbeb48031bc75756
3
- size 36651
 
 
 
 
prismer/dataset/detection_features.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c38ba9352b2a9f832b14fdc19ac407527ffeaa2903958a73f6eb649f78119c76
3
- size 198443
 
 
 
 
prismer/dataset/randaugment.py DELETED
@@ -1,265 +0,0 @@
1
- # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
- #
3
- # This work is made available under the Nvidia Source Code License-NC.
4
- # To view a copy of this license, visit
5
- # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
-
7
- import random
8
- import numpy as np
9
- import torch
10
-
11
- from PIL import Image, ImageOps, ImageEnhance, ImageDraw
12
-
13
-
14
- fillmask = {'depth': 0, 'normal': 0, 'edge': 0, 'seg_coco': 255, 'seg_ade': 255,
15
- 'obj_detection': 255, 'ocr_detection': 255}
16
- fillcolor = (0, 0, 0)
17
-
18
-
19
- def affine_transform(pair, affine_params):
20
- img, label = pair
21
- img = img.transform(img.size, Image.AFFINE, affine_params,
22
- resample=Image.BILINEAR, fillcolor=fillcolor)
23
- if label is not None:
24
- for exp in label:
25
- label[exp] = label[exp].transform(label[exp].size, Image.AFFINE, affine_params,
26
- resample=Image.NEAREST, fillcolor=fillmask[exp])
27
- return img, label
28
-
29
-
30
- def ShearX(pair, v): # [-0.3, 0.3]
31
- assert -0.3 <= v <= 0.3
32
- if random.random() > 0.5:
33
- v = -v
34
- return affine_transform(pair, (1, v, 0, 0, 1, 0))
35
-
36
-
37
- def ShearY(pair, v): # [-0.3, 0.3]
38
- assert -0.3 <= v <= 0.3
39
- if random.random() > 0.5:
40
- v = -v
41
- return affine_transform(pair, (1, 0, 0, v, 1, 0))
42
-
43
-
44
- def TranslateX(pair, v): # [-150, 150] => percentage: [-0.45, 0.45]
45
- assert -0.45 <= v <= 0.45
46
- if random.random() > 0.5:
47
- v = -v
48
- img, _ = pair
49
- v = v * img.size[0]
50
- return affine_transform(pair, (1, 0, v, 0, 1, 0))
51
-
52
-
53
- def TranslateY(pair, v): # [-150, 150] => percentage: [-0.45, 0.45]
54
- assert -0.45 <= v <= 0.45
55
- if random.random() > 0.5:
56
- v = -v
57
- img, _ = pair
58
- v = v * img.size[1]
59
- return affine_transform(pair, (1, 0, 0, 0, 1, v))
60
-
61
-
62
- def TranslateXAbs(pair, v): # [-150, 150] => percentage: [-0.45, 0.45]
63
- assert 0 <= v <= 10
64
- if random.random() > 0.5:
65
- v = -v
66
- return affine_transform(pair, (1, 0, v, 0, 1, 0))
67
-
68
-
69
- def TranslateYAbs(pair, v): # [-150, 150] => percentage: [-0.45, 0.45]
70
- assert 0 <= v <= 10
71
- if random.random() > 0.5:
72
- v = -v
73
- return affine_transform(pair, (1, 0, 0, 0, 1, v))
74
-
75
-
76
- def Rotate(pair, v): # [-30, 30]
77
- assert -30 <= v <= 30
78
- if random.random() > 0.5:
79
- v = -v
80
- img, label = pair
81
- img = img.rotate(v, fillcolor=fillcolor)
82
- if label is not None:
83
- for exp in label:
84
- label[exp] = label[exp].rotate(v, resample=Image.NEAREST, fillcolor=fillmask[exp])
85
- return img, label
86
-
87
-
88
- def AutoContrast(pair, _):
89
- img, label = pair
90
- return ImageOps.autocontrast(img), label
91
-
92
-
93
- def Invert(pair, _):
94
- img, label = pair
95
- return ImageOps.invert(img), label
96
-
97
-
98
- def Equalize(pair, _):
99
- img, label = pair
100
- return ImageOps.equalize(img), label
101
-
102
-
103
- def Flip(pair, _): # not from the paper
104
- img, label = pair
105
- return ImageOps.mirror(img), ImageOps.mirror(label)
106
-
107
-
108
- def Solarize(pair, v): # [0, 256]
109
- img, label = pair
110
- assert 0 <= v <= 256
111
- return ImageOps.solarize(img, v), label
112
-
113
-
114
- def Posterize(pair, v): # [4, 8]
115
- img, label = pair
116
- assert 4 <= v <= 8
117
- v = int(v)
118
- return ImageOps.posterize(img, v), label
119
-
120
-
121
- def Posterize2(pair, v): # [0, 4]
122
- img, label = pair
123
- assert 0 <= v <= 4
124
- v = int(v)
125
- return ImageOps.posterize(img, v), label
126
-
127
-
128
- def Contrast(pair, v): # [0.1,1.9]
129
- img, label = pair
130
- assert 0.1 <= v <= 1.9
131
- return ImageEnhance.Contrast(img).enhance(v), label
132
-
133
-
134
- def Color(pair, v): # [0.1,1.9]
135
- img, label = pair
136
- assert 0.1 <= v <= 1.9
137
- return ImageEnhance.Color(img).enhance(v), label
138
-
139
-
140
- def Brightness(pair, v): # [0.1,1.9]
141
- img, label = pair
142
- assert 0.1 <= v <= 1.9
143
- return ImageEnhance.Brightness(img).enhance(v), label
144
-
145
-
146
- def Sharpness(pair, v): # [0.1,1.9]
147
- img, label = pair
148
- assert 0.1 <= v <= 1.9
149
- return ImageEnhance.Sharpness(img).enhance(v), label
150
-
151
-
152
- def Cutout(pair, v): # [0, 60] => percentage: [0, 0.2]
153
- assert 0.0 <= v <= 0.2
154
- if v <= 0.:
155
- return pair
156
- img, label = pair
157
- v = v * img.size[0]
158
- return CutoutAbs(img, v), label
159
-
160
-
161
- def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
162
- # assert 0 <= v <= 20
163
- if v < 0:
164
- return img
165
- w, h = img.size
166
- x0 = np.random.uniform(w)
167
- y0 = np.random.uniform(h)
168
-
169
- x0 = int(max(0, x0 - v / 2.))
170
- y0 = int(max(0, y0 - v / 2.))
171
- x1 = min(w, x0 + v)
172
- y1 = min(h, y0 + v)
173
-
174
- xy = (x0, y0, x1, y1)
175
- color = (125, 123, 114)
176
- # color = (0, 0, 0)
177
- img = img.copy()
178
- ImageDraw.Draw(img).rectangle(xy, color)
179
- return img
180
-
181
-
182
- def Identity(pair, v):
183
- return pair
184
-
185
-
186
- def augment_list(): # 16 oeprations and their ranges
187
- # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57
188
- l = [
189
- (Identity, 0., 1.0),
190
- (ShearX, 0., 0.3), # 0
191
- (ShearY, 0., 0.3), # 1
192
- (TranslateX, 0., 0.33), # 2
193
- (TranslateY, 0., 0.33), # 3
194
- (Rotate, 0, 30), # 4
195
- (AutoContrast, 0, 1), # 5
196
- # (Invert, 0, 1), # 6
197
- (Equalize, 0, 1), # 7
198
- # (Solarize, 0, 110), # 8
199
- # (Posterize, 4, 8), # 9
200
- # (Color, 0.1, 1.9), # 11
201
- (Brightness, 0.1, 1.9), # 12
202
- (Sharpness, 0.1, 1.9), # 13
203
- ]
204
- return l
205
-
206
-
207
- class Lighting(object):
208
- """Lighting noise(AlexNet - style PCA - based noise)"""
209
-
210
- def __init__(self, alphastd, eigval, eigvec):
211
- self.alphastd = alphastd
212
- self.eigval = torch.Tensor(eigval)
213
- self.eigvec = torch.Tensor(eigvec)
214
-
215
- def __call__(self, img):
216
- if self.alphastd == 0:
217
- return img
218
-
219
- alpha = img.new().resize_(3).normal_(0, self.alphastd)
220
- rgb = self.eigvec.type_as(img).clone() \
221
- .mul(alpha.view(1, 3).expand(3, 3)) \
222
- .mul(self.eigval.view(1, 3).expand(3, 3)) \
223
- .sum(1).squeeze()
224
-
225
- return img.add(rgb.view(3, 1, 1).expand_as(img))
226
-
227
-
228
- class CutoutDefault(object):
229
- """
230
- Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
231
- """
232
- def __init__(self, length):
233
- self.length = length
234
-
235
- def __call__(self, img):
236
- h, w = img.size(1), img.size(2)
237
- mask = np.ones((h, w), np.float32)
238
- y = np.random.randint(h)
239
- x = np.random.randint(w)
240
-
241
- y1 = np.clip(y - self.length // 2, 0, h)
242
- y2 = np.clip(y + self.length // 2, 0, h)
243
- x1 = np.clip(x - self.length // 2, 0, w)
244
- x2 = np.clip(x + self.length // 2, 0, w)
245
-
246
- mask[y1: y2, x1: x2] = 0.
247
- mask = torch.from_numpy(mask)
248
- mask = mask.expand_as(img)
249
- img *= mask
250
- return img
251
-
252
-
253
- class RandAugment:
254
- def __init__(self, n, m):
255
- self.n = n
256
- self.m = m # [0, 10]
257
- self.augment_list = augment_list()
258
-
259
- def __call__(self, img, label):
260
- pair = img, label
261
- ops = random.choices(self.augment_list, k=self.n)
262
- for op, minval, maxval in ops:
263
- val = (float(self.m) / 10) * float(maxval - minval) + minval
264
- pair = op(pair, val)
265
- return pair
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/dataset/utils.py DELETED
@@ -1,192 +0,0 @@
1
- # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
- #
3
- # This work is made available under the Nvidia Source Code License-NC.
4
- # To view a copy of this license, visit
5
- # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
-
7
- import os
8
- import re
9
- import json
10
- import torch
11
- import PIL.Image as Image
12
- import numpy as np
13
- import torchvision.transforms as transforms
14
- import torchvision.transforms.functional as transforms_f
15
- import pathlib
16
- from dataset.randaugment import RandAugment
17
-
18
-
19
- cur_dir = pathlib.Path(__file__).parent
20
-
21
- COCO_FEATURES = torch.load(cur_dir / 'coco_features.pt')['features']
22
- ADE_FEATURES = torch.load(cur_dir / 'ade_features.pt')['features']
23
- DETECTION_FEATURES = torch.load(cur_dir / 'detection_features.pt')['features']
24
- BACKGROUND_FEATURES = torch.load(cur_dir / 'background_features.pt')
25
-
26
-
27
- class Transform:
28
- def __init__(self, resize_resolution=384, scale_size=[0.5, 1.0], train=False):
29
- self.resize_size = [resize_resolution, resize_resolution]
30
- self.scale_size = scale_size
31
- self.train = train
32
- self.randaugment = RandAugment(2, 5)
33
-
34
- def __call__(self, image, labels):
35
- if self.train:
36
- # random resize crop
37
- i, j, h, w = transforms.RandomResizedCrop.get_params(img=image, scale=self.scale_size, ratio=[3. / 4, 4. / 3])
38
- image = transforms_f.crop(image, i, j, h, w)
39
- if labels is not None:
40
- for exp in labels:
41
- labels[exp] = transforms_f.crop(labels[exp], i, j, h, w)
42
-
43
- # resize to the defined shape
44
- image = transforms_f.resize(image, self.resize_size, transforms_f.InterpolationMode.BICUBIC)
45
- if labels is not None:
46
- for exp in labels:
47
- labels[exp] = transforms_f.resize(labels[exp], [224, 224], transforms_f.InterpolationMode.NEAREST)
48
-
49
- if self.train:
50
- # random flipping
51
- if torch.rand(1) > 0.5:
52
- image = transforms_f.hflip(image)
53
- if labels is not None:
54
- for exp in labels:
55
- labels[exp] = transforms_f.hflip(labels[exp])
56
-
57
- # random augmentation
58
- image, labels = self.randaugment(image, labels)
59
-
60
- # transform to tensor
61
- image = transforms_f.to_tensor(image)
62
- if labels is not None:
63
- for exp in labels:
64
- if exp in ['depth', 'normal', 'edge']:
65
- labels[exp] = transforms_f.to_tensor(labels[exp])
66
- else:
67
- labels[exp] = (transforms_f.to_tensor(labels[exp]) * 255).long()
68
-
69
- # apply normalisation:
70
- image = transforms_f.normalize(image, mean=[0.48145466, 0.4578275, 0.40821073],
71
- std=[0.26862954, 0.26130258, 0.27577711])
72
- if labels is not None:
73
- return {'rgb': image, **labels}
74
- else:
75
- return{'rgb': image}
76
-
77
-
78
- def get_expert_labels(data_path, label_path, image_path, dataset, experts):
79
- image_full_path = os.path.join(data_path, dataset, image_path)
80
- image = Image.open(image_full_path).convert('RGB')
81
- if experts != 'none':
82
- labels = {}
83
- labels_info = {}
84
- ps = image_path.split('.')[-1]
85
- for exp in experts:
86
- if exp in ['seg_coco', 'seg_ade', 'edge', 'depth']:
87
- label_full_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.png'))
88
- if os.stat(label_full_path).st_size > 0:
89
- labels[exp] = Image.open(label_full_path).convert('L')
90
- else:
91
- labels[exp] = Image.fromarray(np.zeros([image.size[1], image.size[0]])).convert('L')
92
- elif exp == 'normal':
93
- label_full_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.png'))
94
- if os.stat(label_full_path).st_size > 0:
95
- labels[exp] = Image.open(label_full_path).convert('RGB')
96
- else:
97
- labels[exp] = Image.fromarray(np.zeros([image.size[1], image.size[0], 3])).convert('RGB')
98
- elif exp == 'obj_detection':
99
- label_full_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.png'))
100
- if os.stat(label_full_path).st_size > 0:
101
- labels[exp] = Image.open(label_full_path).convert('L')
102
- else:
103
- labels[exp] = Image.fromarray(255 * np.ones([image.size[1], image.size[0]])).convert('L')
104
- label_info_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.json'))
105
- labels_info[exp] = json.load(open(label_info_path, 'r'))
106
- elif exp == 'ocr_detection':
107
- label_full_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.png'))
108
- label_info_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.pt'))
109
- if os.path.exists(label_info_path):
110
- labels[exp] = Image.open(label_full_path).convert('L')
111
- labels_info[exp] = torch.load(label_info_path)
112
- else:
113
- labels[exp] = Image.fromarray(255 * np.ones([image.size[1], image.size[0]])).convert('L')
114
- labels_info[exp] = None
115
-
116
- else:
117
- labels, labels_info = None, None
118
- return image, labels, labels_info
119
-
120
-
121
- def post_label_process(inputs, labels_info):
122
- eps = 1e-6
123
- for exp in inputs:
124
- if exp in ['depth', 'normal', 'edge']: # remap to -1 to 1 range
125
- inputs[exp] = 2 * (inputs[exp] - inputs[exp].min()) / (inputs[exp].max() - inputs[exp].min() + eps) - 1
126
-
127
- elif exp == 'seg_coco': # in-paint with CLIP features
128
- text_emb = torch.empty([64, *inputs[exp].shape[1:]])
129
- for l in inputs[exp].unique():
130
- if l == 255:
131
- text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
132
- else:
133
- text_emb[:, (inputs[exp][0] == l)] = COCO_FEATURES[l].unsqueeze(-1)
134
- inputs[exp] = text_emb
135
-
136
- elif exp == 'seg_ade': # in-paint with CLIP features
137
- text_emb = torch.empty([64, *inputs[exp].shape[1:]])
138
- for l in inputs[exp].unique():
139
- if l == 255:
140
- text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
141
- else:
142
- text_emb[:, (inputs[exp][0] == l)] = ADE_FEATURES[l].unsqueeze(-1)
143
- inputs[exp] = text_emb
144
-
145
- elif exp == 'obj_detection': # in-paint with CLIP features
146
- text_emb = torch.empty([64, *inputs[exp].shape[1:]])
147
- label_map = labels_info[exp]
148
- for l in inputs[exp].unique():
149
- if l == 255:
150
- text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
151
- else:
152
- text_emb[:, (inputs[exp][0] == l)] = DETECTION_FEATURES[label_map[str(l.item())]].unsqueeze(-1)
153
- inputs[exp] = {'label': text_emb, 'instance': inputs[exp]}
154
-
155
- elif exp == 'ocr_detection': # in-paint with CLIP features
156
- text_emb = torch.empty([64, *inputs[exp].shape[1:]])
157
- label_map = labels_info[exp]
158
- for l in inputs[exp].unique():
159
- if l == 255:
160
- text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
161
- else:
162
- text_emb[:, (inputs[exp][0] == l)] = label_map[l.item()]['features'].unsqueeze(-1)
163
- inputs[exp] = text_emb
164
- return inputs
165
-
166
-
167
- def pre_caption(caption, max_words=50):
168
- caption = re.sub(r"([.!\"()*#:;~])", ' ', caption.capitalize()) # remove special characters
169
- caption = re.sub(r"\s{2,}", ' ', caption) # remove two white spaces
170
-
171
- caption = caption.rstrip('\n') # remove \num_ans_per_q symbol
172
- caption = caption.strip(' ') # remove leading and trailing white spaces
173
-
174
- # truncate caption to the max words
175
- caption_words = caption.split(' ')
176
- if len(caption_words) > max_words:
177
- caption = ' '.join(caption_words[:max_words])
178
- return caption
179
-
180
-
181
- def pre_question(question, max_words=50):
182
- question = re.sub(r"([.!\"()*#:;~])", ' ', question.capitalize()) # remove special characters
183
- question = question.strip()
184
-
185
- # truncate question
186
- question_words = question.split(' ')
187
- if len(question_words) > max_words:
188
- question = ' '.join(question_words[:max_words])
189
- if question[-1] != '?':
190
- question += '?'
191
- return question
192
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/dataset/vqa_dataset.py DELETED
@@ -1,53 +0,0 @@
1
- # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
- #
3
- # This work is made available under the Nvidia Source Code License-NC.
4
- # To view a copy of this license, visit
5
- # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
-
7
- from torch.utils.data import Dataset
8
- from dataset.utils import *
9
- from PIL import ImageFile
10
- ImageFile.LOAD_TRUNCATED_IMAGES = True
11
-
12
-
13
- class VQA(Dataset):
14
- def __init__(self, config, train=True):
15
- self.data_path = config['data_path']
16
- self.label_path = config['label_path']
17
- self.experts = config['experts']
18
- self.transform = Transform(resize_resolution=config['image_resolution'], scale_size=[0.5, 1.0], train=train)
19
- self.train = train
20
-
21
- if train:
22
- self.data_list = []
23
- if 'vqav2' in config['datasets']:
24
- self.data_list += json.load(open(os.path.join(self.data_path, 'vqav2_train_val.json'), 'r'))
25
- if 'vg' in config['datasets']:
26
- self.data_list += json.load(open(os.path.join(self.data_path, 'vg_qa.json'), 'r'))
27
- else:
28
- self.data_list = json.load(open(os.path.join(self.data_path, 'vqav2_test.json'), 'r'))
29
- self.answer_list = json.load(open(os.path.join(self.data_path, 'answer_list.json'), 'r'))
30
-
31
- def __len__(self):
32
- return len(self.data_list)
33
-
34
- def __getitem__(self, index):
35
- data = self.data_list[index]
36
-
37
- if data['dataset'] == 'vqa':
38
- image, labels, labels_info = get_expert_labels(self.data_path, self.label_path, data['image'], 'vqav2', self.experts)
39
- elif data['dataset'] == 'vg':
40
- image, labels, labels_info = get_expert_labels(self.data_path, self.label_path, data['image'], 'vg', self.experts)
41
-
42
- experts = self.transform(image, labels)
43
- experts = post_label_process(experts, labels_info)
44
-
45
- if self.train:
46
- question = pre_question(data['question'], max_words=30)
47
- answers = data['answer']
48
- weights = torch.tensor(data['weight']) if data['dataset'] != 'vg' else torch.tensor(0.2)
49
- return experts, question, answers, weights
50
- else:
51
- question = pre_question(data['question'], max_words=30)
52
- question_id = data['question_id']
53
- return experts, index, question, question_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/demo.py DELETED
@@ -1,77 +0,0 @@
1
- import os
2
- import argparse
3
- import torch
4
- try:
5
- import ruamel_yaml as yaml
6
- except ModuleNotFoundError:
7
- import ruamel.yaml as yaml
8
-
9
-
10
- from model.prismer_caption import PrismerCaption
11
- from dataset import create_dataset, create_loader
12
- from tqdm import tqdm
13
-
14
- parser = argparse.ArgumentParser()
15
- parser.add_argument('--mode', default='')
16
- parser.add_argument('--port', default='')
17
-
18
- parser.add_argument('--exp_name', default='', type=str)
19
- args = parser.parse_args()
20
-
21
- # load config
22
- config = yaml.load(open('configs/caption.yaml', 'r'), Loader=yaml.Loader)['demo']
23
-
24
- # generate expert labels
25
- if len(config['experts']) > 0:
26
- script_name = f'python experts/generate_depth.py'
27
- os.system(script_name)
28
- print('***** Generated Depth *****')
29
-
30
- script_name = f'python experts/generate_edge.py'
31
- os.system(script_name)
32
- print('***** Generated Edge *****')
33
-
34
- script_name = f'python experts/generate_normal.py'
35
- os.system(script_name)
36
- print('***** Generated Surface Normals *****')
37
-
38
- script_name = f'python experts/generate_objdet.py'
39
- os.system(script_name)
40
- print('***** Generated Object Detection Labels *****')
41
-
42
- script_name = f'python experts/generate_ocrdet.py'
43
- os.system(script_name)
44
- print('***** Generated OCR Detection Labels *****')
45
-
46
- script_name = f'python experts/generate_segmentation.py'
47
- os.system(script_name)
48
- print('***** Generated Segmentation Labels *****')
49
-
50
- # load datasets
51
- _, test_dataset = create_dataset('caption', config)
52
- test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
53
-
54
- # load pre-trained model
55
- model = PrismerCaption(config)
56
- state_dict = torch.load(f'logging/caption_{args.exp_name}/pytorch_model.bin', map_location='cuda:0')
57
- model.load_state_dict(state_dict)
58
- tokenizer = model.tokenizer
59
-
60
- # inference
61
- model.eval()
62
- with torch.no_grad():
63
- for step, (experts, data_ids) in enumerate(tqdm(test_loader)):
64
- captions = model(experts, train=False, prefix=config['prefix'])
65
-
66
- captions = tokenizer(captions, max_length=30, padding='max_length', return_tensors='pt').input_ids
67
- caption = captions.to(experts['rgb'].device)[0]
68
-
69
- caption = tokenizer.decode(caption, skip_special_tokens=True)
70
- caption = caption.capitalize() + '.'
71
-
72
- # save caption
73
- save_path = test_loader.dataset.data_list[data_ids[0]]['image'].replace('jpg', 'txt')
74
- with open(save_path, 'w') as f:
75
- f.write(caption)
76
-
77
- print('All Done.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/demo_vis.py DELETED
@@ -1,161 +0,0 @@
1
- import glob
2
- import os
3
- import json
4
- import torch
5
- import random
6
- import matplotlib.pyplot as plt
7
- import numpy as np
8
-
9
- from utils import create_ade20k_label_colormap
10
-
11
- obj_label_map = torch.load('dataset/detection_features.pt')['labels']
12
- coco_label_map = torch.load('dataset/coco_features.pt')['labels']
13
- ade_color = create_ade20k_label_colormap()
14
-
15
- file_path = 'helpers/images'
16
- expert_path = 'helpers/labels'
17
- plt.ioff()
18
-
19
-
20
- def get_label_path(file_name, expert_name, with_suffix=False):
21
- file_suffix = '.png' if not with_suffix else '_.png'
22
- label_name = ''.join(file_name.split('.')[:-1] + [file_suffix])
23
- label_path = os.path.join(expert_path, expert_name, label_name)
24
- return label_path
25
-
26
-
27
- def depth_prettify(file_name):
28
- label_path = get_label_path(file_name, 'depth')
29
- save_path = get_label_path(file_name, 'depth', True)
30
- depth = plt.imread(label_path)
31
- plt.imsave(save_path, depth, cmap='rainbow')
32
-
33
-
34
- def obj_detection_prettify(file_name):
35
- label_path = get_label_path(file_name, 'obj_detection')
36
- save_path = get_label_path(file_name, 'obj_detection', True)
37
-
38
- rgb = plt.imread(file_name)
39
- obj_labels = plt.imread(label_path)
40
- obj_labels_dict = json.load(open(label_path.replace('.png', '.json')))
41
-
42
- plt.imshow(rgb)
43
-
44
- num_objs = np.unique(obj_labels)[:-1].max()
45
- plt.imshow(obj_labels, cmap='terrain', vmax=num_objs + 1 / 255., alpha=0.5)
46
-
47
- for i in np.unique(obj_labels)[:-1]:
48
- obj_idx_all = np.where(obj_labels == i)
49
- obj_idx = random.randint(0, len(obj_idx_all[0]))
50
- x, y = obj_idx_all[1][obj_idx], obj_idx_all[0][obj_idx]
51
- obj_name = obj_label_map[obj_labels_dict[str(int(i * 255))]]
52
- plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center')
53
-
54
- plt.axis('off')
55
- plt.savefig(save_path, bbox_inches='tight', transparent=True, pad_inches=0)
56
- plt.close()
57
-
58
-
59
- def seg_prettify(file_name):
60
- label_path = get_label_path(file_name, 'seg_coco')
61
- save_path = get_label_path(file_name, 'seg_coco', True)
62
-
63
- rgb = plt.imread(file_name)
64
- seg_labels = plt.imread(label_path)
65
-
66
- plt.imshow(rgb)
67
-
68
- seg_map = np.zeros(list(seg_labels.shape) + [3], dtype=np.int16)
69
- for i in np.unique(seg_labels):
70
- seg_map[seg_labels == i] = ade_color[int(i * 255)]
71
-
72
- plt.imshow(seg_map, alpha=0.5)
73
-
74
- for i in np.unique(seg_labels):
75
- obj_idx_all = np.where(seg_labels == i)
76
- obj_idx = random.randint(0, len(obj_idx_all[0]))
77
- x, y = obj_idx_all[1][obj_idx], obj_idx_all[0][obj_idx]
78
- obj_name = coco_label_map[int(i * 255)]
79
- plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center')
80
-
81
- plt.axis('off')
82
- plt.savefig(save_path, bbox_inches='tight', transparent=True, pad_inches=0)
83
- plt.close()
84
-
85
-
86
- def ocr_detection_prettify(file_name):
87
- label_path = get_label_path(file_name, 'ocr_detection')
88
- save_path = get_label_path(file_name, 'ocr_detection', True)
89
-
90
- if os.path.exists(label_path):
91
- rgb = plt.imread(file_name)
92
- ocr_labels = plt.imread(label_path)
93
- ocr_labels_dict = torch.load(label_path.replace('.png', '.pt'))
94
-
95
- plt.imshow(rgb)
96
- plt.imshow((1 - ocr_labels) < 1, cmap='gray', alpha=0.8)
97
-
98
- for i in np.unique(ocr_labels)[:-1]:
99
- text_idx_all = np.where(ocr_labels == i)
100
- x, y = text_idx_all[1].mean(), text_idx_all[0].mean()
101
- text = ocr_labels_dict[int(i * 255)]['text']
102
- plt.text(x, y, text, c='white', horizontalalignment='center', verticalalignment='center')
103
-
104
- plt.axis('off')
105
- plt.savefig(save_path, bbox_inches='tight', transparent=True, pad_inches=0)
106
- plt.close()
107
- else:
108
- rgb = plt.imread(file_name)
109
- ocr_labels = np.ones_like(rgb, dtype=np.float32())
110
-
111
- plt.imshow(rgb)
112
- plt.imshow(ocr_labels, cmap='gray', alpha=0.8)
113
-
114
- x, y = rgb.shape[1] / 2, rgb.shape[0] / 2
115
- plt.text(x, y, 'No text detected', c='black', horizontalalignment='center', verticalalignment='center')
116
-
117
- plt.axis('off')
118
- plt.savefig(save_path, bbox_inches='tight', transparent=True, pad_inches=0)
119
- plt.close()
120
-
121
-
122
- im_list = glob.glob(file_path + '/*.jpg') + glob.glob(file_path + '/*.png') + glob.glob(file_path + '/*.jpeg')
123
-
124
- # prettify labels first:
125
- for i in range(len(im_list)):
126
- depth_prettify(im_list[i])
127
- seg_prettify(im_list[i])
128
- ocr_detection_prettify(im_list[i])
129
- obj_detection_prettify(im_list[i])
130
-
131
- pretty = {'depth': True, 'normal': False, 'edge': False,
132
- 'obj_detection': True, 'ocr_detection': True, 'seg_coco': True}
133
-
134
- # plot expert labels
135
- for im_path in im_list:
136
- fig, axs = plt.subplots(1, 7, figsize=(20, 4))
137
- rgb = plt.imread(im_path)
138
- axs[0].imshow(rgb)
139
- axs[0].axis('off')
140
- axs[0].set_title('RGB')
141
-
142
- for j in range(6):
143
- label_name = list(pretty.keys())[j]
144
- label_path = get_label_path(im_path, label_name, with_suffix=pretty[label_name])
145
- label = plt.imread(label_path)
146
- if label_name != 'edge':
147
- axs[j + 1].imshow(label)
148
- else:
149
- axs[j + 1].imshow(label, cmap='gray')
150
-
151
- axs[j + 1].axis('off')
152
- axs[j + 1].set_title(label_name)
153
-
154
- caption_path = ''.join(im_path.split('.')[:-1] + ['.txt'])
155
- with open(caption_path) as f:
156
- caption = f.readlines()[0]
157
-
158
- plt.suptitle(caption)
159
- plt.tight_layout()
160
-
161
- plt.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/download_checkpoints.py DELETED
@@ -1,124 +0,0 @@
1
- from huggingface_hub import hf_hub_download, hf_hub_url, get_hf_file_metadata
2
- from huggingface_hub.utils import disable_progress_bars
3
- from pathlib import Path
4
- from rich.progress import Progress
5
- from fire import Fire
6
- from typing import Union, List
7
-
8
- _EXPERTS = [
9
- "10_model.pth",
10
- "Unified_learned_OCIM_RS200_6x+2x.pth",
11
- "dpt_hybrid-midas-501f0c75.pt",
12
- "icdar2015_hourglass88.pth",
13
- "model_final_e0c58e.pkl",
14
- "model_final_f07440.pkl",
15
- "scannet.pt",
16
- ]
17
-
18
- _MODELS = [
19
- "vqa_prismer_base",
20
- "vqa_prismer_large",
21
- "vqa_prismerz_base",
22
- "vqa_prismerz_large",
23
- "caption_prismerz_base",
24
- "caption_prismerz_large",
25
- "caption_prismer_base",
26
- "caption_prismer_large",
27
- "pretrain_prismer_base",
28
- "pretrain_prismer_large",
29
- "pretrain_prismerz_base",
30
- "pretrain_prismerz_large",
31
- ]
32
-
33
- _REPO_ID = "lorenmt/prismer"
34
-
35
-
36
- def download_checkpoints(
37
- download_experts: bool = False,
38
- download_models: Union[bool, List] = False,
39
- hide_tqdm: bool = False,
40
- force_redownload: bool = False,
41
- ):
42
- if hide_tqdm:
43
- disable_progress_bars()
44
- # Convert to list and check for invalid names
45
- download_experts = _EXPERTS if download_experts else []
46
- if download_models:
47
- # only download single model
48
- if isinstance(download_models, str):
49
- download_models = [download_models]
50
-
51
- assert all([m in _MODELS for m in download_models]), f"Invalid model name. Must be one of {_MODELS}"
52
- download_models = _MODELS if isinstance(download_models, bool) else download_models
53
- else:
54
- download_models = []
55
-
56
- # Check if files already exist
57
- if not force_redownload:
58
- download_experts = [e for e in download_experts if not Path(f"./experts/expert_weights/{e}").exists()]
59
- download_models = [m for m in download_models if not Path(f"{m}/pytorch_model.bin").exists()]
60
-
61
- assert download_experts or download_models, "Nothing to download."
62
-
63
- with Progress() as progress:
64
- # Calculate total download size
65
- progress.print("[blue]Calculating download size...")
66
- total_size = 0
67
- for expert in download_experts:
68
- url = hf_hub_url(
69
- filename=expert,
70
- repo_id=_REPO_ID,
71
- subfolder="expert_weights"
72
- )
73
- total_size += get_hf_file_metadata(url).size
74
-
75
- for model in download_models:
76
- url = hf_hub_url(
77
- filename=f"pytorch_model.bin",
78
- repo_id=_REPO_ID,
79
- subfolder=model
80
- )
81
- total_size += get_hf_file_metadata(url).size
82
- progress.print(f"[blue]Total download size: {total_size / 1e9:.2f} GB")
83
-
84
- # Download files
85
- total_files = len(download_experts) + len(download_models)
86
- total_task = progress.add_task(f"[green]Downloading files", total=total_files)
87
- if download_experts:
88
- expert_task = progress.add_task(
89
- f"[green]Downloading experts...", total=len(download_experts)
90
- )
91
- out_folder = Path("experts/expert_weights")
92
- out_folder.mkdir(parents=True, exist_ok=True)
93
- for expert in download_experts:
94
- path = Path(hf_hub_download(
95
- filename=expert,
96
- repo_id=_REPO_ID,
97
- subfolder="expert_weights"
98
- ))
99
- path.resolve().rename(out_folder/path.name)
100
- path.unlink()
101
- progress.advance(expert_task)
102
- progress.advance(total_task)
103
-
104
- if download_models:
105
- model_task = progress.add_task(
106
- f"[green]Downloading models...", total=len(download_models)
107
- )
108
- for model in download_models:
109
- path = Path(hf_hub_download(
110
- filename=f"pytorch_model.bin",
111
- repo_id=_REPO_ID,
112
- subfolder=model
113
- ))
114
- out_folder = Path("./logging")/model
115
- out_folder.mkdir(parents=True, exist_ok=True)
116
- path.resolve().rename(out_folder/"pytorch_model.bin")
117
- path.unlink()
118
- progress.advance(model_task)
119
- progress.advance(total_task)
120
- progress.print("[green]Done!")
121
-
122
-
123
- if __name__ == "__main__":
124
- Fire(download_checkpoints)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/depth/base_model.py DELETED
@@ -1,16 +0,0 @@
1
- import torch
2
-
3
-
4
- class BaseModel(torch.nn.Module):
5
- def load(self, path):
6
- """Load model from file.
7
-
8
- Args:
9
- path (str): file path
10
- """
11
- parameters = torch.load(path, map_location=torch.device("cpu"))
12
-
13
- if "optimizer" in parameters:
14
- parameters = parameters["model"]
15
-
16
- self.load_state_dict(parameters)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/depth/blocks.py DELETED
@@ -1,383 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- from .vit import (
5
- _make_pretrained_vitb_rn50_384,
6
- _make_pretrained_vitl16_384,
7
- _make_pretrained_vitb16_384,
8
- forward_vit,
9
- )
10
-
11
-
12
- def _make_encoder(
13
- backbone,
14
- features,
15
- use_pretrained,
16
- groups=1,
17
- expand=False,
18
- exportable=True,
19
- hooks=None,
20
- use_vit_only=False,
21
- use_readout="ignore",
22
- enable_attention_hooks=False,
23
- ):
24
- if backbone == "vitl16_384":
25
- pretrained = _make_pretrained_vitl16_384(
26
- use_pretrained,
27
- hooks=hooks,
28
- use_readout=use_readout,
29
- enable_attention_hooks=enable_attention_hooks,
30
- )
31
- scratch = _make_scratch(
32
- [256, 512, 1024, 1024], features, groups=groups, expand=expand
33
- ) # ViT-L/16 - 85.0% Top1 (backbone)
34
- elif backbone == "vitb_rn50_384":
35
- pretrained = _make_pretrained_vitb_rn50_384(
36
- use_pretrained,
37
- hooks=hooks,
38
- use_vit_only=use_vit_only,
39
- use_readout=use_readout,
40
- enable_attention_hooks=enable_attention_hooks,
41
- )
42
- scratch = _make_scratch(
43
- [256, 512, 768, 768], features, groups=groups, expand=expand
44
- ) # ViT-H/16 - 85.0% Top1 (backbone)
45
- elif backbone == "vitb16_384":
46
- pretrained = _make_pretrained_vitb16_384(
47
- use_pretrained,
48
- hooks=hooks,
49
- use_readout=use_readout,
50
- enable_attention_hooks=enable_attention_hooks,
51
- )
52
- scratch = _make_scratch(
53
- [96, 192, 384, 768], features, groups=groups, expand=expand
54
- ) # ViT-B/16 - 84.6% Top1 (backbone)
55
- elif backbone == "resnext101_wsl":
56
- pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
57
- scratch = _make_scratch(
58
- [256, 512, 1024, 2048], features, groups=groups, expand=expand
59
- ) # efficientnet_lite3
60
- else:
61
- print(f"Backbone '{backbone}' not implemented")
62
- assert False
63
-
64
- return pretrained, scratch
65
-
66
-
67
- def _make_scratch(in_shape, out_shape, groups=1, expand=False):
68
- scratch = nn.Module()
69
-
70
- out_shape1 = out_shape
71
- out_shape2 = out_shape
72
- out_shape3 = out_shape
73
- out_shape4 = out_shape
74
- if expand == True:
75
- out_shape1 = out_shape
76
- out_shape2 = out_shape * 2
77
- out_shape3 = out_shape * 4
78
- out_shape4 = out_shape * 8
79
-
80
- scratch.layer1_rn = nn.Conv2d(
81
- in_shape[0],
82
- out_shape1,
83
- kernel_size=3,
84
- stride=1,
85
- padding=1,
86
- bias=False,
87
- groups=groups,
88
- )
89
- scratch.layer2_rn = nn.Conv2d(
90
- in_shape[1],
91
- out_shape2,
92
- kernel_size=3,
93
- stride=1,
94
- padding=1,
95
- bias=False,
96
- groups=groups,
97
- )
98
- scratch.layer3_rn = nn.Conv2d(
99
- in_shape[2],
100
- out_shape3,
101
- kernel_size=3,
102
- stride=1,
103
- padding=1,
104
- bias=False,
105
- groups=groups,
106
- )
107
- scratch.layer4_rn = nn.Conv2d(
108
- in_shape[3],
109
- out_shape4,
110
- kernel_size=3,
111
- stride=1,
112
- padding=1,
113
- bias=False,
114
- groups=groups,
115
- )
116
-
117
- return scratch
118
-
119
-
120
- def _make_resnet_backbone(resnet):
121
- pretrained = nn.Module()
122
- pretrained.layer1 = nn.Sequential(
123
- resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
124
- )
125
-
126
- pretrained.layer2 = resnet.layer2
127
- pretrained.layer3 = resnet.layer3
128
- pretrained.layer4 = resnet.layer4
129
-
130
- return pretrained
131
-
132
-
133
- def _make_pretrained_resnext101_wsl(use_pretrained):
134
- resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
135
- return _make_resnet_backbone(resnet)
136
-
137
-
138
- class Interpolate(nn.Module):
139
- """Interpolation module."""
140
-
141
- def __init__(self, scale_factor, mode, align_corners=False):
142
- """Init.
143
-
144
- Args:
145
- scale_factor (float): scaling
146
- mode (str): interpolation mode
147
- """
148
- super(Interpolate, self).__init__()
149
-
150
- self.interp = nn.functional.interpolate
151
- self.scale_factor = scale_factor
152
- self.mode = mode
153
- self.align_corners = align_corners
154
-
155
- def forward(self, x):
156
- """Forward pass.
157
-
158
- Args:
159
- x (tensor): input
160
-
161
- Returns:
162
- tensor: interpolated data_list
163
- """
164
-
165
- x = self.interp(
166
- x,
167
- scale_factor=self.scale_factor,
168
- mode=self.mode,
169
- align_corners=self.align_corners,
170
- )
171
-
172
- return x
173
-
174
-
175
- class ResidualConvUnit(nn.Module):
176
- """Residual convolution module."""
177
-
178
- def __init__(self, features):
179
- """Init.
180
-
181
- Args:
182
- features (int): number of features
183
- """
184
- super().__init__()
185
-
186
- self.conv1 = nn.Conv2d(
187
- features, features, kernel_size=3, stride=1, padding=1, bias=True
188
- )
189
-
190
- self.conv2 = nn.Conv2d(
191
- features, features, kernel_size=3, stride=1, padding=1, bias=True
192
- )
193
-
194
- self.relu = nn.ReLU(inplace=True)
195
-
196
- def forward(self, x):
197
- """Forward pass.
198
-
199
- Args:
200
- x (tensor): input
201
-
202
- Returns:
203
- tensor: output
204
- """
205
- out = self.relu(x)
206
- out = self.conv1(out)
207
- out = self.relu(out)
208
- out = self.conv2(out)
209
-
210
- return out + x
211
-
212
-
213
- class FeatureFusionBlock(nn.Module):
214
- """Feature fusion block."""
215
-
216
- def __init__(self, features):
217
- """Init.
218
-
219
- Args:
220
- features (int): number of features
221
- """
222
- super(FeatureFusionBlock, self).__init__()
223
-
224
- self.resConfUnit1 = ResidualConvUnit(features)
225
- self.resConfUnit2 = ResidualConvUnit(features)
226
-
227
- def forward(self, *xs):
228
- """Forward pass.
229
-
230
- Returns:
231
- tensor: output
232
- """
233
- output = xs[0]
234
-
235
- if len(xs) == 2:
236
- output += self.resConfUnit1(xs[1])
237
-
238
- output = self.resConfUnit2(output)
239
-
240
- output = nn.functional.interpolate(
241
- output, scale_factor=2, mode="bilinear", align_corners=True
242
- )
243
-
244
- return output
245
-
246
-
247
- class ResidualConvUnit_custom(nn.Module):
248
- """Residual convolution module."""
249
-
250
- def __init__(self, features, activation, bn):
251
- """Init.
252
-
253
- Args:
254
- features (int): number of features
255
- """
256
- super().__init__()
257
-
258
- self.bn = bn
259
-
260
- self.groups = 1
261
-
262
- self.conv1 = nn.Conv2d(
263
- features,
264
- features,
265
- kernel_size=3,
266
- stride=1,
267
- padding=1,
268
- bias=not self.bn,
269
- groups=self.groups,
270
- )
271
-
272
- self.conv2 = nn.Conv2d(
273
- features,
274
- features,
275
- kernel_size=3,
276
- stride=1,
277
- padding=1,
278
- bias=not self.bn,
279
- groups=self.groups,
280
- )
281
-
282
- if self.bn == True:
283
- self.bn1 = nn.BatchNorm2d(features)
284
- self.bn2 = nn.BatchNorm2d(features)
285
-
286
- self.activation = activation
287
-
288
- self.skip_add = nn.quantized.FloatFunctional()
289
-
290
- def forward(self, x):
291
- """Forward pass.
292
-
293
- Args:
294
- x (tensor): input
295
-
296
- Returns:
297
- tensor: output
298
- """
299
-
300
- out = self.activation(x)
301
- out = self.conv1(out)
302
- if self.bn == True:
303
- out = self.bn1(out)
304
-
305
- out = self.activation(out)
306
- out = self.conv2(out)
307
- if self.bn == True:
308
- out = self.bn2(out)
309
-
310
- if self.groups > 1:
311
- out = self.conv_merge(out)
312
-
313
- return self.skip_add.add(out, x)
314
-
315
- # return out + x
316
-
317
-
318
- class FeatureFusionBlock_custom(nn.Module):
319
- """Feature fusion block."""
320
-
321
- def __init__(
322
- self,
323
- features,
324
- activation,
325
- deconv=False,
326
- bn=False,
327
- expand=False,
328
- align_corners=True,
329
- ):
330
- """Init.
331
-
332
- Args:
333
- features (int): number of features
334
- """
335
- super(FeatureFusionBlock_custom, self).__init__()
336
-
337
- self.deconv = deconv
338
- self.align_corners = align_corners
339
-
340
- self.groups = 1
341
-
342
- self.expand = expand
343
- out_features = features
344
- if self.expand == True:
345
- out_features = features // 2
346
-
347
- self.out_conv = nn.Conv2d(
348
- features,
349
- out_features,
350
- kernel_size=1,
351
- stride=1,
352
- padding=0,
353
- bias=True,
354
- groups=1,
355
- )
356
-
357
- self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
358
- self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
359
-
360
- self.skip_add = nn.quantized.FloatFunctional()
361
-
362
- def forward(self, *xs):
363
- """Forward pass.
364
-
365
- Returns:
366
- tensor: output
367
- """
368
- output = xs[0]
369
-
370
- if len(xs) == 2:
371
- res = self.resConfUnit1(xs[1])
372
- output = self.skip_add.add(output, res)
373
- # output += res
374
-
375
- output = self.resConfUnit2(output)
376
-
377
- output = nn.functional.interpolate(
378
- output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
379
- )
380
-
381
- output = self.out_conv(output)
382
-
383
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/depth/generate_dataset.py DELETED
@@ -1,32 +0,0 @@
1
- # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
- #
3
- # This work is made available under the Nvidia Source Code License-NC.
4
- # To view a copy of this license, visit
5
- # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
-
7
- import glob
8
-
9
- from torch.utils.data import Dataset
10
- from PIL import ImageFile
11
- from dataset.utils import *
12
-
13
- ImageFile.LOAD_TRUNCATED_IMAGES = True
14
-
15
-
16
- class Dataset(Dataset):
17
- def __init__(self, data_path, transform):
18
- self.data_path = data_path
19
- self.transform = transform
20
- data_folders = glob.glob(f'{data_path}/*/')
21
- self.data_list = [data for f in data_folders for data in glob.glob(f + '*.JPEG')]
22
- self.data_list += [data for f in data_folders for data in glob.glob(f + '*.jpg')]
23
-
24
- def __len__(self):
25
- return len(self.data_list)
26
-
27
- def __getitem__(self, index):
28
- image_path = self.data_list[index]
29
- image = Image.open(image_path).convert('RGB')
30
- img_size = [image.size[0], image.size[1]]
31
- image = self.transform(image)
32
- return image, image_path, img_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/depth/models.py DELETED
@@ -1,124 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- from .base_model import BaseModel
6
- from .blocks import (
7
- FeatureFusionBlock,
8
- FeatureFusionBlock_custom,
9
- Interpolate,
10
- _make_encoder,
11
- forward_vit,
12
- )
13
-
14
-
15
- def _make_fusion_block(features, use_bn):
16
- return FeatureFusionBlock_custom(
17
- features,
18
- nn.ReLU(False),
19
- deconv=False,
20
- bn=use_bn,
21
- expand=False,
22
- align_corners=True,
23
- )
24
-
25
-
26
- class DPT(BaseModel):
27
- def __init__(
28
- self,
29
- head,
30
- features=256,
31
- backbone="vitb_rn50_384",
32
- readout="project",
33
- channels_last=False,
34
- use_bn=False,
35
- enable_attention_hooks=False,
36
- ):
37
-
38
- super(DPT, self).__init__()
39
-
40
- self.channels_last = channels_last
41
-
42
- hooks = {
43
- "vitb_rn50_384": [0, 1, 8, 11],
44
- "vitb16_384": [2, 5, 8, 11],
45
- "vitl16_384": [5, 11, 17, 23],
46
- }
47
-
48
- # Instantiate backbone and reassemble blocks
49
- self.pretrained, self.scratch = _make_encoder(
50
- backbone,
51
- features,
52
- False, # Set to true of you want to train from scratch, uses ImageNet weights
53
- groups=1,
54
- expand=False,
55
- exportable=False,
56
- hooks=hooks[backbone],
57
- use_readout=readout,
58
- enable_attention_hooks=enable_attention_hooks,
59
- )
60
-
61
- self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
62
- self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
63
- self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
64
- self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
65
-
66
- self.scratch.output_conv = head
67
-
68
- def forward(self, x):
69
- if self.channels_last == True:
70
- x.contiguous(memory_format=torch.channels_last)
71
-
72
- layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
73
-
74
- layer_1_rn = self.scratch.layer1_rn(layer_1)
75
- layer_2_rn = self.scratch.layer2_rn(layer_2)
76
- layer_3_rn = self.scratch.layer3_rn(layer_3)
77
- layer_4_rn = self.scratch.layer4_rn(layer_4)
78
-
79
- path_4 = self.scratch.refinenet4(layer_4_rn)
80
- path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
81
- path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
82
- path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
83
-
84
- out = self.scratch.output_conv(path_1)
85
-
86
- return out
87
-
88
-
89
- class DPTDepthModel(DPT):
90
- def __init__(
91
- self, path=None, non_negative=True, scale=1.0, shift=0.0, invert=False, **kwargs
92
- ):
93
- features = kwargs["features"] if "features" in kwargs else 256
94
-
95
- self.scale = scale
96
- self.shift = shift
97
- self.invert = invert
98
-
99
- head = nn.Sequential(
100
- nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
101
- Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
102
- nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
103
- nn.ReLU(True),
104
- nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
105
- nn.ReLU(True) if non_negative else nn.Identity(),
106
- nn.Identity(),
107
- )
108
-
109
- super().__init__(head, **kwargs)
110
-
111
- if path is not None:
112
- self.load(path)
113
-
114
- def forward(self, x):
115
- inv_depth = super().forward(x).squeeze(dim=1)
116
-
117
- if self.invert:
118
- depth = self.scale * inv_depth + self.shift
119
- depth[depth < 1e-8] = 1e-8
120
- depth = 1.0 / depth
121
- return depth
122
- else:
123
- return inv_depth
124
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/depth/vit.py DELETED
@@ -1,576 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import timm
4
- import types
5
- import math
6
- import torch.nn.functional as F
7
-
8
-
9
- activations = {}
10
-
11
-
12
- def get_activation(name):
13
- def hook(model, input, output):
14
- activations[name] = output
15
-
16
- return hook
17
-
18
-
19
- attention = {}
20
-
21
-
22
- def get_attention(name):
23
- def hook(module, input, output):
24
- x = input[0]
25
- B, N, C = x.shape
26
- qkv = (
27
- module.qkv(x)
28
- .reshape(B, N, 3, module.num_heads, C // module.num_heads)
29
- .permute(2, 0, 3, 1, 4)
30
- )
31
- q, k, v = (
32
- qkv[0],
33
- qkv[1],
34
- qkv[2],
35
- ) # make torchscript happy (cannot use tensor as tuple)
36
-
37
- attn = (q @ k.transpose(-2, -1)) * module.scale
38
-
39
- attn = attn.softmax(dim=-1) # [:,:,1,1:]
40
- attention[name] = attn
41
-
42
- return hook
43
-
44
-
45
- def get_mean_attention_map(attn, token, shape):
46
- attn = attn[:, :, token, 1:]
47
- attn = attn.unflatten(2, torch.Size([shape[2] // 16, shape[3] // 16])).float()
48
- attn = torch.nn.functional.interpolate(
49
- attn, size=shape[2:], mode="bicubic", align_corners=False
50
- ).squeeze(0)
51
-
52
- all_attn = torch.mean(attn, 0)
53
-
54
- return all_attn
55
-
56
-
57
- class Slice(nn.Module):
58
- def __init__(self, start_index=1):
59
- super(Slice, self).__init__()
60
- self.start_index = start_index
61
-
62
- def forward(self, x):
63
- return x[:, self.start_index :]
64
-
65
-
66
- class AddReadout(nn.Module):
67
- def __init__(self, start_index=1):
68
- super(AddReadout, self).__init__()
69
- self.start_index = start_index
70
-
71
- def forward(self, x):
72
- if self.start_index == 2:
73
- readout = (x[:, 0] + x[:, 1]) / 2
74
- else:
75
- readout = x[:, 0]
76
- return x[:, self.start_index :] + readout.unsqueeze(1)
77
-
78
-
79
- class ProjectReadout(nn.Module):
80
- def __init__(self, in_features, start_index=1):
81
- super(ProjectReadout, self).__init__()
82
- self.start_index = start_index
83
-
84
- self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
85
-
86
- def forward(self, x):
87
- readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
88
- features = torch.cat((x[:, self.start_index :], readout), -1)
89
-
90
- return self.project(features)
91
-
92
-
93
- class Transpose(nn.Module):
94
- def __init__(self, dim0, dim1):
95
- super(Transpose, self).__init__()
96
- self.dim0 = dim0
97
- self.dim1 = dim1
98
-
99
- def forward(self, x):
100
- x = x.transpose(self.dim0, self.dim1)
101
- return x
102
-
103
-
104
- def forward_vit(pretrained, x):
105
- b, c, h, w = x.shape
106
-
107
- glob = pretrained.model.forward_flex(x)
108
-
109
- layer_1 = pretrained.activations["1"]
110
- layer_2 = pretrained.activations["2"]
111
- layer_3 = pretrained.activations["3"]
112
- layer_4 = pretrained.activations["4"]
113
-
114
- layer_1 = pretrained.act_postprocess1[0:2](layer_1)
115
- layer_2 = pretrained.act_postprocess2[0:2](layer_2)
116
- layer_3 = pretrained.act_postprocess3[0:2](layer_3)
117
- layer_4 = pretrained.act_postprocess4[0:2](layer_4)
118
-
119
- unflatten = nn.Sequential(
120
- nn.Unflatten(
121
- 2,
122
- torch.Size(
123
- [
124
- h // pretrained.model.patch_size[1],
125
- w // pretrained.model.patch_size[0],
126
- ]
127
- ),
128
- )
129
- )
130
-
131
- if layer_1.ndim == 3:
132
- layer_1 = unflatten(layer_1)
133
- if layer_2.ndim == 3:
134
- layer_2 = unflatten(layer_2)
135
- if layer_3.ndim == 3:
136
- layer_3 = unflatten(layer_3)
137
- if layer_4.ndim == 3:
138
- layer_4 = unflatten(layer_4)
139
-
140
- layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
141
- layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
142
- layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
143
- layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
144
-
145
- return layer_1, layer_2, layer_3, layer_4
146
-
147
-
148
- def _resize_pos_embed(self, posemb, gs_h, gs_w):
149
- posemb_tok, posemb_grid = (
150
- posemb[:, : self.start_index],
151
- posemb[0, self.start_index :],
152
- )
153
-
154
- gs_old = int(math.sqrt(len(posemb_grid)))
155
-
156
- posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
157
- posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
158
- posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
159
-
160
- posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
161
-
162
- return posemb
163
-
164
-
165
- def forward_flex(self, x):
166
- b, c, h, w = x.shape
167
-
168
- pos_embed = self._resize_pos_embed(
169
- self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
170
- )
171
-
172
- B = x.shape[0]
173
-
174
- if hasattr(self.patch_embed, "backbone"):
175
- x = self.patch_embed.backbone(x)
176
- if isinstance(x, (list, tuple)):
177
- x = x[-1] # last feature if backbone outputs list/tuple of features
178
-
179
- x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
180
-
181
- if getattr(self, "dist_token", None) is not None:
182
- cls_tokens = self.cls_token.expand(
183
- B, -1, -1
184
- ) # stole cls_tokens impl from Phil Wang, thanks
185
- dist_token = self.dist_token.expand(B, -1, -1)
186
- x = torch.cat((cls_tokens, dist_token, x), dim=1)
187
- else:
188
- cls_tokens = self.cls_token.expand(
189
- B, -1, -1
190
- ) # stole cls_tokens impl from Phil Wang, thanks
191
- x = torch.cat((cls_tokens, x), dim=1)
192
-
193
- x = x + pos_embed
194
- x = self.pos_drop(x)
195
-
196
- for blk in self.blocks:
197
- x = blk(x)
198
-
199
- x = self.norm(x)
200
-
201
- return x
202
-
203
-
204
- def get_readout_oper(vit_features, features, use_readout, start_index=1):
205
- if use_readout == "ignore":
206
- readout_oper = [Slice(start_index)] * len(features)
207
- elif use_readout == "add":
208
- readout_oper = [AddReadout(start_index)] * len(features)
209
- elif use_readout == "project":
210
- readout_oper = [
211
- ProjectReadout(vit_features, start_index) for out_feat in features
212
- ]
213
- else:
214
- assert (
215
- False
216
- ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
217
-
218
- return readout_oper
219
-
220
-
221
- def _make_vit_b16_backbone(
222
- model,
223
- features=[96, 192, 384, 768],
224
- size=[384, 384],
225
- hooks=[2, 5, 8, 11],
226
- vit_features=768,
227
- use_readout="ignore",
228
- start_index=1,
229
- enable_attention_hooks=False,
230
- ):
231
- pretrained = nn.Module()
232
-
233
- pretrained.model = model
234
- pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
235
- pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
236
- pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
237
- pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
238
-
239
- pretrained.activations = activations
240
-
241
- if enable_attention_hooks:
242
- pretrained.model.blocks[hooks[0]].attn.register_forward_hook(
243
- get_attention("attn_1")
244
- )
245
- pretrained.model.blocks[hooks[1]].attn.register_forward_hook(
246
- get_attention("attn_2")
247
- )
248
- pretrained.model.blocks[hooks[2]].attn.register_forward_hook(
249
- get_attention("attn_3")
250
- )
251
- pretrained.model.blocks[hooks[3]].attn.register_forward_hook(
252
- get_attention("attn_4")
253
- )
254
- pretrained.attention = attention
255
-
256
- readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
257
-
258
- # 32, 48, 136, 384
259
- pretrained.act_postprocess1 = nn.Sequential(
260
- readout_oper[0],
261
- Transpose(1, 2),
262
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
263
- nn.Conv2d(
264
- in_channels=vit_features,
265
- out_channels=features[0],
266
- kernel_size=1,
267
- stride=1,
268
- padding=0,
269
- ),
270
- nn.ConvTranspose2d(
271
- in_channels=features[0],
272
- out_channels=features[0],
273
- kernel_size=4,
274
- stride=4,
275
- padding=0,
276
- bias=True,
277
- dilation=1,
278
- groups=1,
279
- ),
280
- )
281
-
282
- pretrained.act_postprocess2 = nn.Sequential(
283
- readout_oper[1],
284
- Transpose(1, 2),
285
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
286
- nn.Conv2d(
287
- in_channels=vit_features,
288
- out_channels=features[1],
289
- kernel_size=1,
290
- stride=1,
291
- padding=0,
292
- ),
293
- nn.ConvTranspose2d(
294
- in_channels=features[1],
295
- out_channels=features[1],
296
- kernel_size=2,
297
- stride=2,
298
- padding=0,
299
- bias=True,
300
- dilation=1,
301
- groups=1,
302
- ),
303
- )
304
-
305
- pretrained.act_postprocess3 = nn.Sequential(
306
- readout_oper[2],
307
- Transpose(1, 2),
308
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
309
- nn.Conv2d(
310
- in_channels=vit_features,
311
- out_channels=features[2],
312
- kernel_size=1,
313
- stride=1,
314
- padding=0,
315
- ),
316
- )
317
-
318
- pretrained.act_postprocess4 = nn.Sequential(
319
- readout_oper[3],
320
- Transpose(1, 2),
321
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
322
- nn.Conv2d(
323
- in_channels=vit_features,
324
- out_channels=features[3],
325
- kernel_size=1,
326
- stride=1,
327
- padding=0,
328
- ),
329
- nn.Conv2d(
330
- in_channels=features[3],
331
- out_channels=features[3],
332
- kernel_size=3,
333
- stride=2,
334
- padding=1,
335
- ),
336
- )
337
-
338
- pretrained.model.start_index = start_index
339
- pretrained.model.patch_size = [16, 16]
340
-
341
- # We inject this function into the VisionTransformer instances so that
342
- # we can use it with interpolated position embeddings without modifying the library source.
343
- pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
344
- pretrained.model._resize_pos_embed = types.MethodType(
345
- _resize_pos_embed, pretrained.model
346
- )
347
-
348
- return pretrained
349
-
350
-
351
- def _make_vit_b_rn50_backbone(
352
- model,
353
- features=[256, 512, 768, 768],
354
- size=[384, 384],
355
- hooks=[0, 1, 8, 11],
356
- vit_features=768,
357
- use_vit_only=False,
358
- use_readout="ignore",
359
- start_index=1,
360
- enable_attention_hooks=False,
361
- ):
362
- pretrained = nn.Module()
363
-
364
- pretrained.model = model
365
-
366
- if use_vit_only == True:
367
- pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
368
- pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
369
- else:
370
- pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
371
- get_activation("1")
372
- )
373
- pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
374
- get_activation("2")
375
- )
376
-
377
- pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
378
- pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
379
-
380
- if enable_attention_hooks:
381
- pretrained.model.blocks[2].attn.register_forward_hook(get_attention("attn_1"))
382
- pretrained.model.blocks[5].attn.register_forward_hook(get_attention("attn_2"))
383
- pretrained.model.blocks[8].attn.register_forward_hook(get_attention("attn_3"))
384
- pretrained.model.blocks[11].attn.register_forward_hook(get_attention("attn_4"))
385
- pretrained.attention = attention
386
-
387
- pretrained.activations = activations
388
-
389
- readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
390
-
391
- if use_vit_only == True:
392
- pretrained.act_postprocess1 = nn.Sequential(
393
- readout_oper[0],
394
- Transpose(1, 2),
395
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
396
- nn.Conv2d(
397
- in_channels=vit_features,
398
- out_channels=features[0],
399
- kernel_size=1,
400
- stride=1,
401
- padding=0,
402
- ),
403
- nn.ConvTranspose2d(
404
- in_channels=features[0],
405
- out_channels=features[0],
406
- kernel_size=4,
407
- stride=4,
408
- padding=0,
409
- bias=True,
410
- dilation=1,
411
- groups=1,
412
- ),
413
- )
414
-
415
- pretrained.act_postprocess2 = nn.Sequential(
416
- readout_oper[1],
417
- Transpose(1, 2),
418
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
419
- nn.Conv2d(
420
- in_channels=vit_features,
421
- out_channels=features[1],
422
- kernel_size=1,
423
- stride=1,
424
- padding=0,
425
- ),
426
- nn.ConvTranspose2d(
427
- in_channels=features[1],
428
- out_channels=features[1],
429
- kernel_size=2,
430
- stride=2,
431
- padding=0,
432
- bias=True,
433
- dilation=1,
434
- groups=1,
435
- ),
436
- )
437
- else:
438
- pretrained.act_postprocess1 = nn.Sequential(
439
- nn.Identity(), nn.Identity(), nn.Identity()
440
- )
441
- pretrained.act_postprocess2 = nn.Sequential(
442
- nn.Identity(), nn.Identity(), nn.Identity()
443
- )
444
-
445
- pretrained.act_postprocess3 = nn.Sequential(
446
- readout_oper[2],
447
- Transpose(1, 2),
448
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
449
- nn.Conv2d(
450
- in_channels=vit_features,
451
- out_channels=features[2],
452
- kernel_size=1,
453
- stride=1,
454
- padding=0,
455
- ),
456
- )
457
-
458
- pretrained.act_postprocess4 = nn.Sequential(
459
- readout_oper[3],
460
- Transpose(1, 2),
461
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
462
- nn.Conv2d(
463
- in_channels=vit_features,
464
- out_channels=features[3],
465
- kernel_size=1,
466
- stride=1,
467
- padding=0,
468
- ),
469
- nn.Conv2d(
470
- in_channels=features[3],
471
- out_channels=features[3],
472
- kernel_size=3,
473
- stride=2,
474
- padding=1,
475
- ),
476
- )
477
-
478
- pretrained.model.start_index = start_index
479
- pretrained.model.patch_size = [16, 16]
480
-
481
- # We inject this function into the VisionTransformer instances so that
482
- # we can use it with interpolated position embeddings without modifying the library source.
483
- pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
484
-
485
- # We inject this function into the VisionTransformer instances so that
486
- # we can use it with interpolated position embeddings without modifying the library source.
487
- pretrained.model._resize_pos_embed = types.MethodType(
488
- _resize_pos_embed, pretrained.model
489
- )
490
-
491
- return pretrained
492
-
493
-
494
- def _make_pretrained_vitb_rn50_384(
495
- pretrained,
496
- use_readout="ignore",
497
- hooks=None,
498
- use_vit_only=False,
499
- enable_attention_hooks=False,
500
- ):
501
- model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
502
-
503
- hooks = [0, 1, 8, 11] if hooks == None else hooks
504
- return _make_vit_b_rn50_backbone(
505
- model,
506
- features=[256, 512, 768, 768],
507
- size=[384, 384],
508
- hooks=hooks,
509
- use_vit_only=use_vit_only,
510
- use_readout=use_readout,
511
- enable_attention_hooks=enable_attention_hooks,
512
- )
513
-
514
-
515
- def _make_pretrained_vitl16_384(
516
- pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
517
- ):
518
- model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
519
-
520
- hooks = [5, 11, 17, 23] if hooks == None else hooks
521
- return _make_vit_b16_backbone(
522
- model,
523
- features=[256, 512, 1024, 1024],
524
- hooks=hooks,
525
- vit_features=1024,
526
- use_readout=use_readout,
527
- enable_attention_hooks=enable_attention_hooks,
528
- )
529
-
530
-
531
- def _make_pretrained_vitb16_384(
532
- pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
533
- ):
534
- model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
535
-
536
- hooks = [2, 5, 8, 11] if hooks == None else hooks
537
- return _make_vit_b16_backbone(
538
- model,
539
- features=[96, 192, 384, 768],
540
- hooks=hooks,
541
- use_readout=use_readout,
542
- enable_attention_hooks=enable_attention_hooks,
543
- )
544
-
545
-
546
- def _make_pretrained_deitb16_384(
547
- pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
548
- ):
549
- model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
550
-
551
- hooks = [2, 5, 8, 11] if hooks == None else hooks
552
- return _make_vit_b16_backbone(
553
- model,
554
- features=[96, 192, 384, 768],
555
- hooks=hooks,
556
- use_readout=use_readout,
557
- enable_attention_hooks=enable_attention_hooks,
558
- )
559
-
560
-
561
- def _make_pretrained_deitb16_distil_384(
562
- pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
563
- ):
564
- model = timm.create_model(
565
- "vit_deit_base_distilled_patch16_384", pretrained=pretrained
566
- )
567
-
568
- hooks = [2, 5, 8, 11] if hooks == None else hooks
569
- return _make_vit_b16_backbone(
570
- model,
571
- features=[96, 192, 384, 768],
572
- hooks=hooks,
573
- use_readout=use_readout,
574
- start_index=2,
575
- enable_attention_hooks=enable_attention_hooks,
576
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/edge/generate_dataset.py DELETED
@@ -1,32 +0,0 @@
1
- # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
- #
3
- # This work is made available under the Nvidia Source Code License-NC.
4
- # To view a copy of this license, visit
5
- # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
-
7
- import glob
8
-
9
- from torch.utils.data import Dataset
10
- from dataset.utils import *
11
- from PIL import ImageFile
12
-
13
- ImageFile.LOAD_TRUNCATED_IMAGES = True
14
-
15
-
16
- class Dataset(Dataset):
17
- def __init__(self, data_path, transform):
18
- self.data_path = data_path
19
- self.transform = transform
20
- data_folders = glob.glob(f'{data_path}/*/')
21
- self.data_list = [data for f in data_folders for data in glob.glob(f + '*.JPEG')]
22
- self.data_list += [data for f in data_folders for data in glob.glob(f + '*.jpg')]
23
-
24
- def __len__(self):
25
- return len(self.data_list)
26
-
27
- def __getitem__(self, index):
28
- image_path = self.data_list[index]
29
- image = Image.open(image_path).convert('RGB')
30
- img_size = [image.size[0], image.size[1]]
31
- image = self.transform(image)
32
- return torch.flip(image, dims=(0, )) * 255., image_path, img_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/edge/images.py DELETED
@@ -1,50 +0,0 @@
1
- import cv2
2
- import numpy as np
3
- import torch
4
-
5
-
6
- def image_normalization(img, img_min=0, img_max=255,
7
- epsilon=1e-12):
8
- """This is a typical image normalization function
9
- where the minimum and maximum of the image is needed
10
- source: https://en.wikipedia.org/wiki/Normalization_(image_processing)
11
-
12
- :param img: an image could be gray scale or color
13
- :param img_min: for default is 0
14
- :param img_max: for default is 255
15
-
16
- :return: a normalized image, if max is 255 the dtype is uint8
17
- """
18
-
19
- img = np.float32(img)
20
- # whenever an inconsistent image
21
- img = (img - np.min(img)) * (img_max - img_min) / \
22
- ((np.max(img) - np.min(img)) + epsilon) + img_min
23
- return img
24
-
25
-
26
- def fuse_edge(pred):
27
- edge_maps = []
28
- for i in pred:
29
- tmp = torch.sigmoid(i).cpu().detach().numpy()
30
- edge_maps.append(tmp)
31
- tensor = np.array(edge_maps)
32
-
33
- fuses = []
34
- for idx in range(tensor.shape[1]):
35
- tmp = tensor[:, idx, ...]
36
- tmp = np.squeeze(tmp)
37
-
38
- # Iterate our all 7 NN outputs for a particular image
39
- for i in range(tmp.shape[0]):
40
- tmp_img = tmp[i]
41
- tmp_img = np.uint8(image_normalization(tmp_img))
42
- tmp_img = cv2.bitwise_not(tmp_img)
43
-
44
- if i == 6:
45
- fuse = tmp_img
46
- fuse = fuse.astype(np.uint8)
47
- fuses.append(fuse)
48
- return fuses
49
-
50
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/edge/model.py DELETED
@@ -1,286 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
-
6
- def weight_init(m):
7
- if isinstance(m, (nn.Conv2d,)):
8
- # torch.nn.init.xavier_uniform_(m.weight, gain=1.0)
9
- torch.nn.init.xavier_normal_(m.weight, gain=1.0)
10
- # torch.nn.init.normal_(m.weight, mean=0.0, std=0.01)
11
- if m.weight.data.shape[1] == torch.Size([1]):
12
- torch.nn.init.normal_(m.weight, mean=0.0)
13
-
14
- if m.bias is not None:
15
- torch.nn.init.zeros_(m.bias)
16
-
17
- # for fusion layer
18
- if isinstance(m, (nn.ConvTranspose2d,)):
19
- # torch.nn.init.xavier_uniform_(m.weight, gain=1.0)
20
- torch.nn.init.xavier_normal_(m.weight, gain=1.0)
21
- # torch.nn.init.normal_(m.weight, mean=0.0, std=0.01)
22
-
23
- if m.weight.data.shape[1] == torch.Size([1]):
24
- torch.nn.init.normal_(m.weight, std=0.1)
25
- if m.bias is not None:
26
- torch.nn.init.zeros_(m.bias)
27
-
28
-
29
- class CoFusion(nn.Module):
30
-
31
- def __init__(self, in_ch, out_ch):
32
- super(CoFusion, self).__init__()
33
- self.conv1 = nn.Conv2d(in_ch, 64, kernel_size=3,
34
- stride=1, padding=1)
35
- self.conv2 = nn.Conv2d(64, 64, kernel_size=3,
36
- stride=1, padding=1)
37
- self.conv3 = nn.Conv2d(64, out_ch, kernel_size=3,
38
- stride=1, padding=1)
39
- self.relu = nn.ReLU()
40
-
41
- self.norm_layer1 = nn.GroupNorm(4, 64)
42
- self.norm_layer2 = nn.GroupNorm(4, 64)
43
-
44
- def forward(self, x):
45
- # fusecat = torch.cat(x, dim=1)
46
- attn = self.relu(self.norm_layer1(self.conv1(x)))
47
- attn = self.relu(self.norm_layer2(self.conv2(attn)))
48
- attn = F.softmax(self.conv3(attn), dim=1)
49
-
50
- # return ((fusecat * attn).sum(1)).unsqueeze(1)
51
- return ((x * attn).sum(1)).unsqueeze(1)
52
-
53
- class _DenseLayer(nn.Sequential):
54
- def __init__(self, input_features, out_features):
55
- super(_DenseLayer, self).__init__()
56
-
57
- # self.add_module('relu2', nn.ReLU(inplace=True)),
58
- self.add_module('conv1', nn.Conv2d(input_features, out_features,
59
- kernel_size=3, stride=1, padding=2, bias=True)),
60
- self.add_module('norm1', nn.BatchNorm2d(out_features)),
61
- self.add_module('relu1', nn.ReLU(inplace=True)),
62
- self.add_module('conv2', nn.Conv2d(out_features, out_features,
63
- kernel_size=3, stride=1, bias=True)),
64
- self.add_module('norm2', nn.BatchNorm2d(out_features))
65
-
66
- def forward(self, x):
67
- x1, x2 = x
68
-
69
- new_features = super(_DenseLayer, self).forward(F.relu(x1)) # F.relu()
70
- # if new_features.shape[-1]!=x2.shape[-1]:
71
- # new_features =F.interpolate(new_features,size=(x2.shape[2],x2.shape[-1]), mode='bicubic',
72
- # align_corners=False)
73
- return 0.5 * (new_features + x2), x2
74
-
75
-
76
- class _DenseBlock(nn.Sequential):
77
- def __init__(self, num_layers, input_features, out_features):
78
- super(_DenseBlock, self).__init__()
79
- for i in range(num_layers):
80
- layer = _DenseLayer(input_features, out_features)
81
- self.add_module('denselayer%d' % (i + 1), layer)
82
- input_features = out_features
83
-
84
-
85
- class UpConvBlock(nn.Module):
86
- def __init__(self, in_features, up_scale):
87
- super(UpConvBlock, self).__init__()
88
- self.up_factor = 2
89
- self.constant_features = 16
90
-
91
- layers = self.make_deconv_layers(in_features, up_scale)
92
- assert layers is not None, layers
93
- self.features = nn.Sequential(*layers)
94
-
95
- def make_deconv_layers(self, in_features, up_scale):
96
- layers = []
97
- all_pads=[0,0,1,3,7]
98
- for i in range(up_scale):
99
- kernel_size = 2 ** up_scale
100
- pad = all_pads[up_scale] # kernel_size-1
101
- out_features = self.compute_out_features(i, up_scale)
102
- layers.append(nn.Conv2d(in_features, out_features, 1))
103
- layers.append(nn.ReLU(inplace=True))
104
- layers.append(nn.ConvTranspose2d(
105
- out_features, out_features, kernel_size, stride=2, padding=pad))
106
- in_features = out_features
107
- return layers
108
-
109
- def compute_out_features(self, idx, up_scale):
110
- return 1 if idx == up_scale - 1 else self.constant_features
111
-
112
- def forward(self, x):
113
- return self.features(x)
114
-
115
-
116
- class SingleConvBlock(nn.Module):
117
- def __init__(self, in_features, out_features, stride,
118
- use_bs=True
119
- ):
120
- super(SingleConvBlock, self).__init__()
121
- self.use_bn = use_bs
122
- self.conv = nn.Conv2d(in_features, out_features, 1, stride=stride,
123
- bias=True)
124
- self.bn = nn.BatchNorm2d(out_features)
125
-
126
- def forward(self, x):
127
- x = self.conv(x)
128
- if self.use_bn:
129
- x = self.bn(x)
130
- return x
131
-
132
-
133
- class DoubleConvBlock(nn.Module):
134
- def __init__(self, in_features, mid_features,
135
- out_features=None,
136
- stride=1,
137
- use_act=True):
138
- super(DoubleConvBlock, self).__init__()
139
-
140
- self.use_act = use_act
141
- if out_features is None:
142
- out_features = mid_features
143
- self.conv1 = nn.Conv2d(in_features, mid_features,
144
- 3, padding=1, stride=stride)
145
- self.bn1 = nn.BatchNorm2d(mid_features)
146
- self.conv2 = nn.Conv2d(mid_features, out_features, 3, padding=1)
147
- self.bn2 = nn.BatchNorm2d(out_features)
148
- self.relu = nn.ReLU(inplace=True)
149
-
150
- def forward(self, x):
151
- x = self.conv1(x)
152
- x = self.bn1(x)
153
- x = self.relu(x)
154
- x = self.conv2(x)
155
- x = self.bn2(x)
156
- if self.use_act:
157
- x = self.relu(x)
158
- return x
159
-
160
-
161
- class DexiNed(nn.Module):
162
- """ Definition of the DXtrem network. """
163
-
164
- def __init__(self):
165
- super(DexiNed, self).__init__()
166
- self.block_1 = DoubleConvBlock(3, 32, 64, stride=2,)
167
- self.block_2 = DoubleConvBlock(64, 128, use_act=False)
168
- self.dblock_3 = _DenseBlock(2, 128, 256) # [128,256,100,100]
169
- self.dblock_4 = _DenseBlock(3, 256, 512)
170
- self.dblock_5 = _DenseBlock(3, 512, 512)
171
- self.dblock_6 = _DenseBlock(3, 512, 256)
172
- self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
173
-
174
- # left skip connections, figure in Journal
175
- self.side_1 = SingleConvBlock(64, 128, 2)
176
- self.side_2 = SingleConvBlock(128, 256, 2)
177
- self.side_3 = SingleConvBlock(256, 512, 2)
178
- self.side_4 = SingleConvBlock(512, 512, 1)
179
- self.side_5 = SingleConvBlock(512, 256, 1) # Sory I forget to comment this line :(
180
-
181
- # right skip connections, figure in Journal paper
182
- self.pre_dense_2 = SingleConvBlock(128, 256, 2)
183
- self.pre_dense_3 = SingleConvBlock(128, 256, 1)
184
- self.pre_dense_4 = SingleConvBlock(256, 512, 1)
185
- self.pre_dense_5 = SingleConvBlock(512, 512, 1)
186
- self.pre_dense_6 = SingleConvBlock(512, 256, 1)
187
-
188
-
189
- self.up_block_1 = UpConvBlock(64, 1)
190
- self.up_block_2 = UpConvBlock(128, 1)
191
- self.up_block_3 = UpConvBlock(256, 2)
192
- self.up_block_4 = UpConvBlock(512, 3)
193
- self.up_block_5 = UpConvBlock(512, 4)
194
- self.up_block_6 = UpConvBlock(256, 4)
195
- self.block_cat = SingleConvBlock(6, 1, stride=1, use_bs=False) # hed fusion method
196
- # self.block_cat = CoFusion(6,6)# cats fusion method
197
-
198
-
199
- self.apply(weight_init)
200
-
201
- def slice(self, tensor, slice_shape):
202
- t_shape = tensor.shape
203
- height, width = slice_shape
204
- if t_shape[-1]!=slice_shape[-1]:
205
- new_tensor = F.interpolate(
206
- tensor, size=(height, width), mode='bicubic',align_corners=False)
207
- else:
208
- new_tensor=tensor
209
- # tensor[..., :height, :width]
210
- return new_tensor
211
-
212
- def forward(self, x):
213
- assert x.ndim == 4, x.shape
214
-
215
- # Block 1
216
- block_1 = self.block_1(x)
217
- block_1_side = self.side_1(block_1)
218
-
219
- # Block 2
220
- block_2 = self.block_2(block_1)
221
- block_2_down = self.maxpool(block_2)
222
- block_2_add = block_2_down + block_1_side
223
- block_2_side = self.side_2(block_2_add)
224
-
225
- # Block 3
226
- block_3_pre_dense = self.pre_dense_3(block_2_down)
227
- block_3, _ = self.dblock_3([block_2_add, block_3_pre_dense])
228
- block_3_down = self.maxpool(block_3) # [128,256,50,50]
229
- block_3_add = block_3_down + block_2_side
230
- block_3_side = self.side_3(block_3_add)
231
-
232
- # Block 4
233
- block_2_resize_half = self.pre_dense_2(block_2_down)
234
- block_4_pre_dense = self.pre_dense_4(block_3_down+block_2_resize_half)
235
- block_4, _ = self.dblock_4([block_3_add, block_4_pre_dense])
236
- block_4_down = self.maxpool(block_4)
237
- block_4_add = block_4_down + block_3_side
238
- block_4_side = self.side_4(block_4_add)
239
-
240
- # Block 5
241
- block_5_pre_dense = self.pre_dense_5(
242
- block_4_down) #block_5_pre_dense_512 +block_4_down
243
- block_5, _ = self.dblock_5([block_4_add, block_5_pre_dense])
244
- block_5_add = block_5 + block_4_side
245
-
246
- # Block 6
247
- block_6_pre_dense = self.pre_dense_6(block_5)
248
- block_6, _ = self.dblock_6([block_5_add, block_6_pre_dense])
249
-
250
- # upsampling blocks
251
- out_1 = self.up_block_1(block_1)
252
- out_2 = self.up_block_2(block_2)
253
- out_3 = self.up_block_3(block_3)
254
- out_4 = self.up_block_4(block_4)
255
- out_5 = self.up_block_5(block_5)
256
- out_6 = self.up_block_6(block_6)
257
- results = [out_1, out_2, out_3, out_4, out_5, out_6]
258
-
259
- # concatenate multiscale outputs
260
- block_cat = torch.cat(results, dim=1) # Bx6xHxW
261
- block_cat = self.block_cat(block_cat) # Bx1xHxW
262
-
263
- # return results
264
- results.append(block_cat)
265
- return results
266
-
267
-
268
- if __name__ == '__main__':
269
- batch_size = 8
270
- img_height = 352
271
- img_width = 352
272
-
273
- # device = "cuda" if torch.cuda.is_available() else "cpu"
274
- device = "cpu"
275
- input = torch.rand(batch_size, 3, img_height, img_width).to(device)
276
- # target = torch.rand(batch_size, 1, img_height, img_width).to(device)
277
- print(f"input shape: {input.shape}")
278
- model = DexiNed().to(device)
279
- output = model(input)
280
- print(f"output shapes: {[t.shape for t in output]}")
281
-
282
- # for i in range(20000):
283
- # print(i)
284
- # output = model(input)
285
- # loss = nn.MSELoss()(output[-1], target)
286
- # loss.backward()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/generate_depth.py DELETED
@@ -1,56 +0,0 @@
1
- # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
- #
3
- # This work is made available under the Nvidia Source Code License-NC.
4
- # To view a copy of this license, visit
5
- # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
-
7
- import torch
8
- import os
9
- try:
10
- import ruamel_yaml as yaml
11
- except ModuleNotFoundError:
12
- import ruamel.yaml as yaml
13
-
14
- from experts.model_bank import load_expert_model
15
- from experts.depth.generate_dataset import Dataset
16
- import PIL.Image as Image
17
- from accelerate import Accelerator
18
- from tqdm import tqdm
19
-
20
- model, transform = load_expert_model(task='depth')
21
- accelerator = Accelerator(mixed_precision='fp16')
22
-
23
- config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
24
- data_path = config['data_path']
25
- save_path = os.path.join(config['save_path'], 'depth')
26
-
27
- batch_size = 64
28
- dataset = Dataset(data_path, transform)
29
- data_loader = torch.utils.data.DataLoader(
30
- dataset=dataset,
31
- batch_size=batch_size,
32
- shuffle=False,
33
- num_workers=4,
34
- pin_memory=True
35
- )
36
-
37
- model, data_loader = accelerator.prepare(model, data_loader)
38
-
39
- with torch.no_grad():
40
- for i, (test_data, img_path, img_size) in enumerate(tqdm(data_loader)):
41
- test_pred = model(test_data)
42
-
43
- for k in range(len(test_pred)):
44
- img_path_split = img_path[k].split('/')
45
- ps = img_path[k].split('.')[-1]
46
- im_save_path = os.path.join(save_path, img_path_split[-3], img_path_split[-2])
47
- os.makedirs(im_save_path, exist_ok=True)
48
-
49
- im_size = img_size[0][k].item(), img_size[1][k].item()
50
- depth = test_pred[k]
51
- depth = (depth - depth.min()) / (depth.max() - depth.min())
52
- depth = torch.nn.functional.interpolate(depth.unsqueeze(0).unsqueeze(1), size=(im_size[1], im_size[0]), mode='bilinear', align_corners=True)
53
- depth_im = Image.fromarray(255 * depth[0, 0].detach().cpu().numpy()).convert('L')
54
- depth_im.save(os.path.join(im_save_path, img_path_split[-1].replace(f'.{ps}', '.png')))
55
-
56
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/generate_edge.py DELETED
@@ -1,57 +0,0 @@
1
- # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
- #
3
- # This work is made available under the Nvidia Source Code License-NC.
4
- # To view a copy of this license, visit
5
- # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
-
7
- import torch
8
- import os
9
- try:
10
- import ruamel_yaml as yaml
11
- except ModuleNotFoundError:
12
- import ruamel.yaml as yaml
13
-
14
- from experts.model_bank import load_expert_model
15
- from experts.edge.generate_dataset import Dataset
16
- from experts.edge.images import fuse_edge
17
- import PIL.Image as Image
18
- from accelerate import Accelerator
19
- from tqdm import tqdm
20
-
21
-
22
- model, transform = load_expert_model(task='edge')
23
- accelerator = Accelerator(mixed_precision='fp16')
24
-
25
- config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
26
- data_path = config['data_path']
27
- save_path = os.path.join(config['save_path'], 'edge')
28
-
29
- batch_size = 64
30
- dataset = Dataset(data_path, transform)
31
- data_loader = torch.utils.data.DataLoader(
32
- dataset=dataset,
33
- batch_size=batch_size,
34
- shuffle=False,
35
- num_workers=4,
36
- pin_memory=True
37
- )
38
-
39
- model, data_loader = accelerator.prepare(model, data_loader)
40
-
41
- with torch.no_grad():
42
- for i, (test_data, img_path, img_size) in enumerate(tqdm(data_loader)):
43
- test_pred = model(test_data)
44
- fuses = fuse_edge(test_pred)
45
- for k in range(len(fuses)):
46
- edge = fuses[k]
47
- img_path_split = img_path[k].split('/')
48
- ps = img_path[k].split('.')[-1]
49
- im_save_path = os.path.join(save_path, img_path_split[-3], img_path_split[-2])
50
- os.makedirs(im_save_path, exist_ok=True)
51
-
52
- im_size = img_size[0][k].item(), img_size[1][k].item()
53
- edge = Image.fromarray(edge).convert('L')
54
- edge = edge.resize((im_size[0], im_size[1]), resample=Image.Resampling.BILINEAR)
55
- edge.save(os.path.join(im_save_path, img_path_split[-1].replace(f'.{ps}', '.png')))
56
-
57
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/generate_normal.py DELETED
@@ -1,58 +0,0 @@
1
- # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
- #
3
- # This work is made available under the Nvidia Source Code License-NC.
4
- # To view a copy of this license, visit
5
- # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
-
7
- import torch
8
- import os
9
- try:
10
- import ruamel_yaml as yaml
11
- except ModuleNotFoundError:
12
- import ruamel.yaml as yaml
13
-
14
- from experts.model_bank import load_expert_model
15
- from experts.normal.generate_dataset import CustomDataset
16
- import PIL.Image as Image
17
- from accelerate import Accelerator
18
- from tqdm import tqdm
19
- import numpy as np
20
-
21
-
22
- model, transform = load_expert_model(task='normal')
23
- accelerator = Accelerator(mixed_precision='fp16')
24
-
25
- config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
26
- data_path = config['data_path']
27
- save_path = os.path.join(config['save_path'], 'normal')
28
-
29
- batch_size = 64
30
- dataset = CustomDataset(data_path, transform)
31
- data_loader = torch.utils.data.DataLoader(
32
- dataset=dataset,
33
- batch_size=batch_size,
34
- shuffle=False,
35
- num_workers=4,
36
- pin_memory=True
37
- )
38
-
39
- model, data_loader = accelerator.prepare(model, data_loader)
40
-
41
- with torch.no_grad():
42
- for i, (test_data, img_path, img_size) in enumerate(tqdm(data_loader)):
43
- test_pred = model(test_data)
44
- pred_norm = test_pred[0][-1][:, :3]
45
- for k in range(len(pred_norm)):
46
- img_path_split = img_path[k].split('/')
47
- ps = img_path[k].split('.')[-1]
48
- im_save_path = os.path.join(save_path, img_path_split[-3], img_path_split[-2])
49
- os.makedirs(im_save_path, exist_ok=True)
50
-
51
- im_size = img_size[0][k].item(), img_size[1][k].item()
52
- norm = pred_norm[k]
53
- norm = ((norm + 1) * 0.5).clip(0, 1)
54
- norm = torch.nn.functional.interpolate(norm.unsqueeze(0), size=(im_size[1], im_size[0]), mode='bilinear', align_corners=True)
55
- norm_im = Image.fromarray((norm[0] * 255).permute(1, 2, 0).detach().cpu().numpy().astype(np.uint8)).convert('RGB')
56
- norm_im.save(os.path.join(im_save_path, img_path_split[-1].replace(f'.{ps}', '.png')))
57
-
58
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/generate_objdet.py DELETED
@@ -1,115 +0,0 @@
1
- # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
- #
3
- # This work is made available under the Nvidia Source Code License-NC.
4
- # To view a copy of this license, visit
5
- # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
-
7
- import torch
8
- import os
9
- import json
10
- import copy
11
- import PIL.Image as Image
12
- try:
13
- import ruamel_yaml as yaml
14
- except ModuleNotFoundError:
15
- import ruamel.yaml as yaml
16
-
17
- from experts.model_bank import load_expert_model
18
- from experts.obj_detection.generate_dataset import Dataset, collate_fn
19
- from accelerate import Accelerator
20
- from tqdm import tqdm
21
-
22
- model, transform = load_expert_model(task='obj_detection')
23
- accelerator = Accelerator(mixed_precision='fp16')
24
-
25
- config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
26
- data_path = config['data_path']
27
- save_path = config['save_path']
28
-
29
- depth_path = os.path.join(save_path, 'depth', data_path.split('/')[-1])
30
- batch_size = 32
31
- dataset = Dataset(data_path, depth_path, transform)
32
- data_loader = torch.utils.data.DataLoader(
33
- dataset=dataset,
34
- batch_size=batch_size,
35
- shuffle=False,
36
- num_workers=4,
37
- pin_memory=True,
38
- collate_fn=collate_fn,
39
- )
40
-
41
- model, data_loader = accelerator.prepare(model, data_loader)
42
-
43
-
44
- def get_mask_labels(depth, instance_boxes, instance_id):
45
- obj_masks = []
46
- obj_ids = []
47
- for i in range(len(instance_boxes)):
48
- is_duplicate = False
49
- mask = torch.zeros_like(depth)
50
- x1, y1, x2, y2 = instance_boxes[i][0].item(), instance_boxes[i][1].item(), \
51
- instance_boxes[i][2].item(), instance_boxes[i][3].item()
52
- mask[int(y1):int(y2), int(x1):int(x2)] = 1
53
- for j in range(len(obj_masks)):
54
- if ((mask + obj_masks[j]) == 2).sum() / ((mask + obj_masks[j]) > 0).sum() > 0.95:
55
- is_duplicate = True
56
- break
57
- if not is_duplicate:
58
- obj_masks.append(mask)
59
- obj_ids.append(instance_id[i])
60
-
61
- obj_masked_modified = copy.deepcopy(obj_masks[:])
62
- for i in range(len(obj_masks) - 1):
63
- mask1 = obj_masks[i]
64
- mask1_ = obj_masked_modified[i]
65
- for j in range(i + 1, len(obj_masks)):
66
- mask2 = obj_masks[j]
67
- mask2_ = obj_masked_modified[j]
68
- # case 1: if they don't intersect we don't touch them
69
- if ((mask1 + mask2) == 2).sum() == 0:
70
- continue
71
- # case 2: the entire object 1 is inside of object 2, we say object 1 is in front of object 2:
72
- elif (((mask1 + mask2) == 2).float() - mask1).sum() == 0:
73
- mask2_ -= mask1_
74
- # case 3: the entire object 2 is inside of object 1, we say object 2 is in front of object 1:
75
- elif (((mask1 + mask2) == 2).float() - mask2).sum() == 0:
76
- mask1_ -= mask2_
77
- # case 4: use depth to check object order:
78
- else:
79
- # object 1 is closer
80
- if (depth * mask1).sum() / mask1.sum() > (depth * mask2).sum() / mask2.sum():
81
- mask2_ -= ((mask1 + mask2) == 2).float()
82
- # object 2 is closer
83
- if (depth * mask1).sum() / mask1.sum() < (depth * mask2).sum() / mask2.sum():
84
- mask1_ -= ((mask1 + mask2) == 2).float()
85
-
86
- final_mask = torch.ones_like(depth) * 255
87
- instance_labels = {}
88
- for i in range(len(obj_masked_modified)):
89
- final_mask = final_mask.masked_fill(obj_masked_modified[i] > 0, i)
90
- instance_labels[i] = obj_ids[i].item()
91
- return final_mask, instance_labels
92
-
93
-
94
- with torch.no_grad():
95
- for i, test_data in enumerate(tqdm(data_loader)):
96
- test_pred = model(test_data)
97
- for k in range(len(test_pred)):
98
- instance_boxes = test_pred[k]['instances'].get_fields()['pred_boxes'].tensor
99
- instance_id = test_pred[k]['instances'].get_fields()['pred_classes']
100
- depth = test_data[k]['depth']
101
-
102
- final_mask, instance_labels = get_mask_labels(depth, instance_boxes, instance_id)
103
-
104
- img_path_split = test_data[k]['image_path'].split('/')
105
- im_save_path = os.path.join(save_path, 'obj_detection', img_path_split[-3], img_path_split[-2])
106
- ps = test_data[k]['image_path'].split('.')[-1]
107
- os.makedirs(im_save_path, exist_ok=True)
108
-
109
- height, width = test_data[k]['true_height'], test_data[k]['true_width']
110
- final_mask = Image.fromarray(final_mask.cpu().numpy()).convert('L')
111
- final_mask = final_mask.resize((height, width), resample=Image.Resampling.NEAREST)
112
- final_mask.save(os.path.join(im_save_path, img_path_split[-1].replace(f'.{ps}', '.png')))
113
-
114
- with open(os.path.join(im_save_path, img_path_split[-1].replace(f'.{ps}', '.json')), 'w') as fp:
115
- json.dump(instance_labels, fp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/generate_ocrdet.py DELETED
@@ -1,86 +0,0 @@
1
- # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
- #
3
- # This work is made available under the Nvidia Source Code License-NC.
4
- # To view a copy of this license, visit
5
- # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
-
7
- import torch
8
- import os
9
- import PIL.Image as Image
10
- import numpy as np
11
- import cv2
12
- import clip
13
- import pickle as pk
14
- try:
15
- import ruamel_yaml as yaml
16
- except ModuleNotFoundError:
17
- import ruamel.yaml as yaml
18
-
19
- from experts.model_bank import load_expert_model
20
- from experts.ocr_detection.generate_dataset import Dataset
21
- from accelerate import Accelerator
22
- from tqdm import tqdm
23
-
24
-
25
- model, transform = load_expert_model(task='ocr_detection')
26
- accelerator = Accelerator(mixed_precision='fp16')
27
- pca_clip = pk.load(open('dataset/clip_pca.pkl', 'rb'))
28
-
29
- config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
30
- data_path = config['data_path']
31
- save_path = os.path.join(config['save_path'], 'ocr_detection')
32
-
33
- batch_size = 32
34
- dataset = Dataset(data_path, transform)
35
- data_loader = torch.utils.data.DataLoader(
36
- dataset=dataset,
37
- batch_size=batch_size,
38
- shuffle=False,
39
- num_workers=4,
40
- pin_memory=True,
41
- )
42
-
43
- clip_model, _ = clip.load("ViT-L/14", device=accelerator.device)
44
- model, data_loader = accelerator.prepare(model, data_loader)
45
-
46
-
47
- def get_label(w, h, word_instances):
48
- word_lists = []
49
- final_mask = np.ones([h, w], dtype=np.uint8) * 255
50
- counter = 0
51
- for word_instance in word_instances[::-1]:
52
- mask = np.zeros([h ,w])
53
- mask = cv2.fillPoly(mask, [np.int32(word_instance.word_bbox.reshape(-1, 2))], 1)
54
- text = word_instance.text.lower()
55
- final_mask[mask > 0] = counter
56
- word_lists.append(text)
57
- counter += 1
58
- return final_mask, word_lists
59
-
60
-
61
- with torch.no_grad():
62
- for i, (test_data, image_path, scale_w, scale_h, original_w, original_h) in enumerate(tqdm(data_loader)):
63
- word_instance_lists = model(test_data, scale_w, scale_h, original_w, original_h)
64
- for k in range(len(word_instance_lists)):
65
- word_instance = word_instance_lists[k]
66
- if len(word_instance) == 0:
67
- continue
68
- else:
69
- final_mask, word_lists = get_label(original_w[k], original_h[k], word_instance)
70
-
71
- final_mask = Image.fromarray(final_mask)
72
- img_path_split = image_path[k].split('/')
73
- ps = image_path[k].split('.')[-1]
74
- im_save_path = os.path.join(save_path, img_path_split[-3], img_path_split[-2])
75
- os.makedirs(im_save_path, exist_ok=True)
76
-
77
- final_mask.save(os.path.join(im_save_path, img_path_split[-1].replace(f'.{ps}', '.png')))
78
-
79
- if len(word_lists) > 0:
80
- word_embed = clip.tokenize(word_lists).to(accelerator.device)
81
- word_features = pca_clip.transform(clip_model.encode_text(word_embed).float().cpu())
82
- word_lists = {j: {'features': torch.from_numpy(word_features[j]).float(),
83
- 'text': word_lists[j]} for j in range(len(word_lists))}
84
- torch.save(word_lists, os.path.join(im_save_path, img_path_split[-1].replace(f'.{ps}', '.pt')))
85
-
86
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/generate_segmentation.py DELETED
@@ -1,56 +0,0 @@
1
- # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
- #
3
- # This work is made available under the Nvidia Source Code License-NC.
4
- # To view a copy of this license, visit
5
- # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
-
7
- import torch
8
- import os
9
- import PIL.Image as Image
10
- try:
11
- import ruamel_yaml as yaml
12
- except ModuleNotFoundError:
13
- import ruamel.yaml as yaml
14
-
15
- from experts.model_bank import load_expert_model
16
- from experts.segmentation.generate_dataset import Dataset, collate_fn
17
- from accelerate import Accelerator
18
- from tqdm import tqdm
19
-
20
- model, transform = load_expert_model(task='seg_coco')
21
- accelerator = Accelerator(mixed_precision='fp16')
22
-
23
- config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
24
- data_path = config['data_path']
25
- save_path = os.path.join(config['save_path'], 'seg_coco')
26
-
27
- batch_size = 4
28
- dataset = Dataset(data_path, transform)
29
- data_loader = torch.utils.data.DataLoader(
30
- dataset=dataset,
31
- batch_size=batch_size,
32
- shuffle=False,
33
- num_workers=4,
34
- pin_memory=True,
35
- collate_fn=collate_fn,
36
- )
37
-
38
-
39
- model, data_loader = accelerator.prepare(model, data_loader)
40
-
41
- with torch.no_grad():
42
- for i, test_data in enumerate(tqdm(data_loader)):
43
- test_pred = model(test_data)
44
-
45
- for k in range(len(test_pred)):
46
- pred = test_pred[k]['sem_seg']
47
- labels = torch.argmax(pred, dim=0)
48
-
49
- img_path_split = test_data[k]['image_path'].split('/')
50
- ps = test_data[k]['image_path'].split('.')[-1]
51
- im_save_path = os.path.join(save_path, img_path_split[-3], img_path_split[-2])
52
- os.makedirs(im_save_path, exist_ok=True)
53
-
54
- seg = Image.fromarray(labels.float().detach().cpu().numpy()).convert('L')
55
- seg.save(os.path.join(im_save_path, img_path_split[-1].replace(f'.{ps}', '.png')))
56
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/model_bank.py DELETED
@@ -1,139 +0,0 @@
1
- # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
- #
3
- # This work is made available under the Nvidia Source Code License-NC.
4
- # To view a copy of this license, visit
5
- # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
-
7
- import torch
8
- import torchvision.transforms as transforms
9
-
10
-
11
- def load_expert_model(task=None):
12
- if task == 'depth':
13
- # DPT model is a standard pytorch model class
14
- from experts.depth.models import DPTDepthModel
15
-
16
- model = DPTDepthModel(path='experts/expert_weights/dpt_hybrid-midas-501f0c75.pt',
17
- backbone="vitb_rn50_384",
18
- non_negative=True,
19
- enable_attention_hooks=False)
20
- transform = transforms.Compose([
21
- transforms.Resize([480, 480]),
22
- transforms.ToTensor(),
23
- transforms.Normalize(mean=0.5, std=0.5)]
24
- )
25
-
26
- elif task == 'seg_coco':
27
- # Mask2Former is wrapped in detection2,
28
- # the model takes input in the format of: {"image": image (BGR), "height": height, "width": width}
29
- import argparse
30
- from detectron2.engine.defaults import DefaultPredictor
31
- from experts.segmentation.utils import setup_cfg
32
-
33
- parser = argparse.ArgumentParser()
34
- parser.add_argument("--mode", default="client")
35
- parser.add_argument("--port", default=2)
36
- args = parser.parse_args()
37
-
38
- args.config_file = 'experts/segmentation/configs/coco/panoptic-segmentation/swin/maskformer2_swin_large_IN21k_384_bs16_100ep.yaml'
39
- args.opts = ['MODEL.WEIGHTS', 'experts/expert_weights/model_final_f07440.pkl']
40
- cfg = setup_cfg(args)
41
- model = DefaultPredictor(cfg).model
42
- transform = transforms.Compose([
43
- transforms.Resize(size=479, max_size=480)
44
- ])
45
-
46
- elif task == 'seg_ade':
47
- # Mask2Former is wrapped in detection2,
48
- # the model takes input in the format of: {"image": image (BGR), "height": height, "width": width}
49
- import argparse
50
- from detectron2.engine.defaults import DefaultPredictor
51
- from experts.segmentation.utils import setup_cfg
52
-
53
- parser = argparse.ArgumentParser()
54
- parser.add_argument("--mode", default="client")
55
- parser.add_argument("--port", default=2)
56
- args = parser.parse_args()
57
-
58
- args.config_file = 'experts/segmentation/configs/ade20k/panoptic-segmentation/swin/maskformer2_swin_large_IN21k_384_bs16_160k.yaml'
59
- args.opts = ['MODEL.WEIGHTS', 'experts/expert_weights/model_final_e0c58e.pkl']
60
- cfg = setup_cfg(args)
61
- model = DefaultPredictor(cfg).model
62
- transform = transforms.Compose([
63
- transforms.Resize(size=479, max_size=480)
64
- ])
65
-
66
- elif task == 'obj_detection':
67
- # UniDet is wrapped in detection2,
68
- # the model takes input in the format of: {"image": image (BGR), "height": height, "width": width}
69
- import argparse
70
- from detectron2.engine.defaults import DefaultPredictor
71
- from experts.obj_detection.utils import setup_cfg
72
- parser = argparse.ArgumentParser()
73
- parser.add_argument("--mode", default="client")
74
- parser.add_argument("--port", default=2)
75
- parser.add_argument("--confidence-threshold", type=float, default=0.5)
76
- args = parser.parse_args()
77
-
78
- args.config_file = 'experts/obj_detection/configs/Unified_learned_OCIM_RS200_6x+2x.yaml'
79
- args.opts = ['MODEL.WEIGHTS', 'experts/expert_weights/Unified_learned_OCIM_RS200_6x+2x.pth']
80
-
81
- cfg = setup_cfg(args)
82
- model = DefaultPredictor(cfg).model
83
- transform = transforms.Compose([
84
- transforms.Resize(size=479, max_size=480)
85
- ])
86
-
87
- elif task == 'ocr_detection':
88
- from experts.ocr_detection.charnet.modeling.model import CharNet
89
- model = CharNet()
90
- model.load_state_dict(torch.load('experts/expert_weights/icdar2015_hourglass88.pth'))
91
- transform = transforms.Compose([
92
- transforms.ToTensor(),
93
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
94
- ])
95
-
96
- elif task == 'normal':
97
- # NLL-AngMF model is a standard pytorch model class
98
- import argparse
99
- from experts.normal.models.NNET import NNET
100
- from experts.normal.utils import utils
101
-
102
- parser = argparse.ArgumentParser()
103
- parser.add_argument("--mode", default="client")
104
- parser.add_argument("--port", default=2)
105
- parser.add_argument('--architecture', default='BN', type=str, help='{BN, GN}')
106
- parser.add_argument("--pretrained", default='scannet', type=str, help="{nyu, scannet}")
107
- parser.add_argument('--sampling_ratio', type=float, default=0.4)
108
- parser.add_argument('--importance_ratio', type=float, default=0.7)
109
- args = parser.parse_args()
110
- model = NNET(args)
111
- model = utils.load_checkpoint('experts/expert_weights/scannet.pt', model)
112
-
113
- transform = transforms.Compose([
114
- transforms.Resize([480, 480]),
115
- transforms.ToTensor(),
116
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
117
- ])
118
-
119
- elif task == 'edge':
120
- # NLL-AngMF model is a standard pytorch model class
121
- from experts.edge.model import DexiNed
122
- model = DexiNed()
123
- model.load_state_dict(torch.load('experts/expert_weights/10_model.pth', map_location='cpu'))
124
- transform = transforms.Compose([
125
- transforms.Resize([480, 480]),
126
- transforms.ToTensor(),
127
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[1.0, 1.0, 1.0])
128
- ])
129
- else:
130
- print('Task not supported')
131
- model = None
132
- transform = None
133
-
134
- model.eval()
135
- return model, transform
136
-
137
-
138
-
139
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/normal/generate_dataset.py DELETED
@@ -1,34 +0,0 @@
1
- # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
- #
3
- # This work is made available under the Nvidia Source Code License-NC.
4
- # To view a copy of this license, visit
5
- # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
-
7
- import glob
8
-
9
- from torch.utils.data import Dataset
10
- from PIL import Image
11
- from PIL import ImageFile
12
-
13
- ImageFile.LOAD_TRUNCATED_IMAGES = True
14
-
15
-
16
- class CustomDataset(Dataset):
17
- def __init__(self, data_path, transform):
18
- self.data_path = data_path
19
- self.transform = transform
20
- data_folders = glob.glob(f'{data_path}/*/')
21
- self.data_list = [data for f in data_folders for data in glob.glob(f + '*.JPEG')]
22
- self.data_list += [data for f in data_folders for data in glob.glob(f + '*.jpg')]
23
-
24
- def __len__(self):
25
- return len(self.data_list)
26
-
27
- def __getitem__(self, index):
28
- image_path = self.data_list[index]
29
- image = Image.open(image_path).convert('RGB')
30
- img_size = [image.size[0], image.size[1]]
31
- image = self.transform(image)
32
- return image, image_path, img_size
33
-
34
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/normal/models/NNET.py DELETED
@@ -1,22 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- from experts.normal.models.submodules.encoder import Encoder
6
- from experts.normal.models.submodules.decoder import Decoder
7
-
8
-
9
- class NNET(nn.Module):
10
- def __init__(self, args):
11
- super(NNET, self).__init__()
12
- self.encoder = Encoder()
13
- self.decoder = Decoder(args)
14
-
15
- def get_1x_lr_params(self): # lr/10 learning rate
16
- return self.encoder.parameters()
17
-
18
- def get_10x_lr_params(self): # lr learning rate
19
- return self.decoder.parameters()
20
-
21
- def forward(self, img, **kwargs):
22
- return self.decoder(self.encoder(img), **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/normal/models/baseline.py DELETED
@@ -1,85 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- from experts.normal.models.submodules.submodules import UpSampleBN, norm_normalize
6
-
7
-
8
- # This is the baseline encoder-decoder we used in the ablation study
9
- class NNET(nn.Module):
10
- def __init__(self, args=None):
11
- super(NNET, self).__init__()
12
- self.encoder = Encoder()
13
- self.decoder = Decoder(num_classes=4)
14
-
15
- def forward(self, x, **kwargs):
16
- out = self.decoder(self.encoder(x), **kwargs)
17
-
18
- # Bilinearly upsample the output to match the input resolution
19
- up_out = F.interpolate(out, size=[x.size(2), x.size(3)], mode='bilinear', align_corners=False)
20
-
21
- # L2-normalize the first three channels / ensure positive value for concentration parameters (kappa)
22
- up_out = norm_normalize(up_out)
23
- return up_out
24
-
25
- def get_1x_lr_params(self): # lr/10 learning rate
26
- return self.encoder.parameters()
27
-
28
- def get_10x_lr_params(self): # lr learning rate
29
- modules = [self.decoder]
30
- for m in modules:
31
- yield from m.parameters()
32
-
33
-
34
- # Encoder
35
- class Encoder(nn.Module):
36
- def __init__(self):
37
- super(Encoder, self).__init__()
38
-
39
- basemodel_name = 'tf_efficientnet_b5_ap'
40
- basemodel = torch.hub.load('rwightman/gen-efficientnet-pytorch', basemodel_name, pretrained=True)
41
-
42
- # Remove last layer
43
- basemodel.global_pool = nn.Identity()
44
- basemodel.classifier = nn.Identity()
45
-
46
- self.original_model = basemodel
47
-
48
- def forward(self, x):
49
- features = [x]
50
- for k, v in self.original_model._modules.items():
51
- if (k == 'blocks'):
52
- for ki, vi in v._modules.items():
53
- features.append(vi(features[-1]))
54
- else:
55
- features.append(v(features[-1]))
56
- return features
57
-
58
-
59
- # Decoder (no pixel-wise MLP, no uncertainty-guided sampling)
60
- class Decoder(nn.Module):
61
- def __init__(self, num_classes=4):
62
- super(Decoder, self).__init__()
63
- self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0)
64
- self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024)
65
- self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512)
66
- self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256)
67
- self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128)
68
- self.conv3 = nn.Conv2d(128, num_classes, kernel_size=3, stride=1, padding=1)
69
-
70
- def forward(self, features):
71
- x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[11]
72
- x_d0 = self.conv2(x_block4)
73
- x_d1 = self.up1(x_d0, x_block3)
74
- x_d2 = self.up2(x_d1, x_block2)
75
- x_d3 = self.up3(x_d2, x_block1)
76
- x_d4 = self.up4(x_d3, x_block0)
77
- out = self.conv3(x_d4)
78
- return out
79
-
80
-
81
- if __name__ == '__main__':
82
- model = Baseline()
83
- x = torch.rand(2, 3, 480, 640)
84
- out = model(x)
85
- print(out.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/normal/models/submodules/decoder.py DELETED
@@ -1,202 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from experts.normal.models.submodules.submodules import UpSampleBN, UpSampleGN, norm_normalize, sample_points
5
-
6
-
7
- class Decoder(nn.Module):
8
- def __init__(self, args):
9
- super(Decoder, self).__init__()
10
-
11
- # hyper-parameter for sampling
12
- self.sampling_ratio = args.sampling_ratio
13
- self.importance_ratio = args.importance_ratio
14
-
15
- # feature-map
16
- self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0)
17
- if args.architecture == 'BN':
18
- self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024)
19
- self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512)
20
- self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256)
21
- self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128)
22
-
23
- elif args.architecture == 'GN':
24
- self.up1 = UpSampleGN(skip_input=2048 + 176, output_features=1024)
25
- self.up2 = UpSampleGN(skip_input=1024 + 64, output_features=512)
26
- self.up3 = UpSampleGN(skip_input=512 + 40, output_features=256)
27
- self.up4 = UpSampleGN(skip_input=256 + 24, output_features=128)
28
-
29
- else:
30
- raise Exception('invalid architecture')
31
-
32
- # produces 1/8 res output
33
- self.out_conv_res8 = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
34
-
35
- # produces 1/4 res output
36
- self.out_conv_res4 = nn.Sequential(
37
- nn.Conv1d(512 + 4, 128, kernel_size=1), nn.ReLU(),
38
- nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
39
- nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
40
- nn.Conv1d(128, 4, kernel_size=1),
41
- )
42
-
43
- # produces 1/2 res output
44
- self.out_conv_res2 = nn.Sequential(
45
- nn.Conv1d(256 + 4, 128, kernel_size=1), nn.ReLU(),
46
- nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
47
- nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
48
- nn.Conv1d(128, 4, kernel_size=1),
49
- )
50
-
51
- # produces 1/1 res output
52
- self.out_conv_res1 = nn.Sequential(
53
- nn.Conv1d(128 + 4, 128, kernel_size=1), nn.ReLU(),
54
- nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
55
- nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
56
- nn.Conv1d(128, 4, kernel_size=1),
57
- )
58
-
59
- def forward(self, features, gt_norm_mask=None, mode='test'):
60
- x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[11]
61
-
62
- # generate feature-map
63
-
64
- x_d0 = self.conv2(x_block4) # x_d0 : [2, 2048, 15, 20] 1/32 res
65
- x_d1 = self.up1(x_d0, x_block3) # x_d1 : [2, 1024, 30, 40] 1/16 res
66
- x_d2 = self.up2(x_d1, x_block2) # x_d2 : [2, 512, 60, 80] 1/8 res
67
- x_d3 = self.up3(x_d2, x_block1) # x_d3: [2, 256, 120, 160] 1/4 res
68
- x_d4 = self.up4(x_d3, x_block0) # x_d4: [2, 128, 240, 320] 1/2 res
69
-
70
- # 1/8 res output
71
- out_res8 = self.out_conv_res8(x_d2) # out_res8: [2, 4, 60, 80] 1/8 res output
72
- out_res8 = norm_normalize(out_res8) # out_res8: [2, 4, 60, 80] 1/8 res output
73
-
74
- ################################################################################################################
75
- # out_res4
76
- ################################################################################################################
77
-
78
- if mode == 'train':
79
- # upsampling ... out_res8: [2, 4, 60, 80] -> out_res8_res4: [2, 4, 120, 160]
80
- out_res8_res4 = F.interpolate(out_res8, scale_factor=2, mode='bilinear', align_corners=True)
81
- B, _, H, W = out_res8_res4.shape
82
-
83
- # samples: [B, 1, N, 2]
84
- point_coords_res4, rows_int, cols_int = sample_points(out_res8_res4.detach(), gt_norm_mask,
85
- sampling_ratio=self.sampling_ratio,
86
- beta=self.importance_ratio)
87
-
88
- # output (needed for evaluation / visualization)
89
- out_res4 = out_res8_res4
90
-
91
- # grid_sample feature-map
92
- feat_res4 = F.grid_sample(x_d2, point_coords_res4, mode='bilinear', align_corners=True) # (B, 512, 1, N)
93
- init_pred = F.grid_sample(out_res8, point_coords_res4, mode='bilinear', align_corners=True) # (B, 4, 1, N)
94
- feat_res4 = torch.cat([feat_res4, init_pred], dim=1) # (B, 512+4, 1, N)
95
-
96
- # prediction (needed to compute loss)
97
- samples_pred_res4 = self.out_conv_res4(feat_res4[:, :, 0, :]) # (B, 4, N)
98
- samples_pred_res4 = norm_normalize(samples_pred_res4) # (B, 4, N) - normalized
99
-
100
- for i in range(B):
101
- out_res4[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res4[i, :, :]
102
-
103
- else:
104
- # grid_sample feature-map
105
- feat_map = F.interpolate(x_d2, scale_factor=2, mode='bilinear', align_corners=True)
106
- init_pred = F.interpolate(out_res8, scale_factor=2, mode='bilinear', align_corners=True)
107
- feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
108
- B, _, H, W = feat_map.shape
109
-
110
- # try all pixels
111
- out_res4 = self.out_conv_res4(feat_map.view(B, 512 + 4, -1)) # (B, 4, N)
112
- out_res4 = norm_normalize(out_res4) # (B, 4, N) - normalized
113
- out_res4 = out_res4.view(B, 4, H, W)
114
- samples_pred_res4 = point_coords_res4 = None
115
-
116
- ################################################################################################################
117
- # out_res2
118
- ################################################################################################################
119
-
120
- if mode == 'train':
121
-
122
- # upsampling ... out_res4: [2, 4, 120, 160] -> out_res4_res2: [2, 4, 240, 320]
123
- out_res4_res2 = F.interpolate(out_res4, scale_factor=2, mode='bilinear', align_corners=True)
124
- B, _, H, W = out_res4_res2.shape
125
-
126
- # samples: [B, 1, N, 2]
127
- point_coords_res2, rows_int, cols_int = sample_points(out_res4_res2.detach(), gt_norm_mask,
128
- sampling_ratio=self.sampling_ratio,
129
- beta=self.importance_ratio)
130
-
131
- # output (needed for evaluation / visualization)
132
- out_res2 = out_res4_res2
133
-
134
- # grid_sample feature-map
135
- feat_res2 = F.grid_sample(x_d3, point_coords_res2, mode='bilinear', align_corners=True) # (B, 256, 1, N)
136
- init_pred = F.grid_sample(out_res4, point_coords_res2, mode='bilinear', align_corners=True) # (B, 4, 1, N)
137
- feat_res2 = torch.cat([feat_res2, init_pred], dim=1) # (B, 256+4, 1, N)
138
-
139
- # prediction (needed to compute loss)
140
- samples_pred_res2 = self.out_conv_res2(feat_res2[:, :, 0, :]) # (B, 4, N)
141
- samples_pred_res2 = norm_normalize(samples_pred_res2) # (B, 4, N) - normalized
142
-
143
- for i in range(B):
144
- out_res2[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res2[i, :, :]
145
-
146
- else:
147
- # grid_sample feature-map
148
- feat_map = F.interpolate(x_d3, scale_factor=2, mode='bilinear', align_corners=True)
149
- init_pred = F.interpolate(out_res4, scale_factor=2, mode='bilinear', align_corners=True)
150
- feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
151
- B, _, H, W = feat_map.shape
152
-
153
- out_res2 = self.out_conv_res2(feat_map.view(B, 256 + 4, -1)) # (B, 4, N)
154
- out_res2 = norm_normalize(out_res2) # (B, 4, N) - normalized
155
- out_res2 = out_res2.view(B, 4, H, W)
156
- samples_pred_res2 = point_coords_res2 = None
157
-
158
- ################################################################################################################
159
- # out_res1
160
- ################################################################################################################
161
-
162
- if mode == 'train':
163
- # upsampling ... out_res4: [2, 4, 120, 160] -> out_res4_res2: [2, 4, 240, 320]
164
- out_res2_res1 = F.interpolate(out_res2, scale_factor=2, mode='bilinear', align_corners=True)
165
- B, _, H, W = out_res2_res1.shape
166
-
167
- # samples: [B, 1, N, 2]
168
- point_coords_res1, rows_int, cols_int = sample_points(out_res2_res1.detach(), gt_norm_mask,
169
- sampling_ratio=self.sampling_ratio,
170
- beta=self.importance_ratio)
171
-
172
- # output (needed for evaluation / visualization)
173
- out_res1 = out_res2_res1
174
-
175
- # grid_sample feature-map
176
- feat_res1 = F.grid_sample(x_d4, point_coords_res1, mode='bilinear', align_corners=True) # (B, 128, 1, N)
177
- init_pred = F.grid_sample(out_res2, point_coords_res1, mode='bilinear', align_corners=True) # (B, 4, 1, N)
178
- feat_res1 = torch.cat([feat_res1, init_pred], dim=1) # (B, 128+4, 1, N)
179
-
180
- # prediction (needed to compute loss)
181
- samples_pred_res1 = self.out_conv_res1(feat_res1[:, :, 0, :]) # (B, 4, N)
182
- samples_pred_res1 = norm_normalize(samples_pred_res1) # (B, 4, N) - normalized
183
-
184
- for i in range(B):
185
- out_res1[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res1[i, :, :]
186
-
187
- else:
188
- # grid_sample feature-map
189
- feat_map = F.interpolate(x_d4, scale_factor=2, mode='bilinear', align_corners=True)
190
- init_pred = F.interpolate(out_res2, scale_factor=2, mode='bilinear', align_corners=True)
191
- feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
192
- B, _, H, W = feat_map.shape
193
-
194
- out_res1 = self.out_conv_res1(feat_map.view(B, 128 + 4, -1)) # (B, 4, N)
195
- out_res1 = norm_normalize(out_res1) # (B, 4, N) - normalized
196
- out_res1 = out_res1.view(B, 4, H, W)
197
- samples_pred_res1 = point_coords_res1 = None
198
-
199
- return [out_res8, out_res4, out_res2, out_res1], \
200
- [out_res8, samples_pred_res4, samples_pred_res2, samples_pred_res1], \
201
- [None, point_coords_res4, point_coords_res2, point_coords_res1]
202
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/normal/models/submodules/encoder.py DELETED
@@ -1,32 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
-
6
- class Encoder(nn.Module):
7
- def __init__(self):
8
- super(Encoder, self).__init__()
9
-
10
- basemodel_name = 'tf_efficientnet_b5_ap'
11
- print('Loading base model ()...'.format(basemodel_name), end='')
12
- basemodel = torch.hub.load('rwightman/gen-efficientnet-pytorch', basemodel_name, pretrained=True)
13
- print('Done.')
14
-
15
- # Remove last layer
16
- print('Removing last two layers (global_pool & classifier).')
17
- basemodel.global_pool = nn.Identity()
18
- basemodel.classifier = nn.Identity()
19
-
20
- self.original_model = basemodel
21
-
22
- def forward(self, x):
23
- features = [x]
24
- for k, v in self.original_model._modules.items():
25
- if (k == 'blocks'):
26
- for ki, vi in v._modules.items():
27
- features.append(vi(features[-1]))
28
- else:
29
- features.append(v(features[-1]))
30
- return features
31
-
32
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/normal/models/submodules/submodules.py DELETED
@@ -1,140 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
-
6
- ########################################################################################################################
7
-
8
-
9
- # Upsample + BatchNorm
10
- class UpSampleBN(nn.Module):
11
- def __init__(self, skip_input, output_features):
12
- super(UpSampleBN, self).__init__()
13
-
14
- self._net = nn.Sequential(nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1),
15
- nn.BatchNorm2d(output_features),
16
- nn.LeakyReLU(),
17
- nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1),
18
- nn.BatchNorm2d(output_features),
19
- nn.LeakyReLU())
20
-
21
- def forward(self, x, concat_with):
22
- up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True)
23
- f = torch.cat([up_x, concat_with], dim=1)
24
- return self._net(f)
25
-
26
-
27
- # Upsample + GroupNorm + Weight Standardization
28
- class UpSampleGN(nn.Module):
29
- def __init__(self, skip_input, output_features):
30
- super(UpSampleGN, self).__init__()
31
-
32
- self._net = nn.Sequential(Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1),
33
- nn.GroupNorm(8, output_features),
34
- nn.LeakyReLU(),
35
- Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1),
36
- nn.GroupNorm(8, output_features),
37
- nn.LeakyReLU())
38
-
39
- def forward(self, x, concat_with):
40
- up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True)
41
- f = torch.cat([up_x, concat_with], dim=1)
42
- return self._net(f)
43
-
44
-
45
- # Conv2d with weight standardization
46
- class Conv2d(nn.Conv2d):
47
- def __init__(self, in_channels, out_channels, kernel_size, stride=1,
48
- padding=0, dilation=1, groups=1, bias=True):
49
- super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
50
- padding, dilation, groups, bias)
51
-
52
- def forward(self, x):
53
- weight = self.weight
54
- weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
55
- keepdim=True).mean(dim=3, keepdim=True)
56
- weight = weight - weight_mean
57
- std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
58
- weight = weight / std.expand_as(weight)
59
- return F.conv2d(x, weight, self.bias, self.stride,
60
- self.padding, self.dilation, self.groups)
61
-
62
-
63
- # normalize
64
- def norm_normalize(norm_out):
65
- min_kappa = 0.01
66
- norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1)
67
- norm = torch.sqrt(norm_x ** 2.0 + norm_y ** 2.0 + norm_z ** 2.0) + 1e-10
68
- kappa = F.elu(kappa) + 1.0 + min_kappa
69
- final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1)
70
- return final_out
71
-
72
-
73
- # uncertainty-guided sampling (only used during training)
74
- @torch.no_grad()
75
- def sample_points(init_normal, gt_norm_mask, sampling_ratio, beta):
76
- device = init_normal.device
77
- B, _, H, W = init_normal.shape
78
- N = int(sampling_ratio * H * W)
79
- beta = beta
80
-
81
- # uncertainty map
82
- uncertainty_map = -1 * init_normal[:, 3, :, :] # B, H, W
83
-
84
- # gt_invalid_mask (B, H, W)
85
- if gt_norm_mask is not None:
86
- gt_invalid_mask = F.interpolate(gt_norm_mask.float(), size=[H, W], mode='nearest')
87
- gt_invalid_mask = gt_invalid_mask[:, 0, :, :] < 0.5
88
- uncertainty_map[gt_invalid_mask] = -1e4
89
-
90
- # (B, H*W)
91
- _, idx = uncertainty_map.view(B, -1).sort(1, descending=True)
92
-
93
- # importance sampling
94
- if int(beta * N) > 0:
95
- importance = idx[:, :int(beta * N)] # B, beta*N
96
-
97
- # remaining
98
- remaining = idx[:, int(beta * N):] # B, H*W - beta*N
99
-
100
- # coverage
101
- num_coverage = N - int(beta * N)
102
-
103
- if num_coverage <= 0:
104
- samples = importance
105
- else:
106
- coverage_list = []
107
- for i in range(B):
108
- idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
109
- coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
110
- coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
111
- samples = torch.cat((importance, coverage), dim=1) # B, N
112
-
113
- else:
114
- # remaining
115
- remaining = idx[:, :] # B, H*W
116
-
117
- # coverage
118
- num_coverage = N
119
-
120
- coverage_list = []
121
- for i in range(B):
122
- idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
123
- coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
124
- coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
125
- samples = coverage
126
-
127
- # point coordinates
128
- rows_int = samples // W # 0 for first row, H-1 for last row
129
- rows_float = rows_int / float(H-1) # 0 to 1.0
130
- rows_float = (rows_float * 2.0) - 1.0 # -1.0 to 1.0
131
-
132
- cols_int = samples % W # 0 for first column, W-1 for last column
133
- cols_float = cols_int / float(W-1) # 0 to 1.0
134
- cols_float = (cols_float * 2.0) - 1.0 # -1.0 to 1.0
135
-
136
- point_coords = torch.zeros(B, 1, N, 2)
137
- point_coords[:, 0, :, 0] = cols_float # x coord
138
- point_coords[:, 0, :, 1] = rows_float # y coord
139
- point_coords = point_coords.to(device)
140
- return point_coords, rows_int, cols_int
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/normal/utils/losses.py DELETED
@@ -1,178 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import numpy as np
4
- import torch.nn.functional as F
5
-
6
-
7
- # compute loss
8
- class compute_loss(nn.Module):
9
- def __init__(self, args):
10
- """args.loss_fn can be one of following:
11
- - L1 - L1 loss (no uncertainty)
12
- - L2 - L2 loss (no uncertainty)
13
- - AL - Angular loss (no uncertainty)
14
- - NLL_vMF - NLL of vonMF distribution
15
- - NLL_ours - NLL of Angular vonMF distribution
16
- - UG_NLL_vMF - NLL of vonMF distribution (+ pixel-wise MLP + uncertainty-guided sampling)
17
- - UG_NLL_ours - NLL of Angular vonMF distribution (+ pixel-wise MLP + uncertainty-guided sampling)
18
- """
19
- super(compute_loss, self).__init__()
20
- self.loss_type = args.loss_fn
21
- if self.loss_type in ['L1', 'L2', 'AL', 'NLL_vMF', 'NLL_ours']:
22
- self.loss_fn = self.forward_R
23
- elif self.loss_type in ['UG_NLL_vMF', 'UG_NLL_ours']:
24
- self.loss_fn = self.forward_UG
25
- else:
26
- raise Exception('invalid loss type')
27
-
28
- def forward(self, *args):
29
- return self.loss_fn(*args)
30
-
31
- def forward_R(self, norm_out, gt_norm, gt_norm_mask):
32
- pred_norm, pred_kappa = norm_out[:, 0:3, :, :], norm_out[:, 3:, :, :]
33
-
34
- if self.loss_type == 'L1':
35
- l1 = torch.sum(torch.abs(gt_norm - pred_norm), dim=1, keepdim=True)
36
- loss = torch.mean(l1[gt_norm_mask])
37
-
38
- elif self.loss_type == 'L2':
39
- l2 = torch.sum(torch.square(gt_norm - pred_norm), dim=1, keepdim=True)
40
- loss = torch.mean(l2[gt_norm_mask])
41
-
42
- elif self.loss_type == 'AL':
43
- dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
44
-
45
- valid_mask = gt_norm_mask[:, 0, :, :].float() \
46
- * (dot.detach() < 0.999).float() \
47
- * (dot.detach() > -0.999).float()
48
- valid_mask = valid_mask > 0.0
49
-
50
- al = torch.acos(dot[valid_mask])
51
- loss = torch.mean(al)
52
-
53
- elif self.loss_type == 'NLL_vMF':
54
- dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
55
-
56
- valid_mask = gt_norm_mask[:, 0, :, :].float() \
57
- * (dot.detach() < 0.999).float() \
58
- * (dot.detach() > -0.999).float()
59
- valid_mask = valid_mask > 0.0
60
-
61
- dot = dot[valid_mask]
62
- kappa = pred_kappa[:, 0, :, :][valid_mask]
63
-
64
- loss_pixelwise = - torch.log(kappa) \
65
- - (kappa * (dot - 1)) \
66
- + torch.log(1 - torch.exp(- 2 * kappa))
67
- loss = torch.mean(loss_pixelwise)
68
-
69
- elif self.loss_type == 'NLL_ours':
70
- dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
71
-
72
- valid_mask = gt_norm_mask[:, 0, :, :].float() \
73
- * (dot.detach() < 0.999).float() \
74
- * (dot.detach() > -0.999).float()
75
- valid_mask = valid_mask > 0.0
76
-
77
- dot = dot[valid_mask]
78
- kappa = pred_kappa[:, 0, :, :][valid_mask]
79
-
80
- loss_pixelwise = - torch.log(torch.square(kappa) + 1) \
81
- + kappa * torch.acos(dot) \
82
- + torch.log(1 + torch.exp(-kappa * np.pi))
83
- loss = torch.mean(loss_pixelwise)
84
-
85
- else:
86
- raise Exception('invalid loss type')
87
-
88
- return loss
89
-
90
-
91
- def forward_UG(self, pred_list, coord_list, gt_norm, gt_norm_mask):
92
- loss = 0.0
93
- for (pred, coord) in zip(pred_list, coord_list):
94
- if coord is None:
95
- pred = F.interpolate(pred, size=[gt_norm.size(2), gt_norm.size(3)], mode='bilinear', align_corners=True)
96
- pred_norm, pred_kappa = pred[:, 0:3, :, :], pred[:, 3:, :, :]
97
-
98
- if self.loss_type == 'UG_NLL_vMF':
99
- dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
100
-
101
- valid_mask = gt_norm_mask[:, 0, :, :].float() \
102
- * (dot.detach() < 0.999).float() \
103
- * (dot.detach() > -0.999).float()
104
- valid_mask = valid_mask > 0.5
105
-
106
- # mask
107
- dot = dot[valid_mask]
108
- kappa = pred_kappa[:, 0, :, :][valid_mask]
109
-
110
- loss_pixelwise = - torch.log(kappa) \
111
- - (kappa * (dot - 1)) \
112
- + torch.log(1 - torch.exp(- 2 * kappa))
113
- loss = loss + torch.mean(loss_pixelwise)
114
-
115
- elif self.loss_type == 'UG_NLL_ours':
116
- dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
117
-
118
- valid_mask = gt_norm_mask[:, 0, :, :].float() \
119
- * (dot.detach() < 0.999).float() \
120
- * (dot.detach() > -0.999).float()
121
- valid_mask = valid_mask > 0.5
122
-
123
- dot = dot[valid_mask]
124
- kappa = pred_kappa[:, 0, :, :][valid_mask]
125
-
126
- loss_pixelwise = - torch.log(torch.square(kappa) + 1) \
127
- + kappa * torch.acos(dot) \
128
- + torch.log(1 + torch.exp(-kappa * np.pi))
129
- loss = loss + torch.mean(loss_pixelwise)
130
-
131
- else:
132
- raise Exception
133
-
134
- else:
135
- # coord: B, 1, N, 2
136
- # pred: B, 4, N
137
- gt_norm_ = F.grid_sample(gt_norm, coord, mode='nearest', align_corners=True) # (B, 3, 1, N)
138
- gt_norm_mask_ = F.grid_sample(gt_norm_mask.float(), coord, mode='nearest', align_corners=True) # (B, 1, 1, N)
139
- gt_norm_ = gt_norm_[:, :, 0, :] # (B, 3, N)
140
- gt_norm_mask_ = gt_norm_mask_[:, :, 0, :] > 0.5 # (B, 1, N)
141
-
142
- pred_norm, pred_kappa = pred[:, 0:3, :], pred[:, 3:, :]
143
-
144
- if self.loss_type == 'UG_NLL_vMF':
145
- dot = torch.cosine_similarity(pred_norm, gt_norm_, dim=1) # (B, N)
146
-
147
- valid_mask = gt_norm_mask_[:, 0, :].float() \
148
- * (dot.detach() < 0.999).float() \
149
- * (dot.detach() > -0.999).float()
150
- valid_mask = valid_mask > 0.5
151
-
152
- dot = dot[valid_mask]
153
- kappa = pred_kappa[:, 0, :][valid_mask]
154
-
155
- loss_pixelwise = - torch.log(kappa) \
156
- - (kappa * (dot - 1)) \
157
- + torch.log(1 - torch.exp(- 2 * kappa))
158
- loss = loss + torch.mean(loss_pixelwise)
159
-
160
- elif self.loss_type == 'UG_NLL_ours':
161
- dot = torch.cosine_similarity(pred_norm, gt_norm_, dim=1) # (B, N)
162
-
163
- valid_mask = gt_norm_mask_[:, 0, :].float() \
164
- * (dot.detach() < 0.999).float() \
165
- * (dot.detach() > -0.999).float()
166
- valid_mask = valid_mask > 0.5
167
-
168
- dot = dot[valid_mask]
169
- kappa = pred_kappa[:, 0, :][valid_mask]
170
-
171
- loss_pixelwise = - torch.log(torch.square(kappa) + 1) \
172
- + kappa * torch.acos(dot) \
173
- + torch.log(1 + torch.exp(-kappa * np.pi))
174
- loss = loss + torch.mean(loss_pixelwise)
175
-
176
- else:
177
- raise Exception
178
- return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/normal/utils/utils.py DELETED
@@ -1,191 +0,0 @@
1
- import os
2
- import numpy as np
3
- from PIL import Image
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
-
9
- import matplotlib
10
- matplotlib.use('Agg')
11
- import matplotlib.pyplot as plt
12
-
13
-
14
-
15
- # convert arg line to args
16
- def convert_arg_line_to_args(arg_line):
17
- for arg in arg_line.split():
18
- if not arg.strip():
19
- continue
20
- yield str(arg)
21
-
22
-
23
- # save args
24
- def save_args(args, filename):
25
- with open(filename, 'w') as f:
26
- for arg in vars(args):
27
- f.write('{}: {}\n'.format(arg, getattr(args, arg)))
28
-
29
-
30
- # concatenate images
31
- def concat_image(image_path_list, concat_image_path):
32
- imgs = [Image.open(i).convert("RGB").resize((640, 480), resample=Image.BILINEAR) for i in image_path_list]
33
- imgs_list = []
34
- for i in range(len(imgs)):
35
- img = imgs[i]
36
- imgs_list.append(np.asarray(img))
37
-
38
- H, W, _ = np.asarray(img).shape
39
- imgs_list.append(255 * np.ones((H, 20, 3)).astype('uint8'))
40
-
41
- imgs_comb = np.hstack(imgs_list[:-1])
42
- imgs_comb = Image.fromarray(imgs_comb)
43
- imgs_comb.save(concat_image_path)
44
-
45
-
46
- # load model
47
- def load_checkpoint(fpath, model):
48
- ckpt = torch.load(fpath, map_location='cpu')['model']
49
-
50
- load_dict = {}
51
- for k, v in ckpt.items():
52
- if k.startswith('module.'):
53
- k_ = k.replace('module.', '')
54
- load_dict[k_] = v
55
- else:
56
- load_dict[k] = v
57
-
58
- model.load_state_dict(load_dict)
59
- return model
60
-
61
-
62
- # compute normal errors
63
- def compute_normal_errors(total_normal_errors):
64
- metrics = {
65
- 'mean': np.average(total_normal_errors),
66
- 'median': np.median(total_normal_errors),
67
- 'rmse': np.sqrt(np.sum(total_normal_errors * total_normal_errors) / total_normal_errors.shape),
68
- 'a1': 100.0 * (np.sum(total_normal_errors < 5) / total_normal_errors.shape[0]),
69
- 'a2': 100.0 * (np.sum(total_normal_errors < 7.5) / total_normal_errors.shape[0]),
70
- 'a3': 100.0 * (np.sum(total_normal_errors < 11.25) / total_normal_errors.shape[0]),
71
- 'a4': 100.0 * (np.sum(total_normal_errors < 22.5) / total_normal_errors.shape[0]),
72
- 'a5': 100.0 * (np.sum(total_normal_errors < 30) / total_normal_errors.shape[0])
73
- }
74
- return metrics
75
-
76
-
77
- # log normal errors
78
- def log_normal_errors(metrics, where_to_write, first_line):
79
- print(first_line)
80
- print("mean median rmse 5 7.5 11.25 22.5 30")
81
- print("%.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f" % (
82
- metrics['mean'], metrics['median'], metrics['rmse'],
83
- metrics['a1'], metrics['a2'], metrics['a3'], metrics['a4'], metrics['a5']))
84
-
85
- with open(where_to_write, 'a') as f:
86
- f.write('%s\n' % first_line)
87
- f.write("mean median rmse 5 7.5 11.25 22.5 30\n")
88
- f.write("%.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f\n\n" % (
89
- metrics['mean'], metrics['median'], metrics['rmse'],
90
- metrics['a1'], metrics['a2'], metrics['a3'], metrics['a4'], metrics['a5']))
91
-
92
-
93
- # makedir
94
- def makedir(dirpath):
95
- if not os.path.exists(dirpath):
96
- os.makedirs(dirpath)
97
-
98
-
99
- # makedir from list
100
- def make_dir_from_list(dirpath_list):
101
- for dirpath in dirpath_list:
102
- makedir(dirpath)
103
-
104
-
105
-
106
- ########################################################################################################################
107
- # Visualization
108
- ########################################################################################################################
109
-
110
-
111
- # unnormalize image
112
- __imagenet_stats = {'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]}
113
- def unnormalize(img_in):
114
- img_out = np.zeros(img_in.shape)
115
- for ich in range(3):
116
- img_out[:, :, ich] = img_in[:, :, ich] * __imagenet_stats['std'][ich]
117
- img_out[:, :, ich] += __imagenet_stats['mean'][ich]
118
- img_out = (img_out * 255).astype(np.uint8)
119
- return img_out
120
-
121
-
122
- # kappa to exp error (only applicable to AngMF distribution)
123
- def kappa_to_alpha(pred_kappa):
124
- alpha = ((2 * pred_kappa) / ((pred_kappa ** 2.0) + 1)) \
125
- + ((np.exp(- pred_kappa * np.pi) * np.pi) / (1 + np.exp(- pred_kappa * np.pi)))
126
- alpha = np.degrees(alpha)
127
- return alpha
128
-
129
-
130
- # normal vector to rgb values
131
- def norm_to_rgb(norm):
132
- # norm: (B, H, W, 3)
133
- norm_rgb = ((norm[0, ...] + 1) * 0.5) * 255
134
- norm_rgb = np.clip(norm_rgb, a_min=0, a_max=255)
135
- norm_rgb = norm_rgb.astype(np.uint8)
136
- return norm_rgb
137
-
138
-
139
- # visualize during training
140
- def visualize(args, img, gt_norm, gt_norm_mask, norm_out_list, total_iter):
141
- B, _, H, W = gt_norm.shape
142
-
143
- pred_norm_list = []
144
- pred_kappa_list = []
145
- for norm_out in norm_out_list:
146
- norm_out = F.interpolate(norm_out, size=[gt_norm.size(2), gt_norm.size(3)], mode='nearest')
147
- pred_norm = norm_out[:, :3, :, :] # (B, 3, H, W)
148
- pred_norm = pred_norm.detach().cpu().permute(0, 2, 3, 1).numpy() # (B, H, W, 3)
149
- pred_norm_list.append(pred_norm)
150
-
151
- pred_kappa = norm_out[:, 3:, :, :] # (B, 1, H, W)
152
- pred_kappa = pred_kappa.detach().cpu().permute(0, 2, 3, 1).numpy() # (B, H, W, 1)
153
- pred_kappa_list.append(pred_kappa)
154
-
155
- # to numpy arrays
156
- img = img.detach().cpu().permute(0, 2, 3, 1).numpy() # (B, H, W, 3)
157
- gt_norm = gt_norm.detach().cpu().permute(0, 2, 3, 1).numpy() # (B, H, W, 3)
158
- gt_norm_mask = gt_norm_mask.detach().cpu().permute(0, 2, 3, 1).numpy() # (B, H, W, 1)
159
-
160
- # input image
161
- target_path = '%s/%08d_img.jpg' % (args.exp_vis_dir, total_iter)
162
- img = unnormalize(img[0, ...])
163
- plt.imsave(target_path, img)
164
-
165
- # gt norm
166
- gt_norm_rgb = ((gt_norm[0, ...] + 1) * 0.5) * 255
167
- gt_norm_rgb = np.clip(gt_norm_rgb, a_min=0, a_max=255)
168
- gt_norm_rgb = gt_norm_rgb.astype(np.uint8)
169
-
170
- target_path = '%s/%08d_gt_norm.jpg' % (args.exp_vis_dir, total_iter)
171
- plt.imsave(target_path, gt_norm_rgb * gt_norm_mask[0, ...])
172
-
173
- # pred_norm
174
- for i in range(len(pred_norm_list)):
175
- pred_norm = pred_norm_list[i]
176
- pred_norm_rgb = norm_to_rgb(pred_norm)
177
- target_path = '%s/%08d_pred_norm_%d.jpg' % (args.exp_vis_dir, total_iter, i)
178
- plt.imsave(target_path, pred_norm_rgb)
179
-
180
- pred_kappa = pred_kappa_list[i]
181
- pred_alpha = kappa_to_alpha(pred_kappa)
182
- target_path = '%s/%08d_pred_alpha_%d.jpg' % (args.exp_vis_dir, total_iter, i)
183
- plt.imsave(target_path, pred_alpha[0, :, :, 0], vmin=0, vmax=60, cmap='jet')
184
-
185
- # error in angles
186
- DP = np.sum(gt_norm * pred_norm, axis=3, keepdims=True) # (B, H, W, 1)
187
- DP = np.clip(DP, -1, 1)
188
- E = np.degrees(np.arccos(DP)) # (B, H, W, 1)
189
- E = E * gt_norm_mask
190
- target_path = '%s/%08d_pred_error_%d.jpg' % (args.exp_vis_dir, total_iter, i)
191
- plt.imsave(target_path, E[0, :, :, 0], vmin=0, vmax=60, cmap='jet')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/obj_detection/configs/Base-CRCNN-COCO.yaml DELETED
@@ -1,48 +0,0 @@
1
- MODEL:
2
- META_ARCHITECTURE: "GeneralizedRCNN"
3
- BACKBONE:
4
- NAME: "build_p67_resnet_fpn_backbone"
5
- WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
6
- RESNETS:
7
- OUT_FEATURES: ["res3", "res4", "res5"]
8
- DEPTH: 50
9
- FPN:
10
- IN_FEATURES: ["res3", "res4", "res5"]
11
- ANCHOR_GENERATOR:
12
- SIZES: !!python/object/apply:eval ["[[x, x * 2**(1.0/3), x * 2**(2.0/3) ] for x in [32, 64, 128, 256, 512 ]]"]
13
- ASPECT_RATIOS: [[0.5, 1.0, 2.0]]
14
- RPN:
15
- IN_FEATURES: ["p3", "p4", "p5", "p6", "p7"]
16
- PRE_NMS_TOPK_TRAIN: 2000
17
- PRE_NMS_TOPK_TEST: 1000
18
- POST_NMS_TOPK_TRAIN: 2000
19
- POST_NMS_TOPK_TEST: 1000
20
- ROI_HEADS:
21
- NUM_CLASSES: 80
22
- NAME: CustomCascadeROIHeads
23
- IN_FEATURES: ["p3", "p4", "p5"]
24
- SCORE_THRESH_TEST: 0.0001
25
- ROI_BOX_HEAD:
26
- NAME: "FastRCNNConvFCHead"
27
- NUM_FC: 2
28
- POOLER_RESOLUTION: 7
29
- CLS_AGNOSTIC_BBOX_REG: True
30
- DATASETS:
31
- TRAIN: ("coco_2017_train",)
32
- TEST: ("coco_2017_val",)
33
- TEST:
34
- DETECTIONS_PER_IMAGE: 300
35
- SOLVER:
36
- IMS_PER_BATCH: 16
37
- BASE_LR: 0.01
38
- STEPS: (60000, 80000)
39
- MAX_ITER: 90000
40
- CHECKPOINT_PERIOD: 1000000
41
- WARMUP_ITERS: 4000
42
- WARMUP_FACTOR: 0.00025
43
- CLIP_GRADIENTS:
44
- ENABLED: True
45
- INPUT:
46
- MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
47
- VERSION: 2
48
- OUTPUT_DIR: "output/UniDet/auto"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/obj_detection/configs/O365_CRFR50_CAS_2x.yaml DELETED
@@ -1,15 +0,0 @@
1
- _BASE_: "Base-CRCNN-COCO"
2
- MODEL:
3
- ROI_HEADS:
4
- NUM_CLASSES: 365
5
- DATASETS:
6
- TRAIN: ("objects365_train",)
7
- TEST: ("objects365_val",)
8
- SOLVER:
9
- IMS_PER_BATCH: 16
10
- BASE_LR: 0.01
11
- STEPS: (120000, 160000,)
12
- MAX_ITER: 180000
13
- CHECKPOINT_PERIOD: 120000
14
- DATALOADER:
15
- SAMPLER_TRAIN: "ClassAwareSampler"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/obj_detection/configs/OID_CRFR50_CAS_2x.yaml DELETED
@@ -1,22 +0,0 @@
1
- _BASE_: "Base-CRCNN-COCO"
2
- MODEL:
3
- ROI_HEADS:
4
- NUM_CLASSES: 500
5
- ROI_BOX_HEAD:
6
- USE_SIGMOID_CE: True
7
- USE_EQL_LOSS: True
8
- EQL_FREQ_CAT: 200
9
- EQL_CAT_INFO: 'datasets/oid/annotations/openimages_challenge_2019_train_v2_cat_info.json'
10
- HIERARCHY_IGNORE: True
11
- HIERARCHY_POS_PARENTS: True
12
- DATASETS:
13
- TRAIN: ("oid_train",)
14
- TEST: ("oid_val_expanded",)
15
- SOLVER:
16
- IMS_PER_BATCH: 16
17
- BASE_LR: 0.01
18
- STEPS: (120000, 160000,)
19
- MAX_ITER: 180000
20
- CHECKPOINT_PERIOD: 120000
21
- DATALOADER:
22
- SAMPLER_TRAIN: "ClassAwareSampler"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/obj_detection/configs/Partitioned_COIM_R50_6x+2x.yaml DELETED
@@ -1,28 +0,0 @@
1
- _BASE_: "Base-CRCNN-COCO.yaml"
2
- MODEL:
3
- META_ARCHITECTURE: "SplitClassifierRCNN"
4
- ROI_HEADS:
5
- NUM_CLASSES: -1
6
- NAME: "MultiDatasetCascadeROIHeads"
7
- ROI_BOX_HEAD:
8
- USE_SIGMOID_CE: True
9
- # USE_EQL_LOSS: True
10
- HIERARCHY_IGNORE: True
11
- HIERARCHY_POS_PARENTS: True
12
- DATASETS:
13
- TRAIN: ("coco_2017_train","objects365_train","oid_train","mapillary_960_train")
14
- TEST: ('coco_2017_val','oid_val_expanded','objects365_val',"mapillary_val")
15
- DATALOADER:
16
- SAMPLER_TRAIN: "MultiDatasetSampler"
17
- NUM_WORKERS: 1
18
- MULTI_DATASET:
19
- ENABLED: True
20
- DATASETS: ['coco', 'objects365', 'oid', 'mapillary']
21
- NUM_CLASSES: [80, 365, 500, 37]
22
- DATA_RATIO: [1, 1, 1, 1]
23
- USE_CAS: [False, True, True, False]
24
- SOLVER:
25
- IMS_PER_BATCH: 16
26
- BASE_LR: 0.001
27
- STEPS: (160000,)
28
- MAX_ITER: 180000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/experts/obj_detection/configs/Partitioned_COI_R50_2x.yaml DELETED
@@ -1,29 +0,0 @@
1
- _BASE_: "Base-CRCNN-COCO.yaml"
2
- MODEL:
3
- META_ARCHITECTURE: "SplitClassifierRCNN"
4
- ROI_HEADS:
5
- NUM_CLASSES: -1
6
- NAME: "MultiDatasetCascadeROIHeads"
7
- ROI_BOX_HEAD:
8
- USE_SIGMOID_CE: True
9
- # USE_EQL_LOSS: True
10
- HIERARCHY_IGNORE: True
11
- HIERARCHY_POS_PARENTS: True
12
- DATASETS:
13
- TRAIN: ("coco_2017_train","objects365_train","oid_train",)
14
- TEST: ('coco_2017_val','oid_val_expanded','objects365_val',)
15
- DATALOADER:
16
- SAMPLER_TRAIN: "MultiDatasetSampler"
17
- NUM_WORKERS: 1
18
- MULTI_DATASET:
19
- ENABLED: True
20
- DATASETS: ['coco', 'objects365', 'oid']
21
- NUM_CLASSES: [80, 365, 500]
22
- DATA_RATIO: [1, 1, 1]
23
- USE_CAS: [False, True, True]
24
- SOLVER:
25
- IMS_PER_BATCH: 16
26
- STEPS: (120000, 160000)
27
- MAX_ITER: 180000
28
- CHECKPOINT_PERIOD: 120000
29
-