Spaces:
Runtime error
Runtime error
first version submission
Browse filesThis view is limited to 50 files because it contains too many changes. Β
See raw diff
- app.py +59 -4
- main_gradio.py +84 -0
- models/__pycache__/blip2_model.cpython-38.pyc +0 -0
- models/__pycache__/blip2_model.cpython-39.pyc +0 -0
- models/__pycache__/controlnet_model.cpython-38.pyc +0 -0
- models/__pycache__/gpt_model.cpython-38.pyc +0 -0
- models/__pycache__/grit_model.cpython-38.pyc +0 -0
- models/__pycache__/image_text_transformation.cpython-38.pyc +0 -0
- models/__pycache__/image_text_transformation.cpython-39.pyc +0 -0
- models/__pycache__/region_semantic.cpython-38.pyc +0 -0
- models/blip2_model.py +38 -0
- models/controlnet_model.py +51 -0
- models/gpt_model.py +40 -0
- models/grit_model.py +26 -0
- models/grit_src/__pycache__/image_dense_captions.cpython-38.pyc +0 -0
- models/grit_src/configs/Base.yaml +77 -0
- models/grit_src/configs/GRiT_B_DenseCap.yaml +20 -0
- models/grit_src/configs/GRiT_B_DenseCap_ObjectDet.yaml +23 -0
- models/grit_src/configs/GRiT_B_ObjectDet.yaml +20 -0
- models/grit_src/configs/GRiT_H_ObjectDet.yaml +21 -0
- models/grit_src/configs/GRiT_L_ObjectDet.yaml +20 -0
- models/grit_src/grit/__init__.py +7 -0
- models/grit_src/grit/__pycache__/__init__.cpython-38.pyc +0 -0
- models/grit_src/grit/__pycache__/config.cpython-38.pyc +0 -0
- models/grit_src/grit/__pycache__/predictor.cpython-38.pyc +0 -0
- models/grit_src/grit/config.py +50 -0
- models/grit_src/grit/custom_solver.py +88 -0
- models/grit_src/grit/data/__pycache__/custom_build_augmentation.cpython-38.pyc +0 -0
- models/grit_src/grit/data/__pycache__/custom_dataset_mapper.cpython-38.pyc +0 -0
- models/grit_src/grit/data/custom_build_augmentation.py +44 -0
- models/grit_src/grit/data/custom_dataset_dataloader.py +250 -0
- models/grit_src/grit/data/custom_dataset_mapper.py +149 -0
- models/grit_src/grit/data/datasets/__pycache__/grit_coco.cpython-38.pyc +0 -0
- models/grit_src/grit/data/datasets/__pycache__/object365.cpython-38.pyc +0 -0
- models/grit_src/grit/data/datasets/__pycache__/vg.cpython-38.pyc +0 -0
- models/grit_src/grit/data/datasets/grit_coco.py +112 -0
- models/grit_src/grit/data/datasets/object365.py +111 -0
- models/grit_src/grit/data/datasets/vg.py +98 -0
- models/grit_src/grit/data/transforms/__pycache__/custom_augmentation_impl.cpython-38.pyc +0 -0
- models/grit_src/grit/data/transforms/__pycache__/custom_transform.cpython-38.pyc +0 -0
- models/grit_src/grit/data/transforms/custom_augmentation_impl.py +52 -0
- models/grit_src/grit/data/transforms/custom_transform.py +115 -0
- models/grit_src/grit/evaluation/eval.py +156 -0
- models/grit_src/grit/modeling/__pycache__/soft_nms.cpython-38.pyc +0 -0
- models/grit_src/grit/modeling/backbone/__pycache__/utils.cpython-38.pyc +0 -0
- models/grit_src/grit/modeling/backbone/__pycache__/vit.cpython-38.pyc +0 -0
- models/grit_src/grit/modeling/backbone/utils.py +186 -0
- models/grit_src/grit/modeling/backbone/vit.py +538 -0
- models/grit_src/grit/modeling/meta_arch/__pycache__/grit.cpython-38.pyc +0 -0
- models/grit_src/grit/modeling/meta_arch/grit.py +66 -0
app.py
CHANGED
@@ -1,7 +1,62 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
def
|
4 |
-
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import base64
|
6 |
+
from io import BytesIO
|
7 |
+
from models.image_text_transformation import ImageTextTransformation
|
8 |
|
9 |
+
def pil_image_to_base64(image):
|
10 |
+
buffered = BytesIO()
|
11 |
+
image.save(buffered, format="JPEG")
|
12 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
13 |
+
return img_str
|
14 |
|
15 |
+
def add_logo():
|
16 |
+
with open("examples/logo.png", "rb") as f:
|
17 |
+
logo_base64 = base64.b64encode(f.read()).decode()
|
18 |
+
return logo_base64
|
19 |
+
|
20 |
+
def process_image(image_src, processor):
|
21 |
+
gen_text = processor.image_to_text(image_src)
|
22 |
+
gen_image = processor.text_to_image(gen_text)
|
23 |
+
gen_image_str = pil_image_to_base64(gen_image)
|
24 |
+
# Combine the outputs into a single HTML output
|
25 |
+
custom_output = f'''
|
26 |
+
<h2>Image->Text->Image:</h2>
|
27 |
+
<div style="display: flex; flex-wrap: wrap;">
|
28 |
+
<div style="flex: 1;">
|
29 |
+
<h3>Image2Text</h3>
|
30 |
+
<p>{gen_text}</p>
|
31 |
+
</div>
|
32 |
+
<div style="flex: 1;">
|
33 |
+
<h3>Text2Image</h3>
|
34 |
+
<img src="data:image/jpeg;base64,{gen_image_str}" width="100%" />
|
35 |
+
</div>
|
36 |
+
</div>
|
37 |
+
'''
|
38 |
+
|
39 |
+
return custom_output
|
40 |
+
|
41 |
+
processor = ImageTextTransformation()
|
42 |
+
|
43 |
+
# Create Gradio input and output components
|
44 |
+
image_input = gr.inputs.Image(type='filepath', label="Input Image")
|
45 |
+
|
46 |
+
logo_base64 = add_logo()
|
47 |
+
# Create the title with the logo
|
48 |
+
title_with_logo = f'<img src="data:image/jpeg;base64,{logo_base64}" width="400" style="vertical-align: middle;"> Understanding Image with Text'
|
49 |
+
|
50 |
+
# Create Gradio interface
|
51 |
+
interface = gr.Interface(
|
52 |
+
fn=lambda image: process_image(image, processor), # Pass the processor object using a lambda function
|
53 |
+
inputs=image_input,
|
54 |
+
outputs=gr.outputs.HTML(),
|
55 |
+
title=title_with_logo,
|
56 |
+
description="""
|
57 |
+
This code support image to text transformation. Then the generated text can do retrieval, question answering et al to conduct zero-shot.
|
58 |
+
"""
|
59 |
+
)
|
60 |
+
|
61 |
+
# Launch the interface
|
62 |
+
interface.launch()
|
main_gradio.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import base64
|
6 |
+
from io import BytesIO
|
7 |
+
from models.image_text_transformation import ImageTextTransformation
|
8 |
+
|
9 |
+
def pil_image_to_base64(image):
|
10 |
+
buffered = BytesIO()
|
11 |
+
image.save(buffered, format="JPEG")
|
12 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
13 |
+
return img_str
|
14 |
+
|
15 |
+
def add_logo():
|
16 |
+
with open("examples/logo.png", "rb") as f:
|
17 |
+
logo_base64 = base64.b64encode(f.read()).decode()
|
18 |
+
return logo_base64
|
19 |
+
|
20 |
+
def process_image(image_src, processor):
|
21 |
+
gen_text = processor.image_to_text(image_src)
|
22 |
+
gen_image = processor.text_to_image(gen_text)
|
23 |
+
gen_image_str = pil_image_to_base64(gen_image)
|
24 |
+
# Combine the outputs into a single HTML output
|
25 |
+
custom_output = f'''
|
26 |
+
<h2>Image->Text->Image:</h2>
|
27 |
+
<div style="display: flex; flex-wrap: wrap;">
|
28 |
+
<div style="flex: 1;">
|
29 |
+
<h3>Image2Text</h3>
|
30 |
+
<p>{gen_text}</p>
|
31 |
+
</div>
|
32 |
+
<div style="flex: 1;">
|
33 |
+
<h3>Text2Image</h3>
|
34 |
+
<img src="data:image/jpeg;base64,{gen_image_str}" width="100%" />
|
35 |
+
</div>
|
36 |
+
</div>
|
37 |
+
<h2>Using Source Image to do Retrieval on COCO:</h2>
|
38 |
+
<div style="display: flex; flex-wrap: wrap;">
|
39 |
+
<div style="flex: 1;">
|
40 |
+
<h3>Retrieval Top-3 Text</h3>
|
41 |
+
<p>{gen_text}</p>
|
42 |
+
</div>
|
43 |
+
<div style="flex: 1;">
|
44 |
+
<h3>Retrieval Top-3 Image</h3>
|
45 |
+
<img src="data:image/jpeg;base64,{gen_image_str}" width="100%" />
|
46 |
+
</div>
|
47 |
+
</div>
|
48 |
+
<h2>Using Generated texts to do Retrieval on COCO:</h2>
|
49 |
+
<div style="display: flex; flex-wrap: wrap;">
|
50 |
+
<div style="flex: 1;">
|
51 |
+
<h3>Retrieval Top-3 Text</h3>
|
52 |
+
<p>{gen_text}</p>
|
53 |
+
</div>
|
54 |
+
<div style="flex: 1;">
|
55 |
+
<h3>Retrieval Top-3 Image</h3>
|
56 |
+
<img src="data:image/jpeg;base64,{gen_image_str}" width="100%" />
|
57 |
+
</div>
|
58 |
+
</div>
|
59 |
+
'''
|
60 |
+
|
61 |
+
return custom_output
|
62 |
+
|
63 |
+
processor = ImageTextTransformation()
|
64 |
+
|
65 |
+
# Create Gradio input and output components
|
66 |
+
image_input = gr.inputs.Image(type='filepath', label="Input Image")
|
67 |
+
|
68 |
+
logo_base64 = add_logo()
|
69 |
+
# Create the title with the logo
|
70 |
+
title_with_logo = f'<img src="data:image/jpeg;base64,{logo_base64}" width="400" style="vertical-align: middle;"> Understanding Image with Text'
|
71 |
+
|
72 |
+
# Create Gradio interface
|
73 |
+
interface = gr.Interface(
|
74 |
+
fn=lambda image: process_image(image, processor), # Pass the processor object using a lambda function
|
75 |
+
inputs=image_input,
|
76 |
+
outputs=gr.outputs.HTML(),
|
77 |
+
title=title_with_logo,
|
78 |
+
description="""
|
79 |
+
This code support image to text transformation. Then the generated text can do retrieval, question answering et al to conduct zero-shot.
|
80 |
+
"""
|
81 |
+
)
|
82 |
+
|
83 |
+
# Launch the interface
|
84 |
+
interface.launch()
|
models/__pycache__/blip2_model.cpython-38.pyc
ADDED
Binary file (1.88 kB). View file
|
|
models/__pycache__/blip2_model.cpython-39.pyc
ADDED
Binary file (1.88 kB). View file
|
|
models/__pycache__/controlnet_model.cpython-38.pyc
ADDED
Binary file (1.88 kB). View file
|
|
models/__pycache__/gpt_model.cpython-38.pyc
ADDED
Binary file (2.28 kB). View file
|
|
models/__pycache__/grit_model.cpython-38.pyc
ADDED
Binary file (1.38 kB). View file
|
|
models/__pycache__/image_text_transformation.cpython-38.pyc
ADDED
Binary file (2.55 kB). View file
|
|
models/__pycache__/image_text_transformation.cpython-39.pyc
ADDED
Binary file (2.55 kB). View file
|
|
models/__pycache__/region_semantic.cpython-38.pyc
ADDED
Binary file (2.2 kB). View file
|
|
models/blip2_model.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import requests
|
3 |
+
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
class ImageCaptioning:
|
8 |
+
def __init__(self) -> None:
|
9 |
+
self.device = None
|
10 |
+
# self.processor, self.model = None, None
|
11 |
+
self.processor, self.model = self.initialize_model()
|
12 |
+
|
13 |
+
def initialize_model(self):
|
14 |
+
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
15 |
+
self.device = "cpu" # for low gpu memory devices
|
16 |
+
if self.device == 'cpu':
|
17 |
+
self.data_type = torch.float32
|
18 |
+
else:
|
19 |
+
self.data_type = torch.float16
|
20 |
+
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
21 |
+
model = Blip2ForConditionalGeneration.from_pretrained(
|
22 |
+
"Salesforce/blip2-opt-2.7b", torch_dtype=self.data_type
|
23 |
+
)
|
24 |
+
model.to(self.device)
|
25 |
+
return processor, model
|
26 |
+
|
27 |
+
def image_caption(self, image_src):
|
28 |
+
image = Image.open(image_src)
|
29 |
+
inputs = self.processor(images=image, return_tensors="pt").to(self.device, self.data_type)
|
30 |
+
generated_ids = self.model.generate(**inputs)
|
31 |
+
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
32 |
+
print('*'*100 + '\nStep1, BLIP2 caption:')
|
33 |
+
print(generated_text)
|
34 |
+
print('\n' + '*'*100)
|
35 |
+
return generated_text
|
36 |
+
|
37 |
+
def image_caption_debug(self, image_src):
|
38 |
+
return "A dish with salmon, broccoli, and something yellow."
|
models/controlnet_model.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from diffusers import (
|
6 |
+
StableDiffusionControlNetPipeline,
|
7 |
+
ControlNetModel,
|
8 |
+
UniPCMultistepScheduler,
|
9 |
+
)
|
10 |
+
|
11 |
+
|
12 |
+
class TextToImage:
|
13 |
+
def __init__(self):
|
14 |
+
# self.model = None
|
15 |
+
self.model = self.initialize_model()
|
16 |
+
|
17 |
+
def initialize_model(self):
|
18 |
+
controlnet = ControlNetModel.from_pretrained(
|
19 |
+
"fusing/stable-diffusion-v1-5-controlnet-canny",
|
20 |
+
torch_dtype=torch.float16,
|
21 |
+
)
|
22 |
+
pipeline = StableDiffusionControlNetPipeline.from_pretrained(
|
23 |
+
"runwayml/stable-diffusion-v1-5",
|
24 |
+
controlnet=controlnet,
|
25 |
+
safety_checker=None,
|
26 |
+
torch_dtype=torch.float16,
|
27 |
+
)
|
28 |
+
pipeline.scheduler = UniPCMultistepScheduler.from_config(
|
29 |
+
pipeline.scheduler.config
|
30 |
+
)
|
31 |
+
pipeline.enable_model_cpu_offload()
|
32 |
+
return pipeline
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def preprocess_image(image):
|
36 |
+
image = np.array(image)
|
37 |
+
low_threshold = 100
|
38 |
+
high_threshold = 200
|
39 |
+
image = cv2.Canny(image, low_threshold, high_threshold)
|
40 |
+
image = np.stack([image, image, image], axis=2)
|
41 |
+
image = Image.fromarray(image)
|
42 |
+
return image
|
43 |
+
|
44 |
+
def text_to_image(self, text, image):
|
45 |
+
image = self.preprocess_image(image)
|
46 |
+
generated_image = self.model(text, image, num_inference_steps=20).images[0]
|
47 |
+
return generated_image
|
48 |
+
|
49 |
+
def text_to_image_debug(self, text, image):
|
50 |
+
print("text_to_image_debug")
|
51 |
+
return image
|
models/gpt_model.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
|
3 |
+
class ImageToText:
|
4 |
+
def __init__(self, api_key):
|
5 |
+
self.template = self.initialize_template()
|
6 |
+
openai.api_key = api_key
|
7 |
+
|
8 |
+
def initialize_template(self):
|
9 |
+
prompt_prefix_1 = """Generate only an informative and nature paragraph based on the given information(a,b,c,d):\n"""
|
10 |
+
prompt_prefix_2 = """\n a. Image Resolution: """
|
11 |
+
prompt_prefix_3 = """\n b. Image Caption: """
|
12 |
+
prompt_prefix_4 = """\n c. Dense Caption: """
|
13 |
+
prompt_prefix_5 = """\n d. Region Semantic: """
|
14 |
+
prompt_suffix = """\n There are some rules:
|
15 |
+
Show object, color and position.
|
16 |
+
Use nouns rather than coordinates to show position information of each object.
|
17 |
+
No more than 7 sentences.
|
18 |
+
Only use one paragraph.
|
19 |
+
Do not appear number.
|
20 |
+
"""
|
21 |
+
template = f"{prompt_prefix_1}{prompt_prefix_2}{{width}}X{{height}}{prompt_prefix_3}{{caption}}{prompt_prefix_4}{{dense_caption}}{prompt_prefix_5}{{region_semantic}}{prompt_suffix}"
|
22 |
+
return template
|
23 |
+
|
24 |
+
def paragraph_summary_with_gpt(self, caption, dense_caption, region_semantic, width, height):
|
25 |
+
question = self.template.format(width=width, height=height, caption=caption, dense_caption=dense_caption, region_semantic=region_semantic)
|
26 |
+
print('*'*100)
|
27 |
+
print("question:", question)
|
28 |
+
completion = openai.ChatCompletion.create(
|
29 |
+
model="gpt-3.5-turbo",
|
30 |
+
messages = [
|
31 |
+
{"role": "user", "content" : question}]
|
32 |
+
)
|
33 |
+
print("chatgpt response:", completion['choices'][0]['message']['content'])
|
34 |
+
print('*'*100)
|
35 |
+
return completion['choices'][0]['message']['content']
|
36 |
+
|
37 |
+
def paragraph_summary_with_gpt_debug(self, caption, dense_caption, width, height):
|
38 |
+
question = self.template.format(width=width, height=height, caption=caption, dense_caption=dense_caption)
|
39 |
+
print("paragraph_summary_with_gpt_debug:")
|
40 |
+
return question
|
models/grit_model.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from models.grit_src.image_dense_captions import image_caption_api
|
3 |
+
|
4 |
+
class DenseCaptioning():
|
5 |
+
def __init__(self) -> None:
|
6 |
+
self.model = None
|
7 |
+
|
8 |
+
|
9 |
+
def initialize_model(self):
|
10 |
+
pass
|
11 |
+
|
12 |
+
def image_dense_caption_debug(self, image_src):
|
13 |
+
dense_caption = """
|
14 |
+
1. the broccoli is green, [0, 0, 333, 325];
|
15 |
+
2. a piece of broccoli, [0, 147, 143, 324];
|
16 |
+
3. silver fork on plate, [4, 547, 252, 612];
|
17 |
+
"""
|
18 |
+
return dense_caption
|
19 |
+
|
20 |
+
def image_dense_caption(self, image_src):
|
21 |
+
dense_caption = image_caption_api(image_src)
|
22 |
+
print("Step2, Dense Caption:\n")
|
23 |
+
print(dense_caption)
|
24 |
+
print('\n'+'*'*100)
|
25 |
+
return dense_caption
|
26 |
+
|
models/grit_src/__pycache__/image_dense_captions.cpython-38.pyc
ADDED
Binary file (2.54 kB). View file
|
|
models/grit_src/configs/Base.yaml
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL:
|
2 |
+
META_ARCHITECTURE: "GRiT"
|
3 |
+
MASK_ON: True
|
4 |
+
PROPOSAL_GENERATOR:
|
5 |
+
NAME: "CenterNet"
|
6 |
+
FPN:
|
7 |
+
IN_FEATURES: ["layer3", "layer4", "layer5"]
|
8 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
9 |
+
PIXEL_STD: [58.395, 57.12, 57.375]
|
10 |
+
ROI_HEADS:
|
11 |
+
NAME: GRiTROIHeadsAndTextDecoder
|
12 |
+
IN_FEATURES: ["p3", "p4", "p5"]
|
13 |
+
IOU_THRESHOLDS: [0.6]
|
14 |
+
NUM_CLASSES: 1
|
15 |
+
SCORE_THRESH_TEST: 0.02
|
16 |
+
NMS_THRESH_TEST: 0.5
|
17 |
+
OBJECT_FEAT_POOLER_RES: 14
|
18 |
+
ROI_BOX_CASCADE_HEAD:
|
19 |
+
IOUS: [0.6, 0.7, 0.8]
|
20 |
+
ROI_BOX_HEAD:
|
21 |
+
NAME: "FastRCNNConvFCHead"
|
22 |
+
NUM_FC: 2
|
23 |
+
POOLER_RESOLUTION: 7
|
24 |
+
CLS_AGNOSTIC_BBOX_REG: True
|
25 |
+
MULT_PROPOSAL_SCORE: True
|
26 |
+
ROI_MASK_HEAD:
|
27 |
+
NAME: "MaskRCNNConvUpsampleHead"
|
28 |
+
NUM_CONV: 4
|
29 |
+
POOLER_RESOLUTION: 14
|
30 |
+
CLS_AGNOSTIC_MASK: True
|
31 |
+
CENTERNET:
|
32 |
+
NUM_CLASSES: 1
|
33 |
+
REG_WEIGHT: 1.
|
34 |
+
NOT_NORM_REG: True
|
35 |
+
ONLY_PROPOSAL: True
|
36 |
+
WITH_AGN_HM: True
|
37 |
+
INFERENCE_TH: 0.0001
|
38 |
+
PRE_NMS_TOPK_TRAIN: 4000
|
39 |
+
POST_NMS_TOPK_TRAIN: 2000
|
40 |
+
PRE_NMS_TOPK_TEST: 1000
|
41 |
+
POST_NMS_TOPK_TEST: 256
|
42 |
+
NMS_TH_TRAIN: 0.9
|
43 |
+
NMS_TH_TEST: 0.9
|
44 |
+
POS_WEIGHT: 0.5
|
45 |
+
NEG_WEIGHT: 0.5
|
46 |
+
IGNORE_HIGH_FP: 0.85
|
47 |
+
DATASETS:
|
48 |
+
TRAIN: ("coco_2017_train",)
|
49 |
+
TEST: ("coco_2017_val",)
|
50 |
+
DATALOADER:
|
51 |
+
SAMPLER_TRAIN: "MultiDatasetSampler"
|
52 |
+
DATASET_RATIO: [1]
|
53 |
+
DATASET_INPUT_SIZE: [1024]
|
54 |
+
DATASET_INPUT_SCALE: [[0.1, 2.0]]
|
55 |
+
FILTER_EMPTY_ANNOTATIONS: False
|
56 |
+
NUM_WORKERS: 8
|
57 |
+
TEST:
|
58 |
+
DETECTIONS_PER_IMAGE: 256
|
59 |
+
SOLVER:
|
60 |
+
LR_SCHEDULER_NAME: "WarmupCosineLR"
|
61 |
+
CHECKPOINT_PERIOD: 10000
|
62 |
+
WARMUP_ITERS: 1000
|
63 |
+
WARMUP_FACTOR: 0.001
|
64 |
+
USE_CUSTOM_SOLVER: True
|
65 |
+
OPTIMIZER: "ADAMW"
|
66 |
+
MAX_ITER: 180000
|
67 |
+
IMS_PER_BATCH: 64
|
68 |
+
BASE_LR: 0.00008
|
69 |
+
VIT_LAYER_DECAY: True
|
70 |
+
CLIP_GRADIENTS:
|
71 |
+
ENABLED: True
|
72 |
+
INPUT:
|
73 |
+
FORMAT: RGB
|
74 |
+
CUSTOM_AUG: EfficientDetResizeCrop
|
75 |
+
TRAIN_SIZE: 640
|
76 |
+
USE_ACT_CHECKPOINT: True
|
77 |
+
VERSION: 2
|
models/grit_src/configs/GRiT_B_DenseCap.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "Base.yaml"
|
2 |
+
MODEL:
|
3 |
+
TRAIN_TASK: ["DenseCap"]
|
4 |
+
TEST_TASK: "DenseCap"
|
5 |
+
MASK_ON: False
|
6 |
+
ROI_HEADS:
|
7 |
+
SOFT_NMS_ENABLED: False
|
8 |
+
BEAM_SIZE: 1
|
9 |
+
WEIGHTS: "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth"
|
10 |
+
BACKBONE:
|
11 |
+
NAME: build_vit_fpn_backbone
|
12 |
+
VIT_LAYERS: 12
|
13 |
+
SOLVER:
|
14 |
+
VIT_LAYER_DECAY_RATE: 0.7
|
15 |
+
DATASETS:
|
16 |
+
TRAIN: ("vg_train",)
|
17 |
+
TEST: ("vg_test",)
|
18 |
+
DATALOADER:
|
19 |
+
DATASET_BS: 2
|
20 |
+
OUTPUT_DIR: "./output/GRiT_B_DenseCap"
|
models/grit_src/configs/GRiT_B_DenseCap_ObjectDet.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "Base.yaml"
|
2 |
+
MODEL:
|
3 |
+
TRAIN_TASK: ["ObjectDet", "DenseCap"]
|
4 |
+
TEST_TASK: "DenseCap" # DenseCap or ObjectDet: Choose one for testing
|
5 |
+
MASK_ON: True
|
6 |
+
ROI_HEADS:
|
7 |
+
SOFT_NMS_ENABLED: False
|
8 |
+
BEAM_SIZE: 1
|
9 |
+
WEIGHTS: "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth"
|
10 |
+
BACKBONE:
|
11 |
+
NAME: build_vit_fpn_backbone
|
12 |
+
VIT_LAYERS: 12
|
13 |
+
SOLVER:
|
14 |
+
VIT_LAYER_DECAY_RATE: 0.7
|
15 |
+
DATASETS:
|
16 |
+
TRAIN: ("GRiT_coco2017_train", "vg_train")
|
17 |
+
TEST: ("coco_2017_test-dev",)
|
18 |
+
DATALOADER:
|
19 |
+
DATASET_RATIO: [1, 1]
|
20 |
+
DATASET_BS: 2
|
21 |
+
DATASET_INPUT_SIZE: [1024, 1024]
|
22 |
+
DATASET_INPUT_SCALE: [[0.1, 2.0], [0.1, 2.0]]
|
23 |
+
OUTPUT_DIR: "./output/GRiT_B_DenseCap_ObjectDet"
|
models/grit_src/configs/GRiT_B_ObjectDet.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "Base.yaml"
|
2 |
+
MODEL:
|
3 |
+
TRAIN_TASK: ["ObjectDet"]
|
4 |
+
TEST_TASK: "ObjectDet"
|
5 |
+
MASK_ON: True
|
6 |
+
ROI_HEADS:
|
7 |
+
SOFT_NMS_ENABLED: True
|
8 |
+
BEAM_SIZE: 3
|
9 |
+
WEIGHTS: "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth"
|
10 |
+
BACKBONE:
|
11 |
+
NAME: build_vit_fpn_backbone
|
12 |
+
VIT_LAYERS: 12
|
13 |
+
SOLVER:
|
14 |
+
VIT_LAYER_DECAY_RATE: 0.7
|
15 |
+
DATASETS:
|
16 |
+
TRAIN: ("GRiT_coco2017_train",)
|
17 |
+
TEST: ("coco_2017_val",)
|
18 |
+
DATALOADER:
|
19 |
+
DATASET_BS: 2
|
20 |
+
OUTPUT_DIR: "./output/GRiT_B_ObjectDet"
|
models/grit_src/configs/GRiT_H_ObjectDet.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "Base.yaml"
|
2 |
+
MODEL:
|
3 |
+
TRAIN_TASK: ["ObjectDet"]
|
4 |
+
TEST_TASK: "ObjectDet"
|
5 |
+
MASK_ON: True
|
6 |
+
ROI_HEADS:
|
7 |
+
SOFT_NMS_ENABLED: True
|
8 |
+
BEAM_SIZE: 3
|
9 |
+
WEIGHTS: "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_huge_p14to16.pth"
|
10 |
+
BACKBONE:
|
11 |
+
NAME: build_vit_fpn_backbone_huge
|
12 |
+
VIT_LAYERS: 32
|
13 |
+
SOLVER:
|
14 |
+
MAX_ITER: 135000
|
15 |
+
VIT_LAYER_DECAY_RATE: 0.9
|
16 |
+
DATASETS:
|
17 |
+
TRAIN: ("GRiT_coco2017_train",)
|
18 |
+
TEST: ("coco_2017_val",)
|
19 |
+
DATALOADER:
|
20 |
+
DATASET_BS: 1
|
21 |
+
OUTPUT_DIR: "./output/GRiT_H_ObjectDet"
|
models/grit_src/configs/GRiT_L_ObjectDet.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "Base.yaml"
|
2 |
+
MODEL:
|
3 |
+
TRAIN_TASK: ["ObjectDet"]
|
4 |
+
TEST_TASK: "ObjectDet"
|
5 |
+
MASK_ON: True
|
6 |
+
ROI_HEADS:
|
7 |
+
SOFT_NMS_ENABLED: True
|
8 |
+
BEAM_SIZE: 3
|
9 |
+
WEIGHTS: "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_large.pth"
|
10 |
+
BACKBONE:
|
11 |
+
NAME: build_vit_fpn_backbone_large
|
12 |
+
VIT_LAYERS: 24
|
13 |
+
SOLVER:
|
14 |
+
VIT_LAYER_DECAY_RATE: 0.8
|
15 |
+
DATASETS:
|
16 |
+
TRAIN: ("GRiT_coco2017_train",)
|
17 |
+
TEST: ("coco_2017_val",)
|
18 |
+
DATALOADER:
|
19 |
+
DATASET_BS: 1
|
20 |
+
OUTPUT_DIR: "./output/GRiT_L_ObjectDet"
|
models/grit_src/grit/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .modeling.meta_arch import grit
|
2 |
+
from .modeling.roi_heads import grit_roi_heads
|
3 |
+
from .modeling.backbone import vit
|
4 |
+
|
5 |
+
from .data.datasets import object365
|
6 |
+
from .data.datasets import vg
|
7 |
+
from .data.datasets import grit_coco
|
models/grit_src/grit/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (405 Bytes). View file
|
|
models/grit_src/grit/__pycache__/config.cpython-38.pyc
ADDED
Binary file (1.4 kB). View file
|
|
models/grit_src/grit/__pycache__/predictor.cpython-38.pyc
ADDED
Binary file (2.65 kB). View file
|
|
models/grit_src/grit/config.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from detectron2.config import CfgNode as CN
|
2 |
+
|
3 |
+
|
4 |
+
def add_grit_config(cfg):
|
5 |
+
_C = cfg
|
6 |
+
|
7 |
+
_C.MODEL.BEAM_SIZE = 1
|
8 |
+
_C.MODEL.TRAIN_TASK = ["ObjectDet", "DenseCap"]
|
9 |
+
_C.MODEL.TEST_TASK = "DenseCap" # This can be varied if the model is jointly trained on multiple tasks
|
10 |
+
|
11 |
+
_C.MODEL.ROI_BOX_HEAD.USE_BIAS = 0.0 # >= 0: not use
|
12 |
+
_C.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE = False
|
13 |
+
|
14 |
+
_C.MODEL.ROI_HEADS.MASK_WEIGHT = 1.0
|
15 |
+
_C.MODEL.ROI_HEADS.OBJECT_FEAT_POOLER_RES = 14
|
16 |
+
_C.MODEL.ROI_HEADS.SOFT_NMS_ENABLED = False
|
17 |
+
|
18 |
+
# Backbones
|
19 |
+
_C.MODEL.VIT_LAYERS = 12
|
20 |
+
|
21 |
+
# Text Decoder
|
22 |
+
_C.TEXT_DECODER = CN()
|
23 |
+
_C.TEXT_DECODER.VOCAB_SIZE = 30522
|
24 |
+
_C.TEXT_DECODER.HIDDEN_SIZE = 768
|
25 |
+
_C.TEXT_DECODER.NUM_LAYERS = 6
|
26 |
+
_C.TEXT_DECODER.ATTENTION_HEADS = 12
|
27 |
+
_C.TEXT_DECODER.FEEDFORWARD_SIZE = 768 * 4
|
28 |
+
|
29 |
+
# Multi-dataset dataloader
|
30 |
+
_C.DATALOADER.DATASET_RATIO = [1, 1] # sample ratio
|
31 |
+
_C.DATALOADER.DATASET_BS = 1
|
32 |
+
_C.DATALOADER.DATASET_INPUT_SIZE = [1024, 1024]
|
33 |
+
_C.DATALOADER.DATASET_INPUT_SCALE = [(0.1, 2.0), (0.1, 2.0)]
|
34 |
+
_C.DATALOADER.DATASET_MIN_SIZES = [(640, 800), (640, 800)]
|
35 |
+
_C.DATALOADER.DATASET_MAX_SIZES = [1333, 1333]
|
36 |
+
|
37 |
+
_C.SOLVER.USE_CUSTOM_SOLVER = True
|
38 |
+
_C.SOLVER.OPTIMIZER = 'ADAMW'
|
39 |
+
_C.SOLVER.VIT_LAYER_DECAY = True
|
40 |
+
_C.SOLVER.VIT_LAYER_DECAY_RATE = 0.7
|
41 |
+
|
42 |
+
_C.INPUT.CUSTOM_AUG = 'EfficientDetResizeCrop'
|
43 |
+
_C.INPUT.TRAIN_SIZE = 1024
|
44 |
+
_C.INPUT.TEST_SIZE = 1024
|
45 |
+
_C.INPUT.SCALE_RANGE = (0.1, 2.)
|
46 |
+
# 'default' for fixed short / long edge
|
47 |
+
_C.INPUT.TEST_INPUT_TYPE = 'default'
|
48 |
+
|
49 |
+
_C.FIND_UNUSED_PARAM = True
|
50 |
+
_C.USE_ACT_CHECKPOINT = True
|
models/grit_src/grit/custom_solver.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
# Modified by Jialian Wu from https://github.com/facebookresearch/Detic/blob/main/detic/custom_solver.py
|
3 |
+
import itertools
|
4 |
+
from typing import Any, Callable, Dict, Iterable, List, Set, Type, Union
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from detectron2.config import CfgNode
|
8 |
+
|
9 |
+
from detectron2.solver.build import maybe_add_gradient_clipping
|
10 |
+
|
11 |
+
|
12 |
+
def build_custom_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer:
|
13 |
+
params: List[Dict[str, Any]] = []
|
14 |
+
memo: Set[torch.nn.parameter.Parameter] = set()
|
15 |
+
optimizer_type = cfg.SOLVER.OPTIMIZER
|
16 |
+
|
17 |
+
for key, value in model.named_parameters(recurse=True):
|
18 |
+
if not value.requires_grad:
|
19 |
+
continue
|
20 |
+
# Avoid duplicating parameters
|
21 |
+
if value in memo:
|
22 |
+
continue
|
23 |
+
memo.add(value)
|
24 |
+
lr = cfg.SOLVER.BASE_LR
|
25 |
+
weight_decay = cfg.SOLVER.WEIGHT_DECAY
|
26 |
+
|
27 |
+
if cfg.SOLVER.VIT_LAYER_DECAY:
|
28 |
+
lr = lr * get_vit_lr_decay_rate(key, cfg.SOLVER.VIT_LAYER_DECAY_RATE, cfg.MODEL.VIT_LAYERS)
|
29 |
+
|
30 |
+
param = {"params": [value], "lr": lr}
|
31 |
+
if optimizer_type != 'ADAMW':
|
32 |
+
param['weight_decay'] = weight_decay
|
33 |
+
params += [param]
|
34 |
+
|
35 |
+
def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class
|
36 |
+
# detectron2 doesn't have full model gradient clipping now
|
37 |
+
clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
|
38 |
+
enable = (
|
39 |
+
cfg.SOLVER.CLIP_GRADIENTS.ENABLED
|
40 |
+
and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
|
41 |
+
and clip_norm_val > 0.0
|
42 |
+
)
|
43 |
+
|
44 |
+
class FullModelGradientClippingOptimizer(optim):
|
45 |
+
def step(self, closure=None):
|
46 |
+
all_params = itertools.chain(*[x["params"] for x in self.param_groups])
|
47 |
+
torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
|
48 |
+
super().step(closure=closure)
|
49 |
+
|
50 |
+
return FullModelGradientClippingOptimizer if enable else optim
|
51 |
+
|
52 |
+
|
53 |
+
if optimizer_type == 'SGD':
|
54 |
+
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
|
55 |
+
params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM,
|
56 |
+
nesterov=cfg.SOLVER.NESTEROV
|
57 |
+
)
|
58 |
+
elif optimizer_type == 'ADAMW':
|
59 |
+
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
|
60 |
+
params, cfg.SOLVER.BASE_LR,
|
61 |
+
weight_decay=cfg.SOLVER.WEIGHT_DECAY
|
62 |
+
)
|
63 |
+
else:
|
64 |
+
raise NotImplementedError(f"no optimizer type {optimizer_type}")
|
65 |
+
if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
|
66 |
+
optimizer = maybe_add_gradient_clipping(cfg, optimizer)
|
67 |
+
return optimizer
|
68 |
+
|
69 |
+
|
70 |
+
def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
|
71 |
+
"""
|
72 |
+
Calculate lr decay rate for different ViT blocks.
|
73 |
+
Args:
|
74 |
+
name (string): parameter name.
|
75 |
+
lr_decay_rate (float): base lr decay rate.
|
76 |
+
num_layers (int): number of ViT blocks.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
lr decay rate for the given parameter.
|
80 |
+
"""
|
81 |
+
layer_id = num_layers + 1
|
82 |
+
if name.startswith("backbone"):
|
83 |
+
if ".pos_embed" in name or ".patch_embed" in name:
|
84 |
+
layer_id = 0
|
85 |
+
elif ".blocks." in name and ".residual." not in name:
|
86 |
+
layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
|
87 |
+
|
88 |
+
return lr_decay_rate ** (num_layers + 1 - layer_id)
|
models/grit_src/grit/data/__pycache__/custom_build_augmentation.cpython-38.pyc
ADDED
Binary file (1.21 kB). View file
|
|
models/grit_src/grit/data/__pycache__/custom_dataset_mapper.cpython-38.pyc
ADDED
Binary file (5.68 kB). View file
|
|
models/grit_src/grit/data/custom_build_augmentation.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from detectron2.data import transforms as T
|
3 |
+
from .transforms.custom_augmentation_impl import EfficientDetResizeCrop
|
4 |
+
|
5 |
+
|
6 |
+
def build_custom_augmentation(cfg, is_train, scale=None, size=None, \
|
7 |
+
min_size=None, max_size=None):
|
8 |
+
"""
|
9 |
+
Create a list of default :class:`Augmentation` from config.
|
10 |
+
Now it includes resizing and flipping.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
list[Augmentation]
|
14 |
+
"""
|
15 |
+
if cfg.INPUT.CUSTOM_AUG == 'ResizeShortestEdge':
|
16 |
+
if is_train:
|
17 |
+
min_size = cfg.INPUT.MIN_SIZE_TRAIN if min_size is None else min_size
|
18 |
+
max_size = cfg.INPUT.MAX_SIZE_TRAIN if max_size is None else max_size
|
19 |
+
sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
|
20 |
+
else:
|
21 |
+
min_size = cfg.INPUT.MIN_SIZE_TEST
|
22 |
+
max_size = cfg.INPUT.MAX_SIZE_TEST
|
23 |
+
sample_style = "choice"
|
24 |
+
augmentation = [T.ResizeShortestEdge(min_size, max_size, sample_style)]
|
25 |
+
elif cfg.INPUT.CUSTOM_AUG == 'EfficientDetResizeCrop':
|
26 |
+
if is_train:
|
27 |
+
scale = cfg.INPUT.SCALE_RANGE if scale is None else scale
|
28 |
+
size = cfg.INPUT.TRAIN_SIZE if size is None else size
|
29 |
+
else:
|
30 |
+
scale = (1, 1)
|
31 |
+
size = cfg.INPUT.TEST_SIZE
|
32 |
+
augmentation = [EfficientDetResizeCrop(size, scale)]
|
33 |
+
else:
|
34 |
+
assert 0, cfg.INPUT.CUSTOM_AUG
|
35 |
+
|
36 |
+
if is_train:
|
37 |
+
augmentation.append(T.RandomFlip())
|
38 |
+
return augmentation
|
39 |
+
|
40 |
+
|
41 |
+
build_custom_transform_gen = build_custom_augmentation
|
42 |
+
"""
|
43 |
+
Alias for backward-compatibility.
|
44 |
+
"""
|
models/grit_src/grit/data/custom_dataset_dataloader.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Modified by Jialian Wu from https://github.com/facebookresearch/Detic/blob/main/detic/data/custom_dataset_dataloader.py
|
3 |
+
import operator
|
4 |
+
import torch
|
5 |
+
import torch.utils.data
|
6 |
+
from detectron2.utils.comm import get_world_size
|
7 |
+
|
8 |
+
from detectron2.config import configurable
|
9 |
+
from torch.utils.data.sampler import BatchSampler, Sampler
|
10 |
+
from detectron2.data.common import DatasetFromList, MapDataset
|
11 |
+
from detectron2.data.dataset_mapper import DatasetMapper
|
12 |
+
from detectron2.data.build import get_detection_dataset_dicts, build_batch_data_loader
|
13 |
+
from detectron2.data.samplers import TrainingSampler
|
14 |
+
from detectron2.data.build import worker_init_reset_seed, print_instances_class_histogram
|
15 |
+
from detectron2.data.build import filter_images_with_only_crowd_annotations
|
16 |
+
from detectron2.data.build import filter_images_with_few_keypoints
|
17 |
+
from detectron2.data.build import check_metadata_consistency
|
18 |
+
from detectron2.data.catalog import MetadataCatalog, DatasetCatalog
|
19 |
+
from detectron2.utils import comm
|
20 |
+
import itertools
|
21 |
+
from typing import Optional
|
22 |
+
|
23 |
+
|
24 |
+
def _custom_train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
|
25 |
+
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
|
26 |
+
if 'MultiDataset' in sampler_name:
|
27 |
+
dataset_dicts = get_detection_dataset_dicts_with_source(
|
28 |
+
cfg.DATASETS.TRAIN,
|
29 |
+
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
|
30 |
+
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
|
31 |
+
if cfg.MODEL.KEYPOINT_ON else 0,
|
32 |
+
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
|
33 |
+
)
|
34 |
+
else:
|
35 |
+
dataset_dicts = get_detection_dataset_dicts(
|
36 |
+
cfg.DATASETS.TRAIN,
|
37 |
+
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
|
38 |
+
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
|
39 |
+
if cfg.MODEL.KEYPOINT_ON else 0,
|
40 |
+
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
|
41 |
+
)
|
42 |
+
|
43 |
+
if mapper is None:
|
44 |
+
mapper = DatasetMapper(cfg, True)
|
45 |
+
|
46 |
+
if sampler is not None:
|
47 |
+
pass
|
48 |
+
elif sampler_name == "TrainingSampler":
|
49 |
+
sampler = TrainingSampler(len(dataset))
|
50 |
+
elif sampler_name == "MultiDatasetSampler":
|
51 |
+
sampler = MultiDatasetSampler(
|
52 |
+
dataset_dicts,
|
53 |
+
dataset_ratio=cfg.DATALOADER.DATASET_RATIO,
|
54 |
+
)
|
55 |
+
else:
|
56 |
+
raise ValueError("Unknown training sampler: {}".format(sampler_name))
|
57 |
+
|
58 |
+
return {
|
59 |
+
"dataset": dataset_dicts,
|
60 |
+
"sampler": sampler,
|
61 |
+
"mapper": mapper,
|
62 |
+
"total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
|
63 |
+
"num_workers": cfg.DATALOADER.NUM_WORKERS,
|
64 |
+
'dataset_bs': cfg.DATALOADER.DATASET_BS,
|
65 |
+
'num_datasets': len(cfg.DATASETS.TRAIN)
|
66 |
+
}
|
67 |
+
|
68 |
+
|
69 |
+
@configurable(from_config=_custom_train_loader_from_config)
|
70 |
+
def build_custom_train_loader(
|
71 |
+
dataset, *, mapper, sampler,
|
72 |
+
total_batch_size=16,
|
73 |
+
num_workers=0,
|
74 |
+
num_datasets=1,
|
75 |
+
dataset_bs=1
|
76 |
+
):
|
77 |
+
|
78 |
+
if isinstance(dataset, list):
|
79 |
+
dataset = DatasetFromList(dataset, copy=False)
|
80 |
+
if mapper is not None:
|
81 |
+
dataset = MapDataset(dataset, mapper)
|
82 |
+
if sampler is None:
|
83 |
+
sampler = TrainingSampler(len(dataset))
|
84 |
+
assert isinstance(sampler, torch.utils.data.sampler.Sampler)
|
85 |
+
|
86 |
+
return build_dataset_batch_data_loader(
|
87 |
+
dataset_bs,
|
88 |
+
dataset,
|
89 |
+
sampler,
|
90 |
+
total_batch_size,
|
91 |
+
num_datasets=num_datasets,
|
92 |
+
num_workers=num_workers,
|
93 |
+
)
|
94 |
+
|
95 |
+
|
96 |
+
def build_dataset_batch_data_loader(
|
97 |
+
dataset_bs, dataset, sampler, total_batch_size, num_datasets, num_workers=0
|
98 |
+
):
|
99 |
+
|
100 |
+
world_size = get_world_size()
|
101 |
+
assert (
|
102 |
+
total_batch_size > 0 and total_batch_size % world_size == 0
|
103 |
+
), "Total batch size ({}) must be divisible by the number of gpus ({}).".format(
|
104 |
+
total_batch_size, world_size
|
105 |
+
)
|
106 |
+
|
107 |
+
data_loader = torch.utils.data.DataLoader(
|
108 |
+
dataset,
|
109 |
+
sampler=sampler,
|
110 |
+
num_workers=num_workers,
|
111 |
+
batch_sampler=None,
|
112 |
+
collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements
|
113 |
+
worker_init_fn=worker_init_reset_seed,
|
114 |
+
)
|
115 |
+
|
116 |
+
if num_datasets > 1:
|
117 |
+
return MultiDatasets(data_loader, dataset_bs, num_datasets)
|
118 |
+
else:
|
119 |
+
return SingleDataset(data_loader, dataset_bs)
|
120 |
+
|
121 |
+
|
122 |
+
def get_detection_dataset_dicts_with_source(
|
123 |
+
dataset_names, filter_empty=True, min_keypoints=0, proposal_files=None
|
124 |
+
):
|
125 |
+
assert len(dataset_names)
|
126 |
+
dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names]
|
127 |
+
for dataset_name, dicts in zip(dataset_names, dataset_dicts):
|
128 |
+
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
|
129 |
+
|
130 |
+
for source_id, (dataset_name, dicts) in \
|
131 |
+
enumerate(zip(dataset_names, dataset_dicts)):
|
132 |
+
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
|
133 |
+
for d in dicts:
|
134 |
+
d['dataset_source'] = source_id
|
135 |
+
|
136 |
+
if "annotations" in dicts[0]:
|
137 |
+
try:
|
138 |
+
class_names = MetadataCatalog.get(dataset_name).thing_classes
|
139 |
+
check_metadata_consistency("thing_classes", dataset_name)
|
140 |
+
print_instances_class_histogram(dicts, class_names)
|
141 |
+
except AttributeError: # class names are not available for this dataset
|
142 |
+
pass
|
143 |
+
|
144 |
+
assert proposal_files is None
|
145 |
+
|
146 |
+
dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
|
147 |
+
|
148 |
+
has_instances = "annotations" in dataset_dicts[0]
|
149 |
+
if filter_empty and has_instances:
|
150 |
+
dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts)
|
151 |
+
if min_keypoints > 0 and has_instances:
|
152 |
+
dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints)
|
153 |
+
|
154 |
+
return dataset_dicts
|
155 |
+
|
156 |
+
|
157 |
+
class MultiDatasetSampler(Sampler):
|
158 |
+
def __init__(
|
159 |
+
self,
|
160 |
+
dataset_dicts,
|
161 |
+
dataset_ratio,
|
162 |
+
seed: Optional[int] = None,
|
163 |
+
):
|
164 |
+
sizes = [0 for _ in range(len(dataset_ratio))]
|
165 |
+
for d in dataset_dicts:
|
166 |
+
sizes[d['dataset_source']] += 1
|
167 |
+
print('dataset sizes', sizes)
|
168 |
+
self.sizes = sizes
|
169 |
+
assert len(dataset_ratio) == len(sizes), \
|
170 |
+
'length of dataset ratio {} should be equal to number if dataset {}'.format(
|
171 |
+
len(dataset_ratio), len(sizes)
|
172 |
+
)
|
173 |
+
if seed is None:
|
174 |
+
seed = comm.shared_random_seed()
|
175 |
+
self._seed = int(seed)
|
176 |
+
self._rank = comm.get_rank()
|
177 |
+
self._world_size = comm.get_world_size()
|
178 |
+
|
179 |
+
self.dataset_ids = torch.tensor(
|
180 |
+
[d['dataset_source'] for d in dataset_dicts], dtype=torch.long)
|
181 |
+
self.dataset_ratio = dataset_ratio
|
182 |
+
|
183 |
+
dataset_weight = [torch.ones(s) * max(sizes) / s * r / sum(dataset_ratio) \
|
184 |
+
for i, (r, s) in enumerate(zip(dataset_ratio, sizes))]
|
185 |
+
dataset_weight = torch.cat(dataset_weight)
|
186 |
+
|
187 |
+
self.weights = dataset_weight
|
188 |
+
self.sample_epoch_size = len(self.weights)
|
189 |
+
|
190 |
+
def __iter__(self):
|
191 |
+
start = self._rank
|
192 |
+
yield from itertools.islice(
|
193 |
+
self._infinite_indices(), start, None, self._world_size)
|
194 |
+
|
195 |
+
def _infinite_indices(self):
|
196 |
+
g = torch.Generator()
|
197 |
+
g.manual_seed(self._seed)
|
198 |
+
while True:
|
199 |
+
if len(self.dataset_ratio) > 1:
|
200 |
+
# multiple datasets
|
201 |
+
ids = torch.multinomial(
|
202 |
+
self.weights, self.sample_epoch_size, generator=g,
|
203 |
+
replacement=True)
|
204 |
+
nums = [(self.dataset_ids[ids] == i).sum().int().item() \
|
205 |
+
for i in range(len(self.sizes))]
|
206 |
+
yield from ids
|
207 |
+
else:
|
208 |
+
# single dataset
|
209 |
+
yield from torch.randperm(self.sizes[0], generator=g).tolist()
|
210 |
+
|
211 |
+
|
212 |
+
class SingleDataset(torch.utils.data.IterableDataset):
|
213 |
+
def __init__(self, dataset, batch_sizes):
|
214 |
+
self.dataset = dataset
|
215 |
+
self.batch_sizes = batch_sizes
|
216 |
+
self._buckets = [[] for _ in range(2)]
|
217 |
+
|
218 |
+
def __iter__(self):
|
219 |
+
for d in self.dataset:
|
220 |
+
w, h = d["width"], d["height"]
|
221 |
+
aspect_ratio_bucket_id = 0 if w > h else 1
|
222 |
+
bucket_id = aspect_ratio_bucket_id
|
223 |
+
bucket = self._buckets[bucket_id]
|
224 |
+
bucket.append(d)
|
225 |
+
if len(bucket) == self.batch_sizes:
|
226 |
+
yield bucket[:]
|
227 |
+
del bucket[:]
|
228 |
+
|
229 |
+
|
230 |
+
class MultiDatasets(torch.utils.data.IterableDataset):
|
231 |
+
def __init__(self, dataset, batch_sizes, num_datasets):
|
232 |
+
self.dataset = dataset
|
233 |
+
self.batch_sizes = batch_sizes
|
234 |
+
self._buckets = [[] for _ in range(2 * num_datasets)]
|
235 |
+
self.iter_idx = 0
|
236 |
+
self.num_datasets = num_datasets
|
237 |
+
|
238 |
+
def __iter__(self):
|
239 |
+
for d in self.dataset:
|
240 |
+
w, h = d["width"], d["height"]
|
241 |
+
aspect_ratio_bucket_id = 0 if w > h else 1
|
242 |
+
bucket_id = d['dataset_source'] * 2 + aspect_ratio_bucket_id
|
243 |
+
bucket = self._buckets[bucket_id]
|
244 |
+
if len(bucket) < self.batch_sizes:
|
245 |
+
bucket.append(d)
|
246 |
+
selected_dataset = self.iter_idx % self.num_datasets
|
247 |
+
if len(bucket) == self.batch_sizes and selected_dataset == d['dataset_source']:
|
248 |
+
self.iter_idx += 1
|
249 |
+
yield bucket[:]
|
250 |
+
del bucket[:]
|
models/grit_src/grit/data/custom_dataset_mapper.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
# Modified by Jialian Wu from https://github.com/facebookresearch/Detic/blob/main/detic/data/custom_dataset_mapper.py
|
3 |
+
import copy
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from detectron2.config import configurable
|
8 |
+
|
9 |
+
from detectron2.data import detection_utils as utils
|
10 |
+
from detectron2.data import transforms as T
|
11 |
+
from detectron2.data.dataset_mapper import DatasetMapper
|
12 |
+
from .custom_build_augmentation import build_custom_augmentation
|
13 |
+
from itertools import compress
|
14 |
+
import logging
|
15 |
+
|
16 |
+
__all__ = ["CustomDatasetMapper", "ObjDescription"]
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
class CustomDatasetMapper(DatasetMapper):
|
21 |
+
@configurable
|
22 |
+
def __init__(self, is_train: bool,
|
23 |
+
dataset_augs=[],
|
24 |
+
**kwargs):
|
25 |
+
if is_train:
|
26 |
+
self.dataset_augs = [T.AugmentationList(x) for x in dataset_augs]
|
27 |
+
super().__init__(is_train, **kwargs)
|
28 |
+
|
29 |
+
@classmethod
|
30 |
+
def from_config(cls, cfg, is_train: bool = True):
|
31 |
+
ret = super().from_config(cfg, is_train)
|
32 |
+
if is_train:
|
33 |
+
if cfg.INPUT.CUSTOM_AUG == 'EfficientDetResizeCrop':
|
34 |
+
dataset_scales = cfg.DATALOADER.DATASET_INPUT_SCALE
|
35 |
+
dataset_sizes = cfg.DATALOADER.DATASET_INPUT_SIZE
|
36 |
+
ret['dataset_augs'] = [
|
37 |
+
build_custom_augmentation(cfg, True, scale, size) \
|
38 |
+
for scale, size in zip(dataset_scales, dataset_sizes)]
|
39 |
+
else:
|
40 |
+
assert cfg.INPUT.CUSTOM_AUG == 'ResizeShortestEdge'
|
41 |
+
min_sizes = cfg.DATALOADER.DATASET_MIN_SIZES
|
42 |
+
max_sizes = cfg.DATALOADER.DATASET_MAX_SIZES
|
43 |
+
ret['dataset_augs'] = [
|
44 |
+
build_custom_augmentation(
|
45 |
+
cfg, True, min_size=mi, max_size=ma) \
|
46 |
+
for mi, ma in zip(min_sizes, max_sizes)]
|
47 |
+
else:
|
48 |
+
ret['dataset_augs'] = []
|
49 |
+
|
50 |
+
return ret
|
51 |
+
|
52 |
+
def __call__(self, dataset_dict):
|
53 |
+
dataset_dict_out = self.prepare_data(dataset_dict)
|
54 |
+
|
55 |
+
# When augmented image is too small, do re-augmentation
|
56 |
+
retry = 0
|
57 |
+
while (dataset_dict_out["image"].shape[1] < 32 or dataset_dict_out["image"].shape[2] < 32):
|
58 |
+
retry += 1
|
59 |
+
if retry == 100:
|
60 |
+
logger.info('Retry 100 times for augmentation. Make sure the image size is not too small.')
|
61 |
+
logger.info('Find image information below')
|
62 |
+
logger.info(dataset_dict)
|
63 |
+
dataset_dict_out = self.prepare_data(dataset_dict)
|
64 |
+
|
65 |
+
return dataset_dict_out
|
66 |
+
|
67 |
+
def prepare_data(self, dataset_dict_in):
|
68 |
+
dataset_dict = copy.deepcopy(dataset_dict_in)
|
69 |
+
if 'file_name' in dataset_dict:
|
70 |
+
ori_image = utils.read_image(
|
71 |
+
dataset_dict["file_name"], format=self.image_format)
|
72 |
+
else:
|
73 |
+
ori_image, _, _ = self.tar_dataset[dataset_dict["tar_index"]]
|
74 |
+
ori_image = utils._apply_exif_orientation(ori_image)
|
75 |
+
ori_image = utils.convert_PIL_to_numpy(ori_image, self.image_format)
|
76 |
+
utils.check_image_size(dataset_dict, ori_image)
|
77 |
+
|
78 |
+
aug_input = T.AugInput(copy.deepcopy(ori_image), sem_seg=None)
|
79 |
+
if self.is_train:
|
80 |
+
transforms = \
|
81 |
+
self.dataset_augs[dataset_dict['dataset_source']](aug_input)
|
82 |
+
else:
|
83 |
+
transforms = self.augmentations(aug_input)
|
84 |
+
image, sem_seg_gt = aug_input.image, aug_input.sem_seg
|
85 |
+
|
86 |
+
image_shape = image.shape[:2]
|
87 |
+
dataset_dict["image"] = torch.as_tensor(
|
88 |
+
np.ascontiguousarray(image.transpose(2, 0, 1)))
|
89 |
+
|
90 |
+
if not self.is_train:
|
91 |
+
# USER: Modify this if you want to keep them for some reason.
|
92 |
+
dataset_dict.pop("annotations", None)
|
93 |
+
return dataset_dict
|
94 |
+
|
95 |
+
if "annotations" in dataset_dict:
|
96 |
+
if len(dataset_dict["annotations"]) > 0:
|
97 |
+
object_descriptions = [an['object_description'] for an in dataset_dict["annotations"]]
|
98 |
+
else:
|
99 |
+
object_descriptions = []
|
100 |
+
# USER: Modify this if you want to keep them for some reason.
|
101 |
+
for anno in dataset_dict["annotations"]:
|
102 |
+
if not self.use_instance_mask:
|
103 |
+
anno.pop("segmentation", None)
|
104 |
+
if not self.use_keypoint:
|
105 |
+
anno.pop("keypoints", None)
|
106 |
+
|
107 |
+
all_annos = [
|
108 |
+
(utils.transform_instance_annotations(
|
109 |
+
obj, transforms, image_shape,
|
110 |
+
keypoint_hflip_indices=self.keypoint_hflip_indices,
|
111 |
+
), obj.get("iscrowd", 0))
|
112 |
+
for obj in dataset_dict.pop("annotations")
|
113 |
+
]
|
114 |
+
annos = [ann[0] for ann in all_annos if ann[1] == 0]
|
115 |
+
instances = utils.annotations_to_instances(
|
116 |
+
annos, image_shape, mask_format=self.instance_mask_format
|
117 |
+
)
|
118 |
+
|
119 |
+
instances.gt_object_descriptions = ObjDescription(object_descriptions)
|
120 |
+
|
121 |
+
del all_annos
|
122 |
+
if self.recompute_boxes:
|
123 |
+
instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
|
124 |
+
dataset_dict["instances"] = utils.filter_empty_instances(instances)
|
125 |
+
|
126 |
+
return dataset_dict
|
127 |
+
|
128 |
+
|
129 |
+
class ObjDescription:
|
130 |
+
def __init__(self, object_descriptions):
|
131 |
+
self.data = object_descriptions
|
132 |
+
|
133 |
+
def __getitem__(self, item):
|
134 |
+
assert type(item) == torch.Tensor
|
135 |
+
assert item.dim() == 1
|
136 |
+
if len(item) > 0:
|
137 |
+
assert item.dtype == torch.int64 or item.dtype == torch.bool
|
138 |
+
if item.dtype == torch.int64:
|
139 |
+
return ObjDescription([self.data[x.item()] for x in item])
|
140 |
+
elif item.dtype == torch.bool:
|
141 |
+
return ObjDescription(list(compress(self.data, item)))
|
142 |
+
|
143 |
+
return ObjDescription(list(compress(self.data, item)))
|
144 |
+
|
145 |
+
def __len__(self):
|
146 |
+
return len(self.data)
|
147 |
+
|
148 |
+
def __repr__(self):
|
149 |
+
return "ObjDescription({})".format(self.data)
|
models/grit_src/grit/data/datasets/__pycache__/grit_coco.cpython-38.pyc
ADDED
Binary file (3.94 kB). View file
|
|
models/grit_src/grit/data/datasets/__pycache__/object365.cpython-38.pyc
ADDED
Binary file (3.7 kB). View file
|
|
models/grit_src/grit/data/datasets/__pycache__/vg.cpython-38.pyc
ADDED
Binary file (3.28 kB). View file
|
|
models/grit_src/grit/data/datasets/grit_coco.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from fvcore.common.timer import Timer
|
4 |
+
from detectron2.structures import BoxMode
|
5 |
+
from fvcore.common.file_io import PathManager
|
6 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
7 |
+
from lvis import LVIS
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
__all__ = ["load_GRiTcoco_json", "register_GRiTcoco_instances"]
|
12 |
+
|
13 |
+
|
14 |
+
def register_GRiTcoco_instances(name, metadata, json_file, image_root):
|
15 |
+
"""
|
16 |
+
"""
|
17 |
+
DatasetCatalog.register(name, lambda: load_GRiTcoco_json(
|
18 |
+
json_file, image_root, name))
|
19 |
+
MetadataCatalog.get(name).set(
|
20 |
+
json_file=json_file, image_root=image_root,
|
21 |
+
evaluator_type="coco", **metadata
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
def get_GRiTcoco_meta():
|
26 |
+
categories = [{'supercategory': 'object', 'id': 1, 'name': 'object'}]
|
27 |
+
categories = sorted(categories, key=lambda x: x["id"])
|
28 |
+
thing_classes = [k["name"] for k in categories]
|
29 |
+
meta = {"thing_classes": thing_classes}
|
30 |
+
return meta
|
31 |
+
|
32 |
+
|
33 |
+
def load_GRiTcoco_json(json_file, image_root, dataset_name=None):
|
34 |
+
'''
|
35 |
+
Load COCO class name text for object description for GRiT
|
36 |
+
'''
|
37 |
+
|
38 |
+
json_file = PathManager.get_local_path(json_file)
|
39 |
+
|
40 |
+
timer = Timer()
|
41 |
+
lvis_api = LVIS(json_file)
|
42 |
+
if timer.seconds() > 1:
|
43 |
+
logger.info("Loading {} takes {:.2f} seconds.".format(
|
44 |
+
json_file, timer.seconds()))
|
45 |
+
|
46 |
+
class_names = {}
|
47 |
+
sort_cat = sorted(lvis_api.dataset['categories'], key=lambda x: x['id'])
|
48 |
+
for x in sort_cat:
|
49 |
+
class_names[x['id']] = x['name']
|
50 |
+
|
51 |
+
img_ids = sorted(lvis_api.imgs.keys())
|
52 |
+
imgs = lvis_api.load_imgs(img_ids)
|
53 |
+
anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]
|
54 |
+
|
55 |
+
ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
|
56 |
+
assert len(set(ann_ids)) == len(ann_ids), \
|
57 |
+
"Annotation ids in '{}' are not unique".format(json_file)
|
58 |
+
|
59 |
+
imgs_anns = list(zip(imgs, anns))
|
60 |
+
logger.info("Loaded {} images in the LVIS v1 format from {}".format(
|
61 |
+
len(imgs_anns), json_file))
|
62 |
+
|
63 |
+
dataset_dicts = []
|
64 |
+
|
65 |
+
for (img_dict, anno_dict_list) in imgs_anns:
|
66 |
+
record = {}
|
67 |
+
if "file_name" in img_dict:
|
68 |
+
file_name = img_dict["file_name"]
|
69 |
+
record["file_name"] = os.path.join(image_root, file_name)
|
70 |
+
|
71 |
+
record["height"] = int(img_dict["height"])
|
72 |
+
record["width"] = int(img_dict["width"])
|
73 |
+
image_id = record["image_id"] = img_dict["id"]
|
74 |
+
|
75 |
+
objs = []
|
76 |
+
for anno in anno_dict_list:
|
77 |
+
assert anno["image_id"] == image_id
|
78 |
+
if anno.get('iscrowd', 0) > 0:
|
79 |
+
continue
|
80 |
+
obj = {"bbox": anno["bbox"], "bbox_mode": BoxMode.XYWH_ABS}
|
81 |
+
obj["category_id"] = 0
|
82 |
+
obj["object_description"] = class_names[anno['category_id']]
|
83 |
+
if 'segmentation' in anno:
|
84 |
+
segm = anno["segmentation"]
|
85 |
+
valid_segm = [poly for poly in segm \
|
86 |
+
if len(poly) % 2 == 0 and len(poly) >= 6]
|
87 |
+
if not len(segm) == len(valid_segm):
|
88 |
+
print('Annotation contains an invalid polygon with < 3 points')
|
89 |
+
assert len(segm) > 0
|
90 |
+
obj["segmentation"] = segm
|
91 |
+
objs.append(obj)
|
92 |
+
record["annotations"] = objs
|
93 |
+
if len(record["annotations"]) == 0:
|
94 |
+
continue
|
95 |
+
record["task"] = "ObjectDet"
|
96 |
+
dataset_dicts.append(record)
|
97 |
+
|
98 |
+
return dataset_dicts
|
99 |
+
|
100 |
+
|
101 |
+
_CUSTOM_SPLITS_LVIS = {
|
102 |
+
"GRiT_coco2017_train": ("coco/train2017/", "coco/annotations/instances_train2017.json"),
|
103 |
+
}
|
104 |
+
|
105 |
+
|
106 |
+
for key, (image_root, json_file) in _CUSTOM_SPLITS_LVIS.items():
|
107 |
+
register_GRiTcoco_instances(
|
108 |
+
key,
|
109 |
+
get_GRiTcoco_meta(),
|
110 |
+
os.path.join("datasets", json_file) if "://" not in json_file else json_file,
|
111 |
+
os.path.join("datasets", image_root),
|
112 |
+
)
|
models/grit_src/grit/data/datasets/object365.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from fvcore.common.timer import Timer
|
4 |
+
from detectron2.structures import BoxMode
|
5 |
+
from fvcore.common.file_io import PathManager
|
6 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
7 |
+
from lvis import LVIS
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
__all__ = ["load_o365_json", "register_o365_instances"]
|
12 |
+
|
13 |
+
|
14 |
+
def register_o365_instances(name, metadata, json_file, image_root):
|
15 |
+
DatasetCatalog.register(name, lambda: load_o365_json(
|
16 |
+
json_file, image_root, name))
|
17 |
+
MetadataCatalog.get(name).set(
|
18 |
+
json_file=json_file, image_root=image_root,
|
19 |
+
evaluator_type="lvis", **metadata
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
def get_o365_meta():
|
24 |
+
categories = [{'supercategory': 'object', 'id': 1, 'name': 'object'}]
|
25 |
+
o365_categories = sorted(categories, key=lambda x: x["id"])
|
26 |
+
thing_classes = [k["name"] for k in o365_categories]
|
27 |
+
meta = {"thing_classes": thing_classes}
|
28 |
+
return meta
|
29 |
+
|
30 |
+
|
31 |
+
def load_o365_json(json_file, image_root, dataset_name=None):
|
32 |
+
'''
|
33 |
+
Load Object365 class name text for object description for GRiT
|
34 |
+
'''
|
35 |
+
|
36 |
+
json_file = PathManager.get_local_path(json_file)
|
37 |
+
|
38 |
+
timer = Timer()
|
39 |
+
lvis_api = LVIS(json_file)
|
40 |
+
if timer.seconds() > 1:
|
41 |
+
logger.info("Loading {} takes {:.2f} seconds.".format(
|
42 |
+
json_file, timer.seconds()))
|
43 |
+
|
44 |
+
class_names = {}
|
45 |
+
sort_cat = sorted(lvis_api.dataset['categories'], key=lambda x: x['id'])
|
46 |
+
for x in sort_cat:
|
47 |
+
if '/' in x['name']:
|
48 |
+
text = ''
|
49 |
+
for xx in x['name'].split('/'):
|
50 |
+
text += xx
|
51 |
+
text += ' '
|
52 |
+
text = text[:-1]
|
53 |
+
else:
|
54 |
+
text = x['name']
|
55 |
+
class_names[x['id']] = text
|
56 |
+
|
57 |
+
img_ids = sorted(lvis_api.imgs.keys())
|
58 |
+
imgs = lvis_api.load_imgs(img_ids)
|
59 |
+
anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]
|
60 |
+
|
61 |
+
ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
|
62 |
+
assert len(set(ann_ids)) == len(ann_ids), \
|
63 |
+
"Annotation ids in '{}' are not unique".format(json_file)
|
64 |
+
|
65 |
+
imgs_anns = list(zip(imgs, anns))
|
66 |
+
logger.info("Loaded {} images in the LVIS v1 format from {}".format(
|
67 |
+
len(imgs_anns), json_file))
|
68 |
+
|
69 |
+
dataset_dicts = []
|
70 |
+
|
71 |
+
for (img_dict, anno_dict_list) in imgs_anns:
|
72 |
+
record = {}
|
73 |
+
if "file_name" in img_dict:
|
74 |
+
file_name = img_dict["file_name"]
|
75 |
+
record["file_name"] = os.path.join(image_root, file_name)
|
76 |
+
|
77 |
+
record["height"] = int(img_dict["height"])
|
78 |
+
record["width"] = int(img_dict["width"])
|
79 |
+
image_id = record["image_id"] = img_dict["id"]
|
80 |
+
|
81 |
+
objs = []
|
82 |
+
for anno in anno_dict_list:
|
83 |
+
assert anno["image_id"] == image_id
|
84 |
+
if anno.get('iscrowd', 0) > 0:
|
85 |
+
continue
|
86 |
+
obj = {"bbox": anno["bbox"], "bbox_mode": BoxMode.XYWH_ABS}
|
87 |
+
obj["category_id"] = 0
|
88 |
+
obj["object_description"] = class_names[anno['category_id']]
|
89 |
+
|
90 |
+
objs.append(obj)
|
91 |
+
record["annotations"] = objs
|
92 |
+
if len(record["annotations"]) == 0:
|
93 |
+
continue
|
94 |
+
record["task"] = "ObjectDet"
|
95 |
+
dataset_dicts.append(record)
|
96 |
+
|
97 |
+
return dataset_dicts
|
98 |
+
|
99 |
+
|
100 |
+
_CUSTOM_SPLITS_LVIS = {
|
101 |
+
"object365_train": ("object365/images/train/", "object365/annotations/train_v1.json"),
|
102 |
+
}
|
103 |
+
|
104 |
+
|
105 |
+
for key, (image_root, json_file) in _CUSTOM_SPLITS_LVIS.items():
|
106 |
+
register_o365_instances(
|
107 |
+
key,
|
108 |
+
get_o365_meta(),
|
109 |
+
os.path.join("datasets", json_file) if "://" not in json_file else json_file,
|
110 |
+
os.path.join("datasets", image_root),
|
111 |
+
)
|
models/grit_src/grit/data/datasets/vg.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from fvcore.common.timer import Timer
|
4 |
+
from detectron2.structures import BoxMode
|
5 |
+
from fvcore.common.file_io import PathManager
|
6 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
7 |
+
from lvis import LVIS
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
__all__ = ["load_vg_json", "register_vg_instances"]
|
12 |
+
|
13 |
+
|
14 |
+
def register_vg_instances(name, metadata, json_file, image_root):
|
15 |
+
"""
|
16 |
+
"""
|
17 |
+
DatasetCatalog.register(name, lambda: load_vg_json(
|
18 |
+
json_file, image_root, name))
|
19 |
+
MetadataCatalog.get(name).set(
|
20 |
+
json_file=json_file, image_root=image_root,
|
21 |
+
evaluator_type="vg", **metadata
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
def get_vg_meta():
|
26 |
+
categories = [{'supercategory': 'object', 'id': 1, 'name': 'object'}]
|
27 |
+
vg_categories = sorted(categories, key=lambda x: x["id"])
|
28 |
+
thing_classes = [k["name"] for k in vg_categories]
|
29 |
+
meta = {"thing_classes": thing_classes}
|
30 |
+
return meta
|
31 |
+
|
32 |
+
|
33 |
+
def load_vg_json(json_file, image_root, dataset_name=None):
|
34 |
+
|
35 |
+
json_file = PathManager.get_local_path(json_file)
|
36 |
+
|
37 |
+
timer = Timer()
|
38 |
+
lvis_api = LVIS(json_file)
|
39 |
+
if timer.seconds() > 1:
|
40 |
+
logger.info("Loading {} takes {:.2f} seconds.".format(
|
41 |
+
json_file, timer.seconds()))
|
42 |
+
|
43 |
+
img_ids = sorted(lvis_api.imgs.keys())
|
44 |
+
imgs = lvis_api.load_imgs(img_ids)
|
45 |
+
anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]
|
46 |
+
|
47 |
+
ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
|
48 |
+
assert len(set(ann_ids)) == len(ann_ids), \
|
49 |
+
"Annotation ids in '{}' are not unique".format(json_file)
|
50 |
+
|
51 |
+
imgs_anns = list(zip(imgs, anns))
|
52 |
+
logger.info("Loaded {} images in the LVIS v1 format from {}".format(
|
53 |
+
len(imgs_anns), json_file))
|
54 |
+
|
55 |
+
dataset_dicts = []
|
56 |
+
|
57 |
+
for (img_dict, anno_dict_list) in imgs_anns:
|
58 |
+
record = {}
|
59 |
+
if "file_name" in img_dict:
|
60 |
+
file_name = img_dict["file_name"]
|
61 |
+
record["file_name"] = os.path.join(image_root, file_name)
|
62 |
+
|
63 |
+
record["height"] = int(img_dict["height"])
|
64 |
+
record["width"] = int(img_dict["width"])
|
65 |
+
image_id = record["image_id"] = img_dict["id"]
|
66 |
+
|
67 |
+
objs = []
|
68 |
+
for anno in anno_dict_list:
|
69 |
+
assert anno["image_id"] == image_id
|
70 |
+
if anno.get('iscrowd', 0) > 0:
|
71 |
+
continue
|
72 |
+
obj = {"bbox": anno["bbox"], "bbox_mode": BoxMode.XYWH_ABS}
|
73 |
+
obj["category_id"] = 0
|
74 |
+
obj["object_description"] = anno["caption"]
|
75 |
+
|
76 |
+
objs.append(obj)
|
77 |
+
record["annotations"] = objs
|
78 |
+
if len(record["annotations"]) == 0:
|
79 |
+
continue
|
80 |
+
record["task"] = "DenseCap"
|
81 |
+
dataset_dicts.append(record)
|
82 |
+
|
83 |
+
return dataset_dicts
|
84 |
+
|
85 |
+
|
86 |
+
_CUSTOM_SPLITS_LVIS = {
|
87 |
+
"vg_train": ("vg/images", "vg/annotations/train.json"),
|
88 |
+
"vg_test": ("vg/images", "vg/annotations/test.json"),
|
89 |
+
}
|
90 |
+
|
91 |
+
|
92 |
+
for key, (image_root, json_file) in _CUSTOM_SPLITS_LVIS.items():
|
93 |
+
register_vg_instances(
|
94 |
+
key,
|
95 |
+
get_vg_meta(),
|
96 |
+
os.path.join("datasets", json_file) if "://" not in json_file else json_file,
|
97 |
+
os.path.join("datasets", image_root),
|
98 |
+
)
|
models/grit_src/grit/data/transforms/__pycache__/custom_augmentation_impl.cpython-38.pyc
ADDED
Binary file (1.73 kB). View file
|
|
models/grit_src/grit/data/transforms/__pycache__/custom_transform.cpython-38.pyc
ADDED
Binary file (3.89 kB). View file
|
|
models/grit_src/grit/data/transforms/custom_augmentation_impl.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
3 |
+
# Part of the code is from https://github.com/rwightman/efficientdet-pytorch/blob/master/effdet/data/transforms.py
|
4 |
+
# Modified by Xingyi Zhou
|
5 |
+
# The original code is under Apache-2.0 License
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from detectron2.data.transforms.augmentation import Augmentation
|
10 |
+
from .custom_transform import EfficientDetResizeCropTransform
|
11 |
+
|
12 |
+
__all__ = [
|
13 |
+
"EfficientDetResizeCrop",
|
14 |
+
]
|
15 |
+
|
16 |
+
|
17 |
+
class EfficientDetResizeCrop(Augmentation):
|
18 |
+
"""
|
19 |
+
Scale the shorter edge to the given size, with a limit of `max_size` on the longer edge.
|
20 |
+
If `max_size` is reached, then downscale so that the longer edge does not exceed max_size.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self, size, scale, interp=Image.BILINEAR
|
25 |
+
):
|
26 |
+
"""
|
27 |
+
"""
|
28 |
+
super().__init__()
|
29 |
+
self.target_size = (size, size)
|
30 |
+
self.scale = scale
|
31 |
+
self.interp = interp
|
32 |
+
|
33 |
+
def get_transform(self, img):
|
34 |
+
# Select a random scale factor.
|
35 |
+
scale_factor = np.random.uniform(*self.scale)
|
36 |
+
scaled_target_height = scale_factor * self.target_size[0]
|
37 |
+
scaled_target_width = scale_factor * self.target_size[1]
|
38 |
+
# Recompute the accurate scale_factor using rounded scaled image size.
|
39 |
+
width, height = img.shape[1], img.shape[0]
|
40 |
+
img_scale_y = scaled_target_height / height
|
41 |
+
img_scale_x = scaled_target_width / width
|
42 |
+
img_scale = min(img_scale_y, img_scale_x)
|
43 |
+
|
44 |
+
# Select non-zero random offset (x, y) if scaled image is larger than target size
|
45 |
+
scaled_h = int(height * img_scale)
|
46 |
+
scaled_w = int(width * img_scale)
|
47 |
+
offset_y = scaled_h - self.target_size[0]
|
48 |
+
offset_x = scaled_w - self.target_size[1]
|
49 |
+
offset_y = int(max(0.0, float(offset_y)) * np.random.uniform(0, 1))
|
50 |
+
offset_x = int(max(0.0, float(offset_x)) * np.random.uniform(0, 1))
|
51 |
+
return EfficientDetResizeCropTransform(
|
52 |
+
scaled_h, scaled_w, offset_y, offset_x, img_scale, self.target_size, self.interp)
|
models/grit_src/grit/data/transforms/custom_transform.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
3 |
+
# Part of the code is from https://github.com/rwightman/efficientdet-pytorch/blob/master/effdet/data/transforms.py
|
4 |
+
# Modified by Xingyi Zhou
|
5 |
+
# The original code is under Apache-2.0 License
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from fvcore.transforms.transform import (
|
10 |
+
CropTransform,
|
11 |
+
HFlipTransform,
|
12 |
+
NoOpTransform,
|
13 |
+
Transform,
|
14 |
+
TransformList,
|
15 |
+
)
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
try:
|
19 |
+
import cv2 # noqa
|
20 |
+
except ImportError:
|
21 |
+
# OpenCV is an optional dependency at the moment
|
22 |
+
pass
|
23 |
+
|
24 |
+
__all__ = [
|
25 |
+
"EfficientDetResizeCropTransform",
|
26 |
+
]
|
27 |
+
|
28 |
+
|
29 |
+
class EfficientDetResizeCropTransform(Transform):
|
30 |
+
"""
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self, scaled_h, scaled_w, offset_y, offset_x, img_scale, \
|
34 |
+
target_size, interp=None):
|
35 |
+
"""
|
36 |
+
Args:
|
37 |
+
h, w (int): original image size
|
38 |
+
new_h, new_w (int): new image size
|
39 |
+
interp: PIL interpolation methods, defaults to bilinear.
|
40 |
+
"""
|
41 |
+
# TODO decide on PIL vs opencv
|
42 |
+
super().__init__()
|
43 |
+
if interp is None:
|
44 |
+
interp = Image.BILINEAR
|
45 |
+
self._set_attributes(locals())
|
46 |
+
|
47 |
+
def apply_image(self, img, interp=None):
|
48 |
+
assert len(img.shape) <= 4
|
49 |
+
|
50 |
+
if img.dtype == np.uint8:
|
51 |
+
pil_image = Image.fromarray(img)
|
52 |
+
interp_method = interp if interp is not None else self.interp
|
53 |
+
pil_image = pil_image.resize((self.scaled_w, self.scaled_h), interp_method)
|
54 |
+
ret = np.asarray(pil_image)
|
55 |
+
right = min(self.scaled_w, self.offset_x + self.target_size[1])
|
56 |
+
lower = min(self.scaled_h, self.offset_y + self.target_size[0])
|
57 |
+
if len(ret.shape) <= 3:
|
58 |
+
ret = ret[self.offset_y: lower, self.offset_x: right]
|
59 |
+
else:
|
60 |
+
ret = ret[..., self.offset_y: lower, self.offset_x: right, :]
|
61 |
+
else:
|
62 |
+
# PIL only supports uint8
|
63 |
+
img = torch.from_numpy(img)
|
64 |
+
shape = list(img.shape)
|
65 |
+
shape_4d = shape[:2] + [1] * (4 - len(shape)) + shape[2:]
|
66 |
+
img = img.view(shape_4d).permute(2, 3, 0, 1) # hw(c) -> nchw
|
67 |
+
_PIL_RESIZE_TO_INTERPOLATE_MODE = {Image.BILINEAR: "bilinear", Image.BICUBIC: "bicubic"}
|
68 |
+
mode = _PIL_RESIZE_TO_INTERPOLATE_MODE[self.interp]
|
69 |
+
img = F.interpolate(img, (self.scaled_h, self.scaled_w), mode=mode, align_corners=False)
|
70 |
+
shape[:2] = (self.scaled_h, self.scaled_w)
|
71 |
+
ret = img.permute(2, 3, 0, 1).view(shape).numpy() # nchw -> hw(c)
|
72 |
+
right = min(self.scaled_w, self.offset_x + self.target_size[1])
|
73 |
+
lower = min(self.scaled_h, self.offset_y + self.target_size[0])
|
74 |
+
if len(ret.shape) <= 3:
|
75 |
+
ret = ret[self.offset_y: lower, self.offset_x: right]
|
76 |
+
else:
|
77 |
+
ret = ret[..., self.offset_y: lower, self.offset_x: right, :]
|
78 |
+
return ret
|
79 |
+
|
80 |
+
|
81 |
+
def apply_coords(self, coords):
|
82 |
+
coords[:, 0] = coords[:, 0] * self.img_scale
|
83 |
+
coords[:, 1] = coords[:, 1] * self.img_scale
|
84 |
+
coords[:, 0] -= self.offset_x
|
85 |
+
coords[:, 1] -= self.offset_y
|
86 |
+
return coords
|
87 |
+
|
88 |
+
|
89 |
+
def apply_segmentation(self, segmentation):
|
90 |
+
segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
|
91 |
+
return segmentation
|
92 |
+
|
93 |
+
|
94 |
+
def inverse(self):
|
95 |
+
raise NotImplementedError
|
96 |
+
|
97 |
+
|
98 |
+
def inverse_apply_coords(self, coords):
|
99 |
+
coords[:, 0] += self.offset_x
|
100 |
+
coords[:, 1] += self.offset_y
|
101 |
+
coords[:, 0] = coords[:, 0] / self.img_scale
|
102 |
+
coords[:, 1] = coords[:, 1] / self.img_scale
|
103 |
+
return coords
|
104 |
+
|
105 |
+
|
106 |
+
def inverse_apply_box(self, box: np.ndarray) -> np.ndarray:
|
107 |
+
"""
|
108 |
+
"""
|
109 |
+
idxs = np.array([(0, 1), (2, 1), (0, 3), (2, 3)]).flatten()
|
110 |
+
coords = np.asarray(box).reshape(-1, 4)[:, idxs].reshape(-1, 2)
|
111 |
+
coords = self.inverse_apply_coords(coords).reshape((-1, 4, 2))
|
112 |
+
minxy = coords.min(axis=1)
|
113 |
+
maxxy = coords.max(axis=1)
|
114 |
+
trans_boxes = np.concatenate((minxy, maxxy), axis=1)
|
115 |
+
return trans_boxes
|
models/grit_src/grit/evaluation/eval.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from detectron2.structures import Boxes, BoxMode, pairwise_iou
|
5 |
+
from detectron2.utils.file_io import PathManager
|
6 |
+
import numpy as np
|
7 |
+
import pycocotools.mask as mask_util
|
8 |
+
from detectron2.evaluation.coco_evaluation import COCOEvaluator
|
9 |
+
from detectron2.evaluation.coco_evaluation import _evaluate_predictions_on_coco
|
10 |
+
|
11 |
+
|
12 |
+
class GRiTCOCOEvaluator(COCOEvaluator):
|
13 |
+
def process(self, inputs, outputs):
|
14 |
+
for input, output in zip(inputs, outputs):
|
15 |
+
prediction = {"image_id": input["image_id"]}
|
16 |
+
|
17 |
+
if "instances" in output:
|
18 |
+
instances = output["instances"].to(self._cpu_device)
|
19 |
+
prediction["instances"] = instances_to_coco_json(instances, input["image_id"])
|
20 |
+
|
21 |
+
if len(prediction) > 1:
|
22 |
+
self._predictions.append(prediction)
|
23 |
+
|
24 |
+
def _eval_predictions(self, predictions, img_ids=None):
|
25 |
+
self._logger.info("Preparing results for COCO format ...")
|
26 |
+
coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
|
27 |
+
tasks = self._tasks or self._tasks_from_predictions(coco_results)
|
28 |
+
|
29 |
+
if self._output_dir:
|
30 |
+
file_path = os.path.join(self._output_dir, "coco_instances_results.json")
|
31 |
+
self._logger.info("Saving results to {}".format(file_path))
|
32 |
+
with PathManager.open(file_path, "w") as f:
|
33 |
+
f.write(json.dumps(coco_results))
|
34 |
+
f.flush()
|
35 |
+
|
36 |
+
if not self._do_evaluation:
|
37 |
+
self._logger.info("Annotations are not available for evaluation.")
|
38 |
+
return
|
39 |
+
|
40 |
+
self._logger.info(
|
41 |
+
"Evaluating predictions with {} COCO API...".format(
|
42 |
+
"unofficial" if self._use_fast_impl else "official"
|
43 |
+
)
|
44 |
+
)
|
45 |
+
|
46 |
+
coco_results = self.convert_classname_to_id(coco_results)
|
47 |
+
|
48 |
+
for task in sorted(tasks):
|
49 |
+
assert task in {"bbox", "segm", "keypoints"}, f"Got unknown task: {task}!"
|
50 |
+
coco_eval = (
|
51 |
+
_evaluate_predictions_on_coco(
|
52 |
+
self._coco_api,
|
53 |
+
coco_results,
|
54 |
+
task,
|
55 |
+
kpt_oks_sigmas=self._kpt_oks_sigmas,
|
56 |
+
use_fast_impl=self._use_fast_impl,
|
57 |
+
img_ids=img_ids,
|
58 |
+
max_dets_per_image=self._max_dets_per_image,
|
59 |
+
)
|
60 |
+
if len(coco_results) > 0
|
61 |
+
else None # cocoapi does not handle empty results very well
|
62 |
+
)
|
63 |
+
|
64 |
+
res = self._derive_coco_results(
|
65 |
+
coco_eval, task, class_names=self._metadata.get("thing_classes")
|
66 |
+
)
|
67 |
+
self._results[task] = res
|
68 |
+
|
69 |
+
def convert_classname_to_id(self, results):
|
70 |
+
outputs = []
|
71 |
+
class_name_to_id = {}
|
72 |
+
categories = sorted(self._coco_api.dataset['categories'], key=lambda x: x['id'])
|
73 |
+
|
74 |
+
for cat in categories:
|
75 |
+
class_name_to_id[cat['name']] = cat['id']
|
76 |
+
|
77 |
+
for pred in results:
|
78 |
+
if pred['object_descriptions'] in class_name_to_id:
|
79 |
+
pred['category_id'] = class_name_to_id[pred['object_descriptions']]
|
80 |
+
del pred['object_descriptions']
|
81 |
+
outputs.append(pred)
|
82 |
+
|
83 |
+
return outputs
|
84 |
+
|
85 |
+
|
86 |
+
class GRiTVGEvaluator(COCOEvaluator):
|
87 |
+
def process(self, inputs, outputs):
|
88 |
+
for input, output in zip(inputs, outputs):
|
89 |
+
assert input["image_id"] == int(input['file_name'].split('/')[-1].split('.')[0])
|
90 |
+
prediction = {"image_id": input["image_id"]}
|
91 |
+
|
92 |
+
if "instances" in output:
|
93 |
+
instances = output["instances"].to(self._cpu_device)
|
94 |
+
prediction["instances"] = instances_to_coco_json(instances, input["image_id"], output_logits=True)
|
95 |
+
h = input['height']
|
96 |
+
w = input['width']
|
97 |
+
scale = 720.0 / max(h, w)
|
98 |
+
scaled_inst = []
|
99 |
+
for inst in prediction["instances"]:
|
100 |
+
inst['bbox'][0] = inst['bbox'][0] * scale
|
101 |
+
inst['bbox'][1] = inst['bbox'][1] * scale
|
102 |
+
inst['bbox'][2] = inst['bbox'][2] * scale
|
103 |
+
inst['bbox'][3] = inst['bbox'][3] * scale
|
104 |
+
scaled_inst.append(inst)
|
105 |
+
if len(scaled_inst) > 0:
|
106 |
+
prediction["instances"] = scaled_inst
|
107 |
+
if len(prediction) > 1:
|
108 |
+
self._predictions.append(prediction)
|
109 |
+
|
110 |
+
def _eval_predictions(self, predictions, img_ids=None):
|
111 |
+
'''
|
112 |
+
This is only for saving the results to json file
|
113 |
+
'''
|
114 |
+
self._logger.info("Preparing results for COCO format ...")
|
115 |
+
coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
|
116 |
+
|
117 |
+
if self._output_dir:
|
118 |
+
file_path = os.path.join(self._output_dir, "vg_instances_results.json")
|
119 |
+
self._logger.info("Saving results to {}".format(file_path))
|
120 |
+
with PathManager.open(file_path, "w") as f:
|
121 |
+
f.write(json.dumps(coco_results))
|
122 |
+
f.flush()
|
123 |
+
|
124 |
+
|
125 |
+
def instances_to_coco_json(instances, img_id, output_logits=False):
|
126 |
+
"""
|
127 |
+
Add object_descriptions and logit (if applicable) to
|
128 |
+
detectron2's instances_to_coco_json
|
129 |
+
"""
|
130 |
+
num_instance = len(instances)
|
131 |
+
if num_instance == 0:
|
132 |
+
return []
|
133 |
+
|
134 |
+
boxes = instances.pred_boxes.tensor.numpy()
|
135 |
+
boxes = BoxMode.convert(boxes, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
136 |
+
boxes = boxes.tolist()
|
137 |
+
scores = instances.scores.tolist()
|
138 |
+
classes = instances.pred_classes.tolist()
|
139 |
+
object_descriptions = instances.pred_object_descriptions.data
|
140 |
+
if output_logits:
|
141 |
+
logits = instances.logits.tolist()
|
142 |
+
|
143 |
+
results = []
|
144 |
+
for k in range(num_instance):
|
145 |
+
result = {
|
146 |
+
"image_id": img_id,
|
147 |
+
"category_id": classes[k],
|
148 |
+
"bbox": boxes[k],
|
149 |
+
"score": scores[k],
|
150 |
+
'object_descriptions': object_descriptions[k],
|
151 |
+
}
|
152 |
+
if output_logits:
|
153 |
+
result["logit"] = logits[k]
|
154 |
+
|
155 |
+
results.append(result)
|
156 |
+
return results
|
models/grit_src/grit/modeling/__pycache__/soft_nms.cpython-38.pyc
ADDED
Binary file (5.99 kB). View file
|
|
models/grit_src/grit/modeling/backbone/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (6.12 kB). View file
|
|
models/grit_src/grit/modeling/backbone/__pycache__/vit.cpython-38.pyc
ADDED
Binary file (15.6 kB). View file
|
|
models/grit_src/grit/modeling/backbone/utils.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
# This code is from https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/utils.py
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
__all__ = [
|
9 |
+
"window_partition",
|
10 |
+
"window_unpartition",
|
11 |
+
"add_decomposed_rel_pos",
|
12 |
+
"get_abs_pos",
|
13 |
+
"PatchEmbed",
|
14 |
+
]
|
15 |
+
|
16 |
+
def window_partition(x, window_size):
|
17 |
+
"""
|
18 |
+
Partition into non-overlapping windows with padding if needed.
|
19 |
+
Args:
|
20 |
+
x (tensor): input tokens with [B, H, W, C].
|
21 |
+
window_size (int): window size.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
25 |
+
(Hp, Wp): padded height and width before partition
|
26 |
+
"""
|
27 |
+
B, H, W, C = x.shape
|
28 |
+
|
29 |
+
pad_h = (window_size - H % window_size) % window_size
|
30 |
+
pad_w = (window_size - W % window_size) % window_size
|
31 |
+
if pad_h > 0 or pad_w > 0:
|
32 |
+
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
33 |
+
Hp, Wp = H + pad_h, W + pad_w
|
34 |
+
|
35 |
+
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
36 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
37 |
+
return windows, (Hp, Wp)
|
38 |
+
|
39 |
+
|
40 |
+
def window_unpartition(windows, window_size, pad_hw, hw):
|
41 |
+
"""
|
42 |
+
Window unpartition into original sequences and removing padding.
|
43 |
+
Args:
|
44 |
+
x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
45 |
+
window_size (int): window size.
|
46 |
+
pad_hw (Tuple): padded height and width (Hp, Wp).
|
47 |
+
hw (Tuple): original height and width (H, W) before padding.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
x: unpartitioned sequences with [B, H, W, C].
|
51 |
+
"""
|
52 |
+
Hp, Wp = pad_hw
|
53 |
+
H, W = hw
|
54 |
+
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
55 |
+
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
|
56 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
57 |
+
|
58 |
+
if Hp > H or Wp > W:
|
59 |
+
x = x[:, :H, :W, :].contiguous()
|
60 |
+
return x
|
61 |
+
|
62 |
+
|
63 |
+
def get_rel_pos(q_size, k_size, rel_pos):
|
64 |
+
"""
|
65 |
+
Get relative positional embeddings according to the relative positions of
|
66 |
+
query and key sizes.
|
67 |
+
Args:
|
68 |
+
q_size (int): size of query q.
|
69 |
+
k_size (int): size of key k.
|
70 |
+
rel_pos (Tensor): relative position embeddings (L, C).
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
Extracted positional embeddings according to relative positions.
|
74 |
+
"""
|
75 |
+
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
76 |
+
# Interpolate rel pos if needed.
|
77 |
+
if rel_pos.shape[0] != max_rel_dist:
|
78 |
+
# Interpolate rel pos.
|
79 |
+
rel_pos_resized = F.interpolate(
|
80 |
+
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
81 |
+
size=max_rel_dist,
|
82 |
+
mode="linear",
|
83 |
+
)
|
84 |
+
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
85 |
+
else:
|
86 |
+
rel_pos_resized = rel_pos
|
87 |
+
|
88 |
+
# Scale the coords with short length if shapes for q and k are different.
|
89 |
+
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
90 |
+
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
91 |
+
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
92 |
+
|
93 |
+
return rel_pos_resized[relative_coords.long()]
|
94 |
+
|
95 |
+
|
96 |
+
def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
|
97 |
+
"""
|
98 |
+
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
99 |
+
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
|
100 |
+
Args:
|
101 |
+
attn (Tensor): attention map.
|
102 |
+
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
103 |
+
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
|
104 |
+
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
|
105 |
+
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
106 |
+
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
attn (Tensor): attention map with added relative positional embeddings.
|
110 |
+
"""
|
111 |
+
q_h, q_w = q_size
|
112 |
+
k_h, k_w = k_size
|
113 |
+
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
114 |
+
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
115 |
+
|
116 |
+
B, _, dim = q.shape
|
117 |
+
r_q = q.reshape(B, q_h, q_w, dim)
|
118 |
+
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
119 |
+
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
120 |
+
|
121 |
+
attn = (
|
122 |
+
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
123 |
+
).view(B, q_h * q_w, k_h * k_w)
|
124 |
+
|
125 |
+
return attn
|
126 |
+
|
127 |
+
|
128 |
+
def get_abs_pos(abs_pos, has_cls_token, hw):
|
129 |
+
"""
|
130 |
+
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
|
131 |
+
dimension for the original embeddings.
|
132 |
+
Args:
|
133 |
+
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
|
134 |
+
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
|
135 |
+
hw (Tuple): size of input image tokens.
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
Absolute positional embeddings after processing with shape (1, H, W, C)
|
139 |
+
"""
|
140 |
+
h, w = hw
|
141 |
+
if has_cls_token:
|
142 |
+
abs_pos = abs_pos[:, 1:]
|
143 |
+
xy_num = abs_pos.shape[1]
|
144 |
+
size = int(math.sqrt(xy_num))
|
145 |
+
assert size * size == xy_num
|
146 |
+
|
147 |
+
if size != h or size != w:
|
148 |
+
new_abs_pos = F.interpolate(
|
149 |
+
abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
|
150 |
+
size=(h, w),
|
151 |
+
mode="bicubic",
|
152 |
+
align_corners=False,
|
153 |
+
)
|
154 |
+
|
155 |
+
return new_abs_pos.permute(0, 2, 3, 1)
|
156 |
+
else:
|
157 |
+
return abs_pos.reshape(1, h, w, -1)
|
158 |
+
|
159 |
+
|
160 |
+
class PatchEmbed(nn.Module):
|
161 |
+
"""
|
162 |
+
Image to Patch Embedding.
|
163 |
+
"""
|
164 |
+
|
165 |
+
def __init__(
|
166 |
+
self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768
|
167 |
+
):
|
168 |
+
"""
|
169 |
+
Args:
|
170 |
+
kernel_size (Tuple): kernel size of the projection layer.
|
171 |
+
stride (Tuple): stride of the projection layer.
|
172 |
+
padding (Tuple): padding size of the projection layer.
|
173 |
+
in_chans (int): Number of input image channels.
|
174 |
+
embed_dim (int): embed_dim (int): Patch embedding dimension.
|
175 |
+
"""
|
176 |
+
super().__init__()
|
177 |
+
|
178 |
+
self.proj = nn.Conv2d(
|
179 |
+
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
180 |
+
)
|
181 |
+
|
182 |
+
def forward(self, x):
|
183 |
+
x = self.proj(x)
|
184 |
+
# B C H W -> B H W C
|
185 |
+
x = x.permute(0, 2, 3, 1)
|
186 |
+
return x
|
models/grit_src/grit/modeling/backbone/vit.py
ADDED
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified by Jialian Wu from https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py
|
2 |
+
import logging
|
3 |
+
import math
|
4 |
+
import fvcore.nn.weight_init as weight_init
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from functools import partial
|
8 |
+
|
9 |
+
from detectron2.layers import CNNBlockBase, Conv2d, get_norm
|
10 |
+
from detectron2.modeling.backbone.build import BACKBONE_REGISTRY
|
11 |
+
from detectron2.layers import ShapeSpec
|
12 |
+
from centernet.modeling.backbone.fpn_p5 import LastLevelP6P7_P5
|
13 |
+
|
14 |
+
import torch.utils.checkpoint as checkpoint
|
15 |
+
from timm.models.layers import DropPath, Mlp, trunc_normal_
|
16 |
+
|
17 |
+
from detectron2.modeling.backbone.backbone import Backbone
|
18 |
+
from .utils import (
|
19 |
+
PatchEmbed,
|
20 |
+
add_decomposed_rel_pos,
|
21 |
+
get_abs_pos,
|
22 |
+
window_partition,
|
23 |
+
window_unpartition,
|
24 |
+
)
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
__all__ = ["ViT"]
|
30 |
+
|
31 |
+
|
32 |
+
class Attention(nn.Module):
|
33 |
+
"""Multi-head Attention block with relative position embeddings."""
|
34 |
+
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
dim,
|
38 |
+
num_heads=8,
|
39 |
+
qkv_bias=True,
|
40 |
+
use_rel_pos=False,
|
41 |
+
rel_pos_zero_init=True,
|
42 |
+
input_size=None,
|
43 |
+
):
|
44 |
+
"""
|
45 |
+
Args:
|
46 |
+
dim (int): Number of input channels.
|
47 |
+
num_heads (int): Number of attention heads.
|
48 |
+
qkv_bias (bool: If True, add a learnable bias to query, key, value.
|
49 |
+
rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
50 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
51 |
+
input_size (int or None): Input resolution for calculating the relative positional
|
52 |
+
parameter size.
|
53 |
+
"""
|
54 |
+
super().__init__()
|
55 |
+
self.num_heads = num_heads
|
56 |
+
head_dim = dim // num_heads
|
57 |
+
self.scale = head_dim**-0.5
|
58 |
+
|
59 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
60 |
+
self.proj = nn.Linear(dim, dim)
|
61 |
+
|
62 |
+
self.use_rel_pos = use_rel_pos
|
63 |
+
if self.use_rel_pos:
|
64 |
+
# initialize relative positional embeddings
|
65 |
+
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
66 |
+
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
67 |
+
|
68 |
+
if not rel_pos_zero_init:
|
69 |
+
trunc_normal_(self.rel_pos_h, std=0.02)
|
70 |
+
trunc_normal_(self.rel_pos_w, std=0.02)
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
B, H, W, _ = x.shape
|
74 |
+
# qkv with shape (3, B, nHead, H * W, C)
|
75 |
+
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
76 |
+
# q, k, v with shape (B * nHead, H * W, C)
|
77 |
+
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
|
78 |
+
|
79 |
+
attn = (q * self.scale) @ k.transpose(-2, -1)
|
80 |
+
|
81 |
+
if self.use_rel_pos:
|
82 |
+
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
|
83 |
+
|
84 |
+
attn = attn.softmax(dim=-1)
|
85 |
+
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
|
86 |
+
x = self.proj(x)
|
87 |
+
|
88 |
+
return x
|
89 |
+
|
90 |
+
|
91 |
+
class ResBottleneckBlock(CNNBlockBase):
|
92 |
+
"""
|
93 |
+
The standard bottleneck residual block without the last activation layer.
|
94 |
+
It contains 3 conv layers with kernels 1x1, 3x3, 1x1.
|
95 |
+
"""
|
96 |
+
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
in_channels,
|
100 |
+
out_channels,
|
101 |
+
bottleneck_channels,
|
102 |
+
norm="LN",
|
103 |
+
act_layer=nn.GELU,
|
104 |
+
):
|
105 |
+
"""
|
106 |
+
Args:
|
107 |
+
in_channels (int): Number of input channels.
|
108 |
+
out_channels (int): Number of output channels.
|
109 |
+
bottleneck_channels (int): number of output channels for the 3x3
|
110 |
+
"bottleneck" conv layers.
|
111 |
+
norm (str or callable): normalization for all conv layers.
|
112 |
+
See :func:`layers.get_norm` for supported format.
|
113 |
+
act_layer (callable): activation for all conv layers.
|
114 |
+
"""
|
115 |
+
super().__init__(in_channels, out_channels, 1)
|
116 |
+
|
117 |
+
self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False)
|
118 |
+
self.norm1 = get_norm(norm, bottleneck_channels)
|
119 |
+
self.act1 = act_layer()
|
120 |
+
|
121 |
+
self.conv2 = Conv2d(
|
122 |
+
bottleneck_channels,
|
123 |
+
bottleneck_channels,
|
124 |
+
3,
|
125 |
+
padding=1,
|
126 |
+
bias=False,
|
127 |
+
)
|
128 |
+
self.norm2 = get_norm(norm, bottleneck_channels)
|
129 |
+
self.act2 = act_layer()
|
130 |
+
|
131 |
+
self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False)
|
132 |
+
self.norm3 = get_norm(norm, out_channels)
|
133 |
+
|
134 |
+
for layer in [self.conv1, self.conv2, self.conv3]:
|
135 |
+
weight_init.c2_msra_fill(layer)
|
136 |
+
for layer in [self.norm1, self.norm2]:
|
137 |
+
layer.weight.data.fill_(1.0)
|
138 |
+
layer.bias.data.zero_()
|
139 |
+
# zero init last norm layer.
|
140 |
+
self.norm3.weight.data.zero_()
|
141 |
+
self.norm3.bias.data.zero_()
|
142 |
+
|
143 |
+
def forward(self, x):
|
144 |
+
out = x
|
145 |
+
for layer in self.children():
|
146 |
+
out = layer(out)
|
147 |
+
|
148 |
+
out = x + out
|
149 |
+
return out
|
150 |
+
|
151 |
+
|
152 |
+
class Block(nn.Module):
|
153 |
+
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
154 |
+
|
155 |
+
def __init__(
|
156 |
+
self,
|
157 |
+
dim,
|
158 |
+
num_heads,
|
159 |
+
mlp_ratio=4.0,
|
160 |
+
qkv_bias=True,
|
161 |
+
drop_path=0.0,
|
162 |
+
norm_layer=nn.LayerNorm,
|
163 |
+
act_layer=nn.GELU,
|
164 |
+
use_rel_pos=False,
|
165 |
+
rel_pos_zero_init=True,
|
166 |
+
window_size=0,
|
167 |
+
use_residual_block=False,
|
168 |
+
input_size=None,
|
169 |
+
):
|
170 |
+
"""
|
171 |
+
Args:
|
172 |
+
dim (int): Number of input channels.
|
173 |
+
num_heads (int): Number of attention heads in each ViT block.
|
174 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
175 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
176 |
+
drop_path (float): Stochastic depth rate.
|
177 |
+
norm_layer (nn.Module): Normalization layer.
|
178 |
+
act_layer (nn.Module): Activation layer.
|
179 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
180 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
181 |
+
window_size (int): Window size for window attention blocks. If it equals 0, then not
|
182 |
+
use window attention.
|
183 |
+
use_residual_block (bool): If True, use a residual block after the MLP block.
|
184 |
+
input_size (int or None): Input resolution for calculating the relative positional
|
185 |
+
parameter size.
|
186 |
+
"""
|
187 |
+
super().__init__()
|
188 |
+
self.norm1 = norm_layer(dim)
|
189 |
+
self.attn = Attention(
|
190 |
+
dim,
|
191 |
+
num_heads=num_heads,
|
192 |
+
qkv_bias=qkv_bias,
|
193 |
+
use_rel_pos=use_rel_pos,
|
194 |
+
rel_pos_zero_init=rel_pos_zero_init,
|
195 |
+
input_size=input_size if window_size == 0 else (window_size, window_size),
|
196 |
+
)
|
197 |
+
|
198 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
199 |
+
self.norm2 = norm_layer(dim)
|
200 |
+
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer)
|
201 |
+
|
202 |
+
self.window_size = window_size
|
203 |
+
|
204 |
+
self.use_residual_block = use_residual_block
|
205 |
+
if use_residual_block:
|
206 |
+
# Use a residual block with bottleneck channel as dim // 2
|
207 |
+
self.residual = ResBottleneckBlock(
|
208 |
+
in_channels=dim,
|
209 |
+
out_channels=dim,
|
210 |
+
bottleneck_channels=dim // 2,
|
211 |
+
norm="LN",
|
212 |
+
act_layer=act_layer,
|
213 |
+
)
|
214 |
+
|
215 |
+
def forward(self, x):
|
216 |
+
shortcut = x
|
217 |
+
x = self.norm1(x)
|
218 |
+
# Window partition
|
219 |
+
if self.window_size > 0:
|
220 |
+
H, W = x.shape[1], x.shape[2]
|
221 |
+
x, pad_hw = window_partition(x, self.window_size)
|
222 |
+
|
223 |
+
x = self.attn(x)
|
224 |
+
# Reverse window partition
|
225 |
+
if self.window_size > 0:
|
226 |
+
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
227 |
+
|
228 |
+
x = shortcut + self.drop_path(x)
|
229 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
230 |
+
|
231 |
+
if self.use_residual_block:
|
232 |
+
x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
|
233 |
+
|
234 |
+
return x
|
235 |
+
|
236 |
+
|
237 |
+
class ViT(Backbone):
|
238 |
+
"""
|
239 |
+
This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
|
240 |
+
"Exploring Plain Vision Transformer Backbones for Object Detection",
|
241 |
+
https://arxiv.org/abs/2203.16527
|
242 |
+
"""
|
243 |
+
|
244 |
+
def __init__(
|
245 |
+
self,
|
246 |
+
img_size=1024,
|
247 |
+
patch_size=16,
|
248 |
+
in_chans=3,
|
249 |
+
embed_dim=768,
|
250 |
+
depth=12,
|
251 |
+
num_heads=12,
|
252 |
+
mlp_ratio=4.0,
|
253 |
+
qkv_bias=True,
|
254 |
+
drop_path_rate=0.0,
|
255 |
+
norm_layer=nn.LayerNorm,
|
256 |
+
act_layer=nn.GELU,
|
257 |
+
use_abs_pos=True,
|
258 |
+
use_rel_pos=False,
|
259 |
+
rel_pos_zero_init=True,
|
260 |
+
window_size=0,
|
261 |
+
window_block_indexes=(),
|
262 |
+
residual_block_indexes=(),
|
263 |
+
use_act_checkpoint=True,
|
264 |
+
pretrain_img_size=224,
|
265 |
+
pretrain_use_cls_token=True,
|
266 |
+
out_feature="last_feat",
|
267 |
+
):
|
268 |
+
"""
|
269 |
+
Args:
|
270 |
+
img_size (int): Input image size.
|
271 |
+
patch_size (int): Patch size.
|
272 |
+
in_chans (int): Number of input image channels.
|
273 |
+
embed_dim (int): Patch embedding dimension.
|
274 |
+
depth (int): Depth of ViT.
|
275 |
+
num_heads (int): Number of attention heads in each ViT block.
|
276 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
277 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
278 |
+
drop_path_rate (float): Stochastic depth rate.
|
279 |
+
norm_layer (nn.Module): Normalization layer.
|
280 |
+
act_layer (nn.Module): Activation layer.
|
281 |
+
use_abs_pos (bool): If True, use absolute positional embeddings.
|
282 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
283 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
284 |
+
window_size (int): Window size for window attention blocks.
|
285 |
+
window_block_indexes (list): Indexes for blocks using window attention.
|
286 |
+
residual_block_indexes (list): Indexes for blocks using conv propagation.
|
287 |
+
use_act_checkpoint (bool): If True, use activation checkpointing.
|
288 |
+
pretrain_img_size (int): input image size for pretraining models.
|
289 |
+
pretrain_use_cls_token (bool): If True, pretrainig models use class token.
|
290 |
+
out_feature (str): name of the feature from the last block.
|
291 |
+
"""
|
292 |
+
super().__init__()
|
293 |
+
self.pretrain_use_cls_token = pretrain_use_cls_token
|
294 |
+
self.use_act_checkpoint = use_act_checkpoint
|
295 |
+
|
296 |
+
self.patch_embed = PatchEmbed(
|
297 |
+
kernel_size=(patch_size, patch_size),
|
298 |
+
stride=(patch_size, patch_size),
|
299 |
+
in_chans=in_chans,
|
300 |
+
embed_dim=embed_dim,
|
301 |
+
)
|
302 |
+
|
303 |
+
if use_abs_pos:
|
304 |
+
# Initialize absolute positional embedding with pretrain image size.
|
305 |
+
num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
|
306 |
+
num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
|
307 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
|
308 |
+
else:
|
309 |
+
self.pos_embed = None
|
310 |
+
|
311 |
+
# stochastic depth decay rule
|
312 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
313 |
+
|
314 |
+
self.blocks = nn.ModuleList()
|
315 |
+
for i in range(depth):
|
316 |
+
block = Block(
|
317 |
+
dim=embed_dim,
|
318 |
+
num_heads=num_heads,
|
319 |
+
mlp_ratio=mlp_ratio,
|
320 |
+
qkv_bias=qkv_bias,
|
321 |
+
drop_path=dpr[i],
|
322 |
+
norm_layer=norm_layer,
|
323 |
+
act_layer=act_layer,
|
324 |
+
use_rel_pos=use_rel_pos,
|
325 |
+
rel_pos_zero_init=rel_pos_zero_init,
|
326 |
+
window_size=window_size if i in window_block_indexes else 0,
|
327 |
+
use_residual_block=i in residual_block_indexes,
|
328 |
+
input_size=(img_size // patch_size, img_size // patch_size),
|
329 |
+
)
|
330 |
+
self.blocks.append(block)
|
331 |
+
|
332 |
+
self._out_feature_channels = {out_feature: embed_dim}
|
333 |
+
self._out_feature_strides = {out_feature: patch_size}
|
334 |
+
self._out_features = [out_feature]
|
335 |
+
|
336 |
+
if self.pos_embed is not None:
|
337 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
338 |
+
|
339 |
+
self.apply(self._init_weights)
|
340 |
+
|
341 |
+
def _init_weights(self, m):
|
342 |
+
if isinstance(m, nn.Linear):
|
343 |
+
trunc_normal_(m.weight, std=0.02)
|
344 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
345 |
+
nn.init.constant_(m.bias, 0)
|
346 |
+
elif isinstance(m, nn.LayerNorm):
|
347 |
+
nn.init.constant_(m.bias, 0)
|
348 |
+
nn.init.constant_(m.weight, 1.0)
|
349 |
+
|
350 |
+
def forward(self, x):
|
351 |
+
x = self.patch_embed(x)
|
352 |
+
if self.pos_embed is not None:
|
353 |
+
x = x + get_abs_pos(
|
354 |
+
self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
|
355 |
+
)
|
356 |
+
|
357 |
+
for blk in self.blocks:
|
358 |
+
if self.use_act_checkpoint:
|
359 |
+
x = checkpoint.checkpoint(blk, x)
|
360 |
+
else:
|
361 |
+
x = blk(x)
|
362 |
+
|
363 |
+
return x.permute(0, 3, 1, 2)
|
364 |
+
|
365 |
+
|
366 |
+
class ViT_FPN(Backbone):
|
367 |
+
def __init__(self, bottom_up=None, top_block=None, out_channels=None, strides=None, vit_out_dim=None):
|
368 |
+
super(ViT_FPN, self).__init__()
|
369 |
+
assert isinstance(bottom_up, Backbone)
|
370 |
+
self.bottom_up = bottom_up
|
371 |
+
self.top_block = top_block
|
372 |
+
|
373 |
+
self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides}
|
374 |
+
self._out_features = list(self._out_feature_strides.keys())
|
375 |
+
self._out_feature_channels = {k: out_channels for k in self._out_features}
|
376 |
+
self._size_divisibility = strides[2]
|
377 |
+
|
378 |
+
self.maxpool = nn.MaxPool2d(2, stride=2)
|
379 |
+
self.fpn_stride_16_8 = nn.ConvTranspose2d(vit_out_dim, vit_out_dim, 2, stride=2, bias=False)
|
380 |
+
self.fpn_stride8_conv1 = nn.Conv2d(in_channels=vit_out_dim, out_channels=out_channels, kernel_size=1, bias=False)
|
381 |
+
self.fpn_stride8_norm1 = nn.LayerNorm(out_channels)
|
382 |
+
self.fpn_stride8_conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
383 |
+
self.fpn_stride8_norm2 = nn.LayerNorm(out_channels)
|
384 |
+
|
385 |
+
self.fpn_stride16_conv1 = nn.Conv2d(in_channels=vit_out_dim, out_channels=out_channels, kernel_size=1, bias=False)
|
386 |
+
self.fpn_stride16_norm1 = nn.LayerNorm(out_channels)
|
387 |
+
self.fpn_stride16_conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
388 |
+
self.fpn_stride16_norm2 = nn.LayerNorm(out_channels)
|
389 |
+
|
390 |
+
self.fpn_stride32_conv1 = nn.Conv2d(in_channels=vit_out_dim, out_channels=out_channels, kernel_size=1, bias=False)
|
391 |
+
self.fpn_stride32_norm1 = nn.LayerNorm(out_channels)
|
392 |
+
self.fpn_stride32_conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
393 |
+
self.fpn_stride32_norm2 = nn.LayerNorm(out_channels)
|
394 |
+
|
395 |
+
def forward(self, x):
|
396 |
+
vit_output_featuremap = self.bottom_up(x)
|
397 |
+
|
398 |
+
stride8_feature = self.fpn_stride_16_8(vit_output_featuremap)
|
399 |
+
stride8_feature = self.fpn_stride8_norm1(self.fpn_stride8_conv1(stride8_feature)
|
400 |
+
.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
401 |
+
stride8_feature = self.fpn_stride8_norm2(self.fpn_stride8_conv2(stride8_feature)
|
402 |
+
.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
403 |
+
|
404 |
+
stride32_feature = self.maxpool(vit_output_featuremap)
|
405 |
+
stride32_feature = self.fpn_stride32_norm1(self.fpn_stride32_conv1(stride32_feature)
|
406 |
+
.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
407 |
+
stride32_feature = self.fpn_stride32_norm2(self.fpn_stride32_conv2(stride32_feature)
|
408 |
+
.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
409 |
+
|
410 |
+
stride16_feature = self.fpn_stride16_norm1(self.fpn_stride16_conv1(vit_output_featuremap).
|
411 |
+
permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
412 |
+
stride16_feature = self.fpn_stride16_norm2(self.fpn_stride16_conv2(stride16_feature)
|
413 |
+
.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
414 |
+
|
415 |
+
results = [stride8_feature, stride16_feature, stride32_feature]
|
416 |
+
|
417 |
+
results.extend(self.top_block(stride32_feature))
|
418 |
+
|
419 |
+
assert len(self._out_features) == len(results)
|
420 |
+
fpn_out = {f: res for f, res in zip(self._out_features, results)}
|
421 |
+
|
422 |
+
return fpn_out
|
423 |
+
@property
|
424 |
+
def size_divisibility(self):
|
425 |
+
return self._size_divisibility
|
426 |
+
|
427 |
+
def output_shape(self):
|
428 |
+
return {
|
429 |
+
name: ShapeSpec(
|
430 |
+
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
|
431 |
+
)
|
432 |
+
for name in self._out_features
|
433 |
+
}
|
434 |
+
|
435 |
+
|
436 |
+
@BACKBONE_REGISTRY.register()
|
437 |
+
def build_vit_fpn_backbone(cfg, input_shape: ShapeSpec):
|
438 |
+
embed_dim = 768
|
439 |
+
vit_out_dim = embed_dim
|
440 |
+
bottom_up = ViT( # Single-scale ViT backbone
|
441 |
+
img_size=1024,
|
442 |
+
patch_size=16,
|
443 |
+
embed_dim=embed_dim,
|
444 |
+
depth=12,
|
445 |
+
num_heads=12,
|
446 |
+
drop_path_rate=0.1,
|
447 |
+
window_size=14,
|
448 |
+
mlp_ratio=4,
|
449 |
+
qkv_bias=True,
|
450 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
451 |
+
window_block_indexes=[
|
452 |
+
# 2, 5, 8 11 for global attention
|
453 |
+
0,
|
454 |
+
1,
|
455 |
+
3,
|
456 |
+
4,
|
457 |
+
6,
|
458 |
+
7,
|
459 |
+
9,
|
460 |
+
10,
|
461 |
+
],
|
462 |
+
residual_block_indexes=[],
|
463 |
+
use_act_checkpoint=cfg.USE_ACT_CHECKPOINT,
|
464 |
+
use_rel_pos=True,
|
465 |
+
out_feature="last_feat",)
|
466 |
+
|
467 |
+
out_channels = cfg.MODEL.FPN.OUT_CHANNELS
|
468 |
+
assert out_channels == 256 or out_channels == 768 or out_channels == 1024
|
469 |
+
backbone = ViT_FPN(bottom_up=bottom_up,
|
470 |
+
top_block=LastLevelP6P7_P5(out_channels, out_channels),
|
471 |
+
out_channels=out_channels,
|
472 |
+
strides=[8, 16, 32, 64, 128],
|
473 |
+
vit_out_dim=vit_out_dim)
|
474 |
+
return backbone
|
475 |
+
|
476 |
+
|
477 |
+
@BACKBONE_REGISTRY.register()
|
478 |
+
def build_vit_fpn_backbone_large(cfg, input_shape: ShapeSpec):
|
479 |
+
window_block_indexes = (list(range(0, 5)) + list(range(6, 11)) + list(range(12, 17)) + list(range(18, 23)))
|
480 |
+
embed_dim = 1024
|
481 |
+
vit_out_dim = embed_dim
|
482 |
+
bottom_up = ViT( # Single-scale ViT backbone
|
483 |
+
img_size=1024,
|
484 |
+
patch_size=16,
|
485 |
+
embed_dim=embed_dim,
|
486 |
+
depth=24,
|
487 |
+
num_heads=16,
|
488 |
+
drop_path_rate=0.4,
|
489 |
+
window_size=14,
|
490 |
+
mlp_ratio=4,
|
491 |
+
qkv_bias=True,
|
492 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
493 |
+
window_block_indexes=window_block_indexes,
|
494 |
+
residual_block_indexes=[],
|
495 |
+
use_act_checkpoint=cfg.USE_ACT_CHECKPOINT,
|
496 |
+
use_rel_pos=True,
|
497 |
+
out_feature="last_feat",)
|
498 |
+
|
499 |
+
out_channels = cfg.MODEL.FPN.OUT_CHANNELS
|
500 |
+
assert out_channels == 256 or out_channels == 768 or out_channels == 1024
|
501 |
+
backbone = ViT_FPN(bottom_up=bottom_up,
|
502 |
+
top_block=LastLevelP6P7_P5(out_channels, out_channels),
|
503 |
+
out_channels=out_channels,
|
504 |
+
strides=[8, 16, 32, 64, 128],
|
505 |
+
vit_out_dim=vit_out_dim)
|
506 |
+
return backbone
|
507 |
+
|
508 |
+
|
509 |
+
@BACKBONE_REGISTRY.register()
|
510 |
+
def build_vit_fpn_backbone_huge(cfg, input_shape: ShapeSpec):
|
511 |
+
window_block_indexes = (list(range(0, 7)) + list(range(8, 15)) + list(range(16, 23)) + list(range(24, 31)))
|
512 |
+
embed_dim = 1280
|
513 |
+
vit_out_dim = embed_dim
|
514 |
+
bottom_up = ViT( # Single-scale ViT backbone
|
515 |
+
img_size=1024,
|
516 |
+
patch_size=16,
|
517 |
+
embed_dim=embed_dim,
|
518 |
+
depth=32,
|
519 |
+
num_heads=16,
|
520 |
+
drop_path_rate=0.5,
|
521 |
+
window_size=14,
|
522 |
+
mlp_ratio=4,
|
523 |
+
qkv_bias=True,
|
524 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
525 |
+
window_block_indexes=window_block_indexes,
|
526 |
+
residual_block_indexes=[],
|
527 |
+
use_act_checkpoint=cfg.USE_ACT_CHECKPOINT,
|
528 |
+
use_rel_pos=True,
|
529 |
+
out_feature="last_feat",)
|
530 |
+
|
531 |
+
out_channels = cfg.MODEL.FPN.OUT_CHANNELS
|
532 |
+
assert out_channels == 256 or out_channels == 768 or out_channels == 1024
|
533 |
+
backbone = ViT_FPN(bottom_up=bottom_up,
|
534 |
+
top_block=LastLevelP6P7_P5(out_channels, out_channels),
|
535 |
+
out_channels=out_channels,
|
536 |
+
strides=[8, 16, 32, 64, 128],
|
537 |
+
vit_out_dim=vit_out_dim)
|
538 |
+
return backbone
|
models/grit_src/grit/modeling/meta_arch/__pycache__/grit.cpython-38.pyc
ADDED
Binary file (2.49 kB). View file
|
|
models/grit_src/grit/modeling/meta_arch/grit.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional, Tuple
|
2 |
+
import torch
|
3 |
+
from detectron2.config import configurable
|
4 |
+
from detectron2.structures import ImageList, Instances, Boxes
|
5 |
+
from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
|
6 |
+
from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN
|
7 |
+
|
8 |
+
|
9 |
+
@META_ARCH_REGISTRY.register()
|
10 |
+
class GRiT(GeneralizedRCNN):
|
11 |
+
@configurable
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
**kwargs):
|
15 |
+
super().__init__(**kwargs)
|
16 |
+
assert self.proposal_generator is not None
|
17 |
+
|
18 |
+
@classmethod
|
19 |
+
def from_config(cls, cfg):
|
20 |
+
ret = super().from_config(cfg)
|
21 |
+
return ret
|
22 |
+
|
23 |
+
def inference(
|
24 |
+
self,
|
25 |
+
batched_inputs: Tuple[Dict[str, torch.Tensor]],
|
26 |
+
detected_instances: Optional[List[Instances]] = None,
|
27 |
+
do_postprocess: bool = True,
|
28 |
+
):
|
29 |
+
assert not self.training
|
30 |
+
assert detected_instances is None
|
31 |
+
|
32 |
+
images = self.preprocess_image(batched_inputs)
|
33 |
+
features = self.backbone(images.tensor)
|
34 |
+
proposals, _ = self.proposal_generator(images, features, None)
|
35 |
+
results, _ = self.roi_heads(features, proposals)
|
36 |
+
if do_postprocess:
|
37 |
+
assert not torch.jit.is_scripting(), \
|
38 |
+
"Scripting is not supported for postprocess."
|
39 |
+
return GRiT._postprocess(
|
40 |
+
results, batched_inputs, images.image_sizes)
|
41 |
+
else:
|
42 |
+
return results
|
43 |
+
|
44 |
+
def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]):
|
45 |
+
if not self.training:
|
46 |
+
return self.inference(batched_inputs)
|
47 |
+
|
48 |
+
images = self.preprocess_image(batched_inputs)
|
49 |
+
|
50 |
+
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
|
51 |
+
|
52 |
+
targets_task = batched_inputs[0]['task']
|
53 |
+
for anno_per_image in batched_inputs:
|
54 |
+
assert targets_task == anno_per_image['task']
|
55 |
+
|
56 |
+
features = self.backbone(images.tensor)
|
57 |
+
proposals, proposal_losses = self.proposal_generator(
|
58 |
+
images, features, gt_instances)
|
59 |
+
proposals, roihead_textdecoder_losses = self.roi_heads(
|
60 |
+
features, proposals, gt_instances, targets_task=targets_task)
|
61 |
+
|
62 |
+
losses = {}
|
63 |
+
losses.update(roihead_textdecoder_losses)
|
64 |
+
losses.update(proposal_losses)
|
65 |
+
|
66 |
+
return losses
|