Spaces:
Runtime error
Runtime error
stellaathena
commited on
Commit
•
bb5cd12
1
Parent(s):
23ee17f
This should work
Browse files- LICENSE +21 -0
- app.py +54 -0
- configs/MAGMA_v1.yml +33 -0
- configs/MAGMA_v2.yml +36 -0
- example_inference.py +27 -0
- examples/magma_oracle.png +0 -0
- examples/magma_present.jpg +0 -0
- examples/magma_social.png +0 -0
- examples/magma_treasure.png +0 -0
- examples/magma_tree.jpg +0 -0
- examples/model.jpg +0 -0
- magma/__init__.py +20 -0
- magma/adapters.py +116 -0
- magma/config.py +144 -0
- magma/datasets/__init__.py +5 -0
- magma/datasets/convert_datasets.py +118 -0
- magma/datasets/dataset.py +160 -0
- magma/image_encoders.py +91 -0
- magma/image_input.py +24 -0
- magma/image_prefix.py +109 -0
- magma/language_model.py +45 -0
- magma/magma.py +301 -0
- magma/sampling.py +121 -0
- magma/train_loop.py +98 -0
- magma/transforms.py +134 -0
- magma/utils.py +372 -0
- requirements.txt +9 -0
- test.py +43 -0
- train.py +192 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Aleph Alpha GmbH
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
app.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import re
|
3 |
+
from magma import Magma
|
4 |
+
from magma.image_input import ImageInput
|
5 |
+
|
6 |
+
model = Magma.from_checkpoint(
|
7 |
+
config_path = "configs/MAGMA_v1.yml",
|
8 |
+
checkpoint_path = "./mp_rank_00_model_states.pt",
|
9 |
+
device = 'cuda:0'
|
10 |
+
)
|
11 |
+
|
12 |
+
def generate(context, length, temperature, top_k):
|
13 |
+
context = context.strip()
|
14 |
+
|
15 |
+
url_regex = r'https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)'
|
16 |
+
lines = context.split('\n')
|
17 |
+
inputs = []
|
18 |
+
for line in lines:
|
19 |
+
if re.match(url_regex, line):
|
20 |
+
try:
|
21 |
+
inputs.append(ImageInput(line))
|
22 |
+
except Exception as e:
|
23 |
+
return str(e)
|
24 |
+
else:
|
25 |
+
inputs.append(line)
|
26 |
+
|
27 |
+
## returns a tensor of shape: (1, 149, 4096)
|
28 |
+
embeddings = model.preprocess_inputs(inputs)
|
29 |
+
|
30 |
+
## returns a list of length embeddings.shape[0] (batch size)
|
31 |
+
output = model.generate(
|
32 |
+
embeddings = embeddings,
|
33 |
+
max_steps = length,
|
34 |
+
temperature = (0.01 if temperature == 0 else temperature),
|
35 |
+
top_k = top_k
|
36 |
+
)
|
37 |
+
|
38 |
+
return context + output[0]
|
39 |
+
|
40 |
+
iface = gr.Interface(
|
41 |
+
fn=generate,
|
42 |
+
inputs=[
|
43 |
+
gr.inputs.Textbox(
|
44 |
+
label="Prompt (image URLs need to be on their own lines):",
|
45 |
+
default="https://www.art-prints-on-demand.com/kunst/thomas_cole/woods_hi.jpg\nDescribe the painting:",
|
46 |
+
lines=7),
|
47 |
+
gr.inputs.Slider(minimum=1, maximum=100, default=15, step=1, label="Output tokens:"),
|
48 |
+
gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.7, label='Temperature'),
|
49 |
+
gr.inputs.Slider(minimum=0, maximum=100, default=0, step=1, label='Top K')
|
50 |
+
],
|
51 |
+
outputs=["textbox"]
|
52 |
+
).launch(share=True)
|
53 |
+
|
54 |
+
|
configs/MAGMA_v1.yml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
# image encoder settings
|
3 |
+
encoder_name: 'clip_resnet_large',
|
4 |
+
adapter_config: {"mlp": {"adapter_type": "normal", "downsample_factor": 4}},
|
5 |
+
freeze_img_encoder: false,
|
6 |
+
|
7 |
+
# train settings
|
8 |
+
batch_size: 256,
|
9 |
+
train_steps: 150000,
|
10 |
+
lr: 8.0e-4,
|
11 |
+
min_lr: 0.0,
|
12 |
+
lr_decay_iters: 300000,
|
13 |
+
image_enc_lr: 2.0e-6,
|
14 |
+
use_image_embed_layernorm: true,
|
15 |
+
image_embed_dropout_prob: 0.1,
|
16 |
+
image_size: 384,
|
17 |
+
|
18 |
+
gradient_accumulation_steps: 8,
|
19 |
+
zero_stage: 2,
|
20 |
+
gradient_clipping: 1.0,
|
21 |
+
|
22 |
+
# dataset / save / load settings
|
23 |
+
train_dataset_name: 'conceptual_captions',
|
24 |
+
train_dataset_dir: '/mnt/localdisk/conceptual_captions',
|
25 |
+
eval_dataset_name: 'coco',
|
26 |
+
eval_dataset_dir: '/mnt/localdisk/coco_data',
|
27 |
+
|
28 |
+
save: "/mnt/shared_vol/checkpoints/multimodal_transformer_rn50x16",
|
29 |
+
load: "/mnt/shared_vol/checkpoints/multimodal_transformer_rn50x16",
|
30 |
+
|
31 |
+
eval_every: 100,
|
32 |
+
|
33 |
+
}
|
configs/MAGMA_v2.yml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
# image encoder settings
|
3 |
+
encoder_name: 'clip_resnet_large',
|
4 |
+
adapter_config: {"mlp": {"adapter_type": "normal", "downsample_factor": 8}, "attention": {"adapter_type": "normal", "downsample_factor": 8}},
|
5 |
+
freeze_img_encoder: false,
|
6 |
+
|
7 |
+
# train settings
|
8 |
+
batch_size: 256,
|
9 |
+
train_steps: 150000,
|
10 |
+
lr: 8.0e-4,
|
11 |
+
min_lr: 0.0,
|
12 |
+
lr_decay_iters: 300000,
|
13 |
+
image_enc_lr: 2.0e-6,
|
14 |
+
use_image_embed_layernorm: true,
|
15 |
+
image_embed_dropout_prob: 0.1,
|
16 |
+
image_size: 384,
|
17 |
+
|
18 |
+
gradient_accumulation_steps: 4,
|
19 |
+
zero_stage: 2,
|
20 |
+
gradient_clipping: 1.0,
|
21 |
+
|
22 |
+
# dataset / save / load settings
|
23 |
+
dataset_type: 'new',
|
24 |
+
train_dataset_dir: ['/mnt/localdisk/laion', '/mnt/brick/CC3M_converted', '/mnt/localdisk/localized_narratives', '/mnt/localdisk/visual_genome_converted', '/mnt/localdisk/hateful_memes_converted', '/mnt/localdisk/coco_converted', '/mnt/brick/wit_converted', '/mnt/localdisk/gqa_train_converted', '/mnt/localdisk/vqa_train_converted', '/mnt/localdisk/okvqa_train_converted'], #'/mnt/brick/wit_converted'
|
25 |
+
|
26 |
+
eval_dataset_dir: null, # if this is none, train dataset will be split
|
27 |
+
vqa_dir: "/mnt/localdisk/vqa_val_converted",
|
28 |
+
gqa_dir: "/mnt/localdisk/gqa_val_converted",
|
29 |
+
|
30 |
+
save: "/mnt/shared_vol/checkpoints/MAGMA_RN50x16",
|
31 |
+
load: "/mnt/shared_vol/checkpoints/MAGMA_RN50x16",
|
32 |
+
|
33 |
+
eval_every: 250,
|
34 |
+
wandb_project: "MAGMA_training",
|
35 |
+
name: "MAGMA_RN50x16_v1"
|
36 |
+
}
|
example_inference.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from magma import Magma
|
2 |
+
from magma.image_input import ImageInput
|
3 |
+
|
4 |
+
model = Magma.from_checkpoint(
|
5 |
+
config_path = "configs/MAGMA_v1.yml",
|
6 |
+
checkpoint_path = "./mp_rank_00_model_states.pt",
|
7 |
+
device = 'cuda:0'
|
8 |
+
)
|
9 |
+
|
10 |
+
inputs =[
|
11 |
+
## supports urls and path/to/image
|
12 |
+
ImageInput('https://www.art-prints-on-demand.com/kunst/thomas_cole/woods_hi.jpg'),
|
13 |
+
'Describe the painting:'
|
14 |
+
]
|
15 |
+
|
16 |
+
## returns a tensor of shape: (1, 149, 4096)
|
17 |
+
embeddings = model.preprocess_inputs(inputs)
|
18 |
+
|
19 |
+
## returns a list of length embeddings.shape[0] (batch size)
|
20 |
+
output = model.generate(
|
21 |
+
embeddings = embeddings,
|
22 |
+
max_steps = 6,
|
23 |
+
temperature = 0.7,
|
24 |
+
top_k = 0,
|
25 |
+
)
|
26 |
+
|
27 |
+
print(output[0]) ## A cabin on a lake
|
examples/magma_oracle.png
ADDED
examples/magma_present.jpg
ADDED
examples/magma_social.png
ADDED
examples/magma_treasure.png
ADDED
examples/magma_tree.jpg
ADDED
examples/model.jpg
ADDED
magma/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .config import MultimodalConfig
|
2 |
+
from .magma import Magma
|
3 |
+
from .language_model import get_gptj
|
4 |
+
from .transforms import get_transforms
|
5 |
+
from .utils import (
|
6 |
+
count_parameters,
|
7 |
+
is_main,
|
8 |
+
cycle,
|
9 |
+
get_tokenizer,
|
10 |
+
parse_args,
|
11 |
+
wandb_log,
|
12 |
+
wandb_init,
|
13 |
+
save_model,
|
14 |
+
load_model,
|
15 |
+
print_main,
|
16 |
+
configure_param_groups,
|
17 |
+
log_table,
|
18 |
+
)
|
19 |
+
from .train_loop import eval_step, inference_step, train_step
|
20 |
+
from .datasets import collate_fn
|
magma/adapters.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torchtyping import TensorType
|
4 |
+
|
5 |
+
|
6 |
+
class Adapter(nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
dim: int,
|
10 |
+
downsample_factor: int = 4,
|
11 |
+
activation: nn.Module = nn.ReLU,
|
12 |
+
add_layernorm: bool = False,
|
13 |
+
):
|
14 |
+
super().__init__()
|
15 |
+
layers = []
|
16 |
+
if add_layernorm:
|
17 |
+
layers.append(nn.LayerNorm(dim))
|
18 |
+
layers.extend(
|
19 |
+
[
|
20 |
+
nn.Linear(dim, dim // downsample_factor),
|
21 |
+
activation(),
|
22 |
+
nn.Linear(dim // downsample_factor, dim),
|
23 |
+
]
|
24 |
+
)
|
25 |
+
self.adapter = nn.Sequential(*layers)
|
26 |
+
self.adapter.apply(self.init_weights)
|
27 |
+
|
28 |
+
def init_weights(self, m: nn.Module, std=1e-3):
|
29 |
+
if isinstance(m, nn.Linear):
|
30 |
+
torch.nn.init.normal_(m.weight, std=std)
|
31 |
+
torch.nn.init.normal_(m.bias, std=std)
|
32 |
+
m.weight.data = torch.clamp(m.weight.data, min=-2 * std, max=2 * std)
|
33 |
+
m.bias.data = torch.clamp(m.bias.data, min=-2 * std, max=2 * std)
|
34 |
+
elif isinstance(m, nn.LayerNorm):
|
35 |
+
m.bias.data.zero_()
|
36 |
+
m.weight.data.fill_(1.0)
|
37 |
+
|
38 |
+
def forward(self, x: TensorType["b", "s", "d"]) -> TensorType["b", "s", "d"]:
|
39 |
+
return self.adapter(x) + x
|
40 |
+
|
41 |
+
|
42 |
+
class ParallelAdapter(Adapter):
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
module: nn.Module,
|
46 |
+
dim: int,
|
47 |
+
downsample_factor: int = 4,
|
48 |
+
scaled: bool = False,
|
49 |
+
add_layernorm: bool = False,
|
50 |
+
activation: nn.Module = nn.ReLU,
|
51 |
+
):
|
52 |
+
super().__init__(
|
53 |
+
dim, downsample_factor, add_layernorm=add_layernorm, activation=activation
|
54 |
+
)
|
55 |
+
self.module = module
|
56 |
+
|
57 |
+
if scaled:
|
58 |
+
# init scaling param
|
59 |
+
self.adapter_scale = nn.Parameter(torch.ones(1))
|
60 |
+
else:
|
61 |
+
self.adapter_scale = 1
|
62 |
+
|
63 |
+
def forward(self, x: TensorType["b", "s", "d"], **module_kwargs):
|
64 |
+
y = self.module(x, **module_kwargs)
|
65 |
+
z = self.adapter(x)
|
66 |
+
return y + (z * self.adapter_scale)
|
67 |
+
|
68 |
+
|
69 |
+
class ParallelAdapterWrapper(ParallelAdapter):
|
70 |
+
# used to add an adapter to the attention block
|
71 |
+
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
module: nn.Module,
|
75 |
+
dim: int,
|
76 |
+
downsample_factor: int = 4,
|
77 |
+
scaled: bool = False,
|
78 |
+
add_layernorm: bool = False,
|
79 |
+
activation: nn.Module = nn.ReLU,
|
80 |
+
):
|
81 |
+
super().__init__(
|
82 |
+
module, dim, downsample_factor, scaled, add_layernorm, activation
|
83 |
+
)
|
84 |
+
|
85 |
+
def forward(self, x: TensorType["b", "s", "d"], *attn_args, **attn_kwargs):
|
86 |
+
attn_outputs = self.module(x, *attn_args, **attn_kwargs)
|
87 |
+
attn_output, outputs = (
|
88 |
+
attn_outputs[0],
|
89 |
+
attn_outputs[1:],
|
90 |
+
) # output_attn: a, present, (attentions)
|
91 |
+
hidden_states = attn_output + (self.adapter(x) * self.adapter_scale)
|
92 |
+
return (hidden_states,) + outputs
|
93 |
+
|
94 |
+
|
95 |
+
class AdapterWrapper(Adapter):
|
96 |
+
# used to add an adapter to the attention block
|
97 |
+
|
98 |
+
def __init__(
|
99 |
+
self,
|
100 |
+
attn_block: nn.Module,
|
101 |
+
dim: int,
|
102 |
+
downsample_factor: int = 4,
|
103 |
+
activation: nn.Module = nn.ReLU,
|
104 |
+
add_layernorm: bool = False,
|
105 |
+
):
|
106 |
+
super().__init__(dim, downsample_factor, activation, add_layernorm)
|
107 |
+
self.attn_block = attn_block
|
108 |
+
|
109 |
+
def forward(self, x: TensorType["b", "s", "d"], *attn_args, **attn_kwargs):
|
110 |
+
attn_outputs = self.attn_block(x, *attn_args, **attn_kwargs)
|
111 |
+
attn_output, outputs = (
|
112 |
+
attn_outputs[0],
|
113 |
+
attn_outputs[1:],
|
114 |
+
) # output_attn: a, present, (attentions)
|
115 |
+
hidden_states = self.adapter(attn_output) + attn_output
|
116 |
+
return (hidden_states,) + outputs
|
magma/config.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, asdict
|
2 |
+
import yaml
|
3 |
+
from pprint import pprint
|
4 |
+
from .utils import is_main
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
import uuid
|
8 |
+
|
9 |
+
|
10 |
+
def load_config(path, config_dir=Path("configs")):
|
11 |
+
if not path.endswith(".yml"):
|
12 |
+
path += ".yml"
|
13 |
+
if not os.path.exists(path):
|
14 |
+
path = config_dir / path
|
15 |
+
with open(path, "r") as stream:
|
16 |
+
config = yaml.safe_load(stream)
|
17 |
+
return config
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class MultimodalConfig:
|
22 |
+
|
23 |
+
# Training:
|
24 |
+
# ------------------------------------------------------------
|
25 |
+
|
26 |
+
batch_size: int
|
27 |
+
train_steps: int
|
28 |
+
optimizer_name: str = "AdamW"
|
29 |
+
lr: float = 8.0e-4
|
30 |
+
image_enc_lr: float = None
|
31 |
+
min_lr: float = 0.0
|
32 |
+
lr_decay_iters: int = None
|
33 |
+
gradient_accumulation_steps: int = 1
|
34 |
+
image_size: int = 256
|
35 |
+
eval_every: int = 250
|
36 |
+
eval_steps: int = 25
|
37 |
+
zero_stage: int = 2
|
38 |
+
gradient_clipping: float = 1.0
|
39 |
+
warmup_num_steps: int = 100
|
40 |
+
weight_decay: float = 0.00
|
41 |
+
run_blind: bool = False
|
42 |
+
fine_tune: bool = False
|
43 |
+
load_optimizer: bool = True
|
44 |
+
|
45 |
+
# Checkpointing:
|
46 |
+
# ------------------------------------------------------------
|
47 |
+
save_every: int = 2500
|
48 |
+
save: str = None
|
49 |
+
load: str = None
|
50 |
+
|
51 |
+
# Data:
|
52 |
+
# ------------------------------------------------------------
|
53 |
+
train_dataset_name: str = "conceptual_captions"
|
54 |
+
eval_dataset_name: str = "/data/conceptual_captions"
|
55 |
+
train_dataset_dir: str = "/data/coco_data"
|
56 |
+
eval_dataset_dir: str = "/data/coco_data"
|
57 |
+
eval_dataset_pct: float = 0.1
|
58 |
+
|
59 |
+
# Model architecture:
|
60 |
+
# ------------------------------------------------------------
|
61 |
+
encoder_name: str = "clip"
|
62 |
+
tokenizer_name: str = "gpt2"
|
63 |
+
lm_name: str = "EleutherAI/gpt-j-6B"
|
64 |
+
image_seq_len: int = 2
|
65 |
+
pretrained_img_encoder: bool = False
|
66 |
+
seq_len: int = None
|
67 |
+
|
68 |
+
# Layer Freezing settings:
|
69 |
+
# ------------------------------------------------------------
|
70 |
+
freeze_lm: bool = True
|
71 |
+
freeze_img_encoder: bool = True
|
72 |
+
|
73 |
+
image_embed_dropout_prob: float = 0.0
|
74 |
+
use_image_embed_layernorm: bool = False
|
75 |
+
|
76 |
+
# Adapter settings:
|
77 |
+
# ------------------------------------------------------------
|
78 |
+
adapter_config: dict = None
|
79 |
+
|
80 |
+
# Classification Finetuning settings:
|
81 |
+
# ------------------------------------------------------------
|
82 |
+
class_dict: dict = None # {num_classes: .., ckpt_path: .., classifier_type:, .., interface_type: .., interface_position: .., freeze_model: ..}
|
83 |
+
|
84 |
+
# Logging settings:
|
85 |
+
# ------------------------------------------------------------
|
86 |
+
name: str = None # name, just used for wandb logging
|
87 |
+
log_every: int = 1
|
88 |
+
wandb_project: str = "magma"
|
89 |
+
|
90 |
+
def print(self):
|
91 |
+
if is_main():
|
92 |
+
print("-" * 100)
|
93 |
+
pprint(self.__dict__, indent=4)
|
94 |
+
print("-" * 100)
|
95 |
+
|
96 |
+
def __post_init__(self):
|
97 |
+
self.is_classifier = self.class_dict is not None
|
98 |
+
if self.adapter_config is None:
|
99 |
+
self.adapter_config = {}
|
100 |
+
|
101 |
+
# Deepspeed Settings:
|
102 |
+
# ------------------------------------------------------------
|
103 |
+
if self.lr_decay_iters is None:
|
104 |
+
self.lr_scheduler = "WarmupLR"
|
105 |
+
self.scheduler_dict = {
|
106 |
+
"type": self.lr_scheduler,
|
107 |
+
"params": {
|
108 |
+
"warmup_min_lr": self.min_lr,
|
109 |
+
"warmup_max_lr": self.lr,
|
110 |
+
"warmup_num_steps": self.warmup_num_steps,
|
111 |
+
},
|
112 |
+
}
|
113 |
+
else:
|
114 |
+
self.lr_scheduler = "WarmupDecayLR"
|
115 |
+
self.scheduler_dict = {
|
116 |
+
"type": self.lr_scheduler,
|
117 |
+
"params": {
|
118 |
+
"total_num_steps": self.lr_decay_iters,
|
119 |
+
"warmup_min_lr": self.min_lr,
|
120 |
+
"warmup_max_lr": self.lr,
|
121 |
+
"warmup_num_steps": self.warmup_num_steps,
|
122 |
+
},
|
123 |
+
}
|
124 |
+
self.deepspeed_config_params = {
|
125 |
+
"train_batch_size": self.batch_size,
|
126 |
+
"gradient_accumulation_steps": self.gradient_accumulation_steps,
|
127 |
+
"gradient_clipping": self.gradient_clipping,
|
128 |
+
"fp16": {"enabled": True, "loss_scale_window": 250},
|
129 |
+
"scheduler": self.scheduler_dict,
|
130 |
+
"zero_optimization": {
|
131 |
+
"stage": self.zero_stage,
|
132 |
+
"load_from_fp32_weights": False,
|
133 |
+
},
|
134 |
+
}
|
135 |
+
|
136 |
+
if self.name is None:
|
137 |
+
self.name = str(uuid.uuid4())[:8]
|
138 |
+
|
139 |
+
@classmethod
|
140 |
+
def from_yml(cls, path):
|
141 |
+
return cls(**load_config(path))
|
142 |
+
|
143 |
+
def to_dict(self):
|
144 |
+
return asdict(self)
|
magma/datasets/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .dataset import (
|
2 |
+
ImgCptDataset,
|
3 |
+
collate_fn,
|
4 |
+
)
|
5 |
+
|
magma/datasets/convert_datasets.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from PIL import UnidentifiedImageError
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
from pathlib import Path
|
6 |
+
from tqdm import tqdm
|
7 |
+
import shutil
|
8 |
+
|
9 |
+
|
10 |
+
def save_to_jsons(data_list, target_dir, starting_idx=0):
|
11 |
+
pbar = tqdm(
|
12 |
+
enumerate(data_list), desc=f"saving {len(data_list)} jsons to {str(target_dir)}"
|
13 |
+
)
|
14 |
+
for k, data in pbar:
|
15 |
+
filename = Path(target_dir) / Path(f"{k+starting_idx}.json")
|
16 |
+
with open(filename, "w") as f:
|
17 |
+
json.dump(data, f)
|
18 |
+
|
19 |
+
return None
|
20 |
+
|
21 |
+
|
22 |
+
def save_images(img_list, target_dir, mode="mv"):
|
23 |
+
for img_path in tqdm(
|
24 |
+
img_list,
|
25 |
+
desc=f"saving {len(img_list)} images (mode={mode}) to {str(target_dir)}",
|
26 |
+
):
|
27 |
+
if mode == "mv":
|
28 |
+
shutil.move(img_path, target_dir)
|
29 |
+
elif mode == "cp":
|
30 |
+
shutil.copy(img_path, target_dir)
|
31 |
+
|
32 |
+
|
33 |
+
def convert_dataset(
|
34 |
+
data_dir,
|
35 |
+
dir_size=10000,
|
36 |
+
hash_fn=None,
|
37 |
+
mode="mv",
|
38 |
+
ds_iterator=None,
|
39 |
+
):
|
40 |
+
"""
|
41 |
+
Builds a dataset directory in our standard format. ds_iterator should return data of the form
|
42 |
+
image_path, {"captions": [...], "metadata": {...}, }, where image_path should be a Path object, captions should map to a list of strings
|
43 |
+
and metadata can contain any custom data about the image. If a hash_fn is specified (such as phash), the image hash gets saved in metadata.
|
44 |
+
"""
|
45 |
+
|
46 |
+
data_dir = Path(data_dir)
|
47 |
+
|
48 |
+
# folders for images and corresponding data which is stored in a json file for each image
|
49 |
+
os.makedirs(data_dir / "images", exist_ok=True)
|
50 |
+
os.makedirs(data_dir / "image_data", exist_ok=True)
|
51 |
+
|
52 |
+
img_data_list = []
|
53 |
+
img_path_list = []
|
54 |
+
save_img_dir = data_dir / "images" / "0"
|
55 |
+
save_data_dir = data_dir / "image_data" / "0"
|
56 |
+
num_img_dirs = 0
|
57 |
+
|
58 |
+
# save the new locations of all img files in case some datafiles point to the same image
|
59 |
+
new_img_locations = {}
|
60 |
+
|
61 |
+
pbar = tqdm(
|
62 |
+
enumerate(ds_iterator),
|
63 |
+
desc="converting dataset to standard format...",
|
64 |
+
)
|
65 |
+
|
66 |
+
for k, (img_path, data) in pbar:
|
67 |
+
img_cpt_data = {}
|
68 |
+
# get img data
|
69 |
+
img_cpt_data.update(data)
|
70 |
+
|
71 |
+
if str(img_path) in new_img_locations.keys():
|
72 |
+
# if filename is in the dictionary, it already has a new location
|
73 |
+
new_img_path = new_img_locations[str(img_path)]["new_img_path"]
|
74 |
+
img_cpt_data["image_path"] = new_img_path
|
75 |
+
if hash_fn is not None:
|
76 |
+
img_cpt_data["metadata"]["image_hash"] = new_img_locations[
|
77 |
+
str(img_path)
|
78 |
+
]["hash"]
|
79 |
+
else:
|
80 |
+
# if file exists in the old location, it will get moved to a new directory
|
81 |
+
new_img_path = f"images/{save_img_dir.name}/{img_path.name}"
|
82 |
+
img_cpt_data["image_path"] = new_img_path
|
83 |
+
new_img_locations[str(img_path)] = {"new_img_path": new_img_path}
|
84 |
+
# original location is saved an later saved to the new directory
|
85 |
+
img_path_list.append(img_path)
|
86 |
+
|
87 |
+
# if given, apply hash fn
|
88 |
+
if hash_fn is not None:
|
89 |
+
try:
|
90 |
+
img = Image.open(img_path).convert("RGB")
|
91 |
+
hash_str = str(hash_fn(img))
|
92 |
+
img_cpt_data["metadata"]["image_hash"] = hash_str
|
93 |
+
# save hash so it does not have to be recomputed
|
94 |
+
new_img_locations[str(img_path)]["hash"] = hash_str
|
95 |
+
except (UnidentifiedImageError, FileNotFoundError):
|
96 |
+
print("Warning: corrupted or non-existent Image")
|
97 |
+
|
98 |
+
img_data_list.append(img_cpt_data)
|
99 |
+
|
100 |
+
# save images in specified images folder (maximum of dir_size images per folder)
|
101 |
+
if (len(img_path_list) % dir_size == 0 and len(img_path_list) > 0) or (
|
102 |
+
k == len(ds_iterator) - 1
|
103 |
+
):
|
104 |
+
os.makedirs(save_img_dir, exist_ok=True)
|
105 |
+
save_images(img_path_list, save_img_dir, mode=mode)
|
106 |
+
img_path_list = []
|
107 |
+
num_img_dirs += 1
|
108 |
+
save_img_dir = data_dir / "images" / f"{num_img_dirs}/"
|
109 |
+
|
110 |
+
# save jdon data in specified image_data folder with consecutive labeling of the json files
|
111 |
+
if ((k + 1) % dir_size == 0) or (k == len(ds_iterator) - 1):
|
112 |
+
os.makedirs(save_data_dir, exist_ok=True)
|
113 |
+
save_to_jsons(
|
114 |
+
img_data_list, save_data_dir, starting_idx=max(k + 1 - dir_size, 0)
|
115 |
+
)
|
116 |
+
# empty path and data lists and update save directories for next saving step
|
117 |
+
img_data_list = []
|
118 |
+
save_data_dir = data_dir / "image_data" / f"{int((k+1)/dir_size)}/"
|
magma/datasets/dataset.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
from PIL import Image
|
4 |
+
from PIL.Image import Image as img
|
5 |
+
from PIL.Image import DecompressionBombError
|
6 |
+
from PIL import UnidentifiedImageError
|
7 |
+
import json
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
from tqdm import tqdm
|
11 |
+
from typing import List, Tuple, Generator
|
12 |
+
import random
|
13 |
+
from multiprocessing import Pool, cpu_count
|
14 |
+
|
15 |
+
from PIL import Image
|
16 |
+
from torch.utils.data import Dataset
|
17 |
+
from typing import Tuple
|
18 |
+
from torchtyping import TensorType
|
19 |
+
import traceback
|
20 |
+
|
21 |
+
|
22 |
+
def read_jsonl(filename: str) -> Generator[List, None, None]:
|
23 |
+
"""
|
24 |
+
Iterator over data from a jsonl file
|
25 |
+
"""
|
26 |
+
with open(filename) as file:
|
27 |
+
for line in file:
|
28 |
+
yield json.loads(line.rstrip("\n|\r"))
|
29 |
+
|
30 |
+
|
31 |
+
def read_img_captions(filename: str) -> List[Tuple[str, str]]:
|
32 |
+
"""
|
33 |
+
Yields image_path, image_caption from cc jsonl files
|
34 |
+
"""
|
35 |
+
img_captions = []
|
36 |
+
for item in read_jsonl(filename):
|
37 |
+
if not "N/A" in item[-2:]:
|
38 |
+
img_captions.append((item[-1], item[-2]))
|
39 |
+
return img_captions
|
40 |
+
|
41 |
+
|
42 |
+
def load_json(filename):
|
43 |
+
try:
|
44 |
+
with open(filename) as f:
|
45 |
+
return json.load(f)
|
46 |
+
except Exception:
|
47 |
+
print(f"ERROR: Error loading json file {filename}")
|
48 |
+
traceback.print_exc()
|
49 |
+
|
50 |
+
|
51 |
+
def _read_image_data(data_dir):
|
52 |
+
image_data = []
|
53 |
+
img_data_dir = data_dir / "image_data"
|
54 |
+
paths = _load_paths(data_dir)
|
55 |
+
pbar = tqdm(
|
56 |
+
paths,
|
57 |
+
desc=f"loading dataset from {str(data_dir)}",
|
58 |
+
)
|
59 |
+
# read data with multiprocessing
|
60 |
+
with Pool(cpu_count()) as pool:
|
61 |
+
for img_data in pool.imap(load_json, pbar):
|
62 |
+
if img_data is not None:
|
63 |
+
image_data.append(img_data)
|
64 |
+
return image_data
|
65 |
+
|
66 |
+
|
67 |
+
def _load_paths(data_dir, sort=True):
|
68 |
+
paths = []
|
69 |
+
img_data_dir = data_dir / "image_data"
|
70 |
+
for p in tqdm(
|
71 |
+
Path(img_data_dir).glob("*/*.json"),
|
72 |
+
desc=f"loading dataset paths from {str(data_dir)}",
|
73 |
+
):
|
74 |
+
paths.append(p)
|
75 |
+
return sorted(paths)
|
76 |
+
|
77 |
+
|
78 |
+
class LazyLoader:
|
79 |
+
def __init__(self, data_dir):
|
80 |
+
self.paths = _load_paths(data_dir)
|
81 |
+
|
82 |
+
def __len__(self):
|
83 |
+
return len(self.paths)
|
84 |
+
|
85 |
+
def __getitem__(self, idx):
|
86 |
+
data = load_json(self.paths[idx])
|
87 |
+
if data is None:
|
88 |
+
return self[random.randint(0, len(self) - 1)]
|
89 |
+
return data
|
90 |
+
|
91 |
+
|
92 |
+
class ImgCptDataset(Dataset):
|
93 |
+
"""
|
94 |
+
Dataset which loads image caption data from our standard format and transforms them into tensors that can be input to the model.
|
95 |
+
Images are expected to be stored in data_dir/images, image data in data_dir/image_data and each data item is a json file with format {"image_path": img_path, "captions": [caption1, caption2,...], "metadata":{...}}
|
96 |
+
"""
|
97 |
+
|
98 |
+
def __init__(
|
99 |
+
self, data_dir, tokenizer, transforms, seq_len=2048, load_data_in_memory=False
|
100 |
+
):
|
101 |
+
self.data_dir = Path(data_dir)
|
102 |
+
self.tokenizer = tokenizer
|
103 |
+
self.transforms = transforms
|
104 |
+
self.seq_len = seq_len
|
105 |
+
self.load_data_in_memory = load_data_in_memory
|
106 |
+
if self.load_data_in_memory:
|
107 |
+
self.data = _read_image_data(self.data_dir)
|
108 |
+
else:
|
109 |
+
self.data = LazyLoader(self.data_dir)
|
110 |
+
|
111 |
+
def __len__(self):
|
112 |
+
return len(self.data)
|
113 |
+
|
114 |
+
def __getitem__(
|
115 |
+
self, idx
|
116 |
+
) -> Tuple[TensorType["b", "c", "h", "w"], TensorType["b", "s"]]:
|
117 |
+
img_data = self.data[idx]
|
118 |
+
try:
|
119 |
+
try:
|
120 |
+
img_path = self.data_dir / img_data["image_path"]
|
121 |
+
except KeyError as e:
|
122 |
+
# if no image path is found, assume path is same as .json, but .jpg
|
123 |
+
if not self.load_data_in_memory:
|
124 |
+
p = self.data.paths[idx]
|
125 |
+
img_path = (
|
126 |
+
self.data_dir
|
127 |
+
/ "images"
|
128 |
+
/ Path(p.parent).name
|
129 |
+
/ Path(p.name).with_suffix(".jpg")
|
130 |
+
)
|
131 |
+
else:
|
132 |
+
raise e
|
133 |
+
img = Image.open(img_path)
|
134 |
+
img_tensor = self.transforms(img)
|
135 |
+
caption = random.choice(img_data["captions"])
|
136 |
+
caption_tensor = self.tokenizer.encode(
|
137 |
+
caption,
|
138 |
+
return_tensors="pt",
|
139 |
+
max_length=self.seq_len,
|
140 |
+
padding="max_length",
|
141 |
+
truncation=True,
|
142 |
+
)
|
143 |
+
return img_tensor, caption_tensor
|
144 |
+
except (
|
145 |
+
UnidentifiedImageError,
|
146 |
+
OSError,
|
147 |
+
DecompressionBombError,
|
148 |
+
IndexError,
|
149 |
+
) as e:
|
150 |
+
# return random index if image is corrupt
|
151 |
+
print(f"Warning: Could not load image {str(img_path)}")
|
152 |
+
return self[random.randint(0, len(self) - 1)]
|
153 |
+
|
154 |
+
|
155 |
+
def collate_fn(batch_data: List[Tuple[torch.Tensor, torch.Tensor]], seq_len=2048):
|
156 |
+
|
157 |
+
all_images, all_captions = list(
|
158 |
+
zip(*batch_data)
|
159 |
+
) # [(img1, caption1), (img2, caption2), ... ] -> [(img1, img2, ... ), (caption1, caption2, ... )]
|
160 |
+
return torch.cat(all_images), torch.cat([i[:, :seq_len] for i in all_captions])
|
magma/image_encoders.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from typing import Callable, Union
|
4 |
+
from torchtyping import patch_typeguard
|
5 |
+
from einops import rearrange
|
6 |
+
import timm
|
7 |
+
import clip
|
8 |
+
from functools import partial
|
9 |
+
|
10 |
+
# ----------------------------- Utils --------------------------------------
|
11 |
+
|
12 |
+
clip.model.LayerNorm = (
|
13 |
+
nn.LayerNorm
|
14 |
+
) # we need to patch this for clip to work with deepspeed
|
15 |
+
patch_typeguard() # needed for torchtyping typechecks to work
|
16 |
+
|
17 |
+
|
18 |
+
class Lambda(torch.nn.Module):
|
19 |
+
def __init__(self, fn: Callable):
|
20 |
+
super().__init__()
|
21 |
+
assert hasattr(fn, "__call__")
|
22 |
+
self.fn = fn
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
return self.fn(x)
|
26 |
+
|
27 |
+
|
28 |
+
# ------------------------- Image encoders ----------------------------------
|
29 |
+
|
30 |
+
|
31 |
+
def nfresnet50(
|
32 |
+
device: Union[torch.device, str] = None, pretrained: bool = True
|
33 |
+
) -> nn.Module:
|
34 |
+
"""
|
35 |
+
Loads nfresnet50 model, removing the pooling layer and replacing it with
|
36 |
+
an adaptive pooling layer.
|
37 |
+
"""
|
38 |
+
encoder = torch.nn.Sequential(
|
39 |
+
*list(timm.create_model("nf_resnet50", pretrained=pretrained).children())[:-1]
|
40 |
+
)
|
41 |
+
pooling = torch.nn.AdaptiveAvgPool2d((1, 1))
|
42 |
+
encoder = torch.nn.Sequential(encoder, pooling)
|
43 |
+
if device is not None:
|
44 |
+
encoder = encoder.to(device)
|
45 |
+
return encoder
|
46 |
+
|
47 |
+
|
48 |
+
def clip_encoder(
|
49 |
+
device: Union[torch.device, str] = None, name: str = "clip",
|
50 |
+
) -> nn.Module:
|
51 |
+
"""
|
52 |
+
Loads clip's image encoder module, discarding the lm component.
|
53 |
+
|
54 |
+
If the variant is a resnet model, we also remove the attention pooling.
|
55 |
+
"""
|
56 |
+
if name in ["clip", "ViT-B/32"]:
|
57 |
+
name = "ViT-B/32"
|
58 |
+
elif name in ["clip_resnet", "RN50x4"]:
|
59 |
+
name = "RN50x4"
|
60 |
+
elif name in ["clip_resnet_large", "RN50x16"]:
|
61 |
+
name = "RN50x16"
|
62 |
+
else:
|
63 |
+
raise ValueError(f"encoder {name} not recognized")
|
64 |
+
|
65 |
+
encoder = clip.load(name, device=device)[0].visual
|
66 |
+
|
67 |
+
if device is not None:
|
68 |
+
encoder = encoder.to(device)
|
69 |
+
|
70 |
+
if "RN" in name:
|
71 |
+
# remove attention pooling
|
72 |
+
encoder.attnpool = Lambda(
|
73 |
+
partial(rearrange, pattern="b d h w -> b (h w) d")
|
74 |
+
) # remove attn pooling, just use reshaped features
|
75 |
+
|
76 |
+
return encoder
|
77 |
+
|
78 |
+
|
79 |
+
def get_image_encoder(
|
80 |
+
name: str, device: Union[torch.device, str] = None, pretrained: bool = False
|
81 |
+
) -> torch.nn.Module:
|
82 |
+
"""
|
83 |
+
Loads image encoder module
|
84 |
+
"""
|
85 |
+
if name == "nfresnet50":
|
86 |
+
encoder = nfresnet50(device=device, pretrained=pretrained)
|
87 |
+
elif "clip" in name:
|
88 |
+
encoder = clip_encoder(device=device, name=name)
|
89 |
+
else:
|
90 |
+
raise ValueError(f"image encoder {name} not recognized")
|
91 |
+
return encoder
|
magma/image_input.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from io import BytesIO
|
3 |
+
import PIL.Image as PilImage
|
4 |
+
from typing import Callable
|
5 |
+
|
6 |
+
class ImageInput():
|
7 |
+
"""Wrapper to handle image inputs both from local paths and urls
|
8 |
+
Args:
|
9 |
+
path_or_url (str): path or link to image.
|
10 |
+
"""
|
11 |
+
def __init__(self, path_or_url):
|
12 |
+
|
13 |
+
self.path_or_url = path_or_url
|
14 |
+
if self.path_or_url.startswith("http://") or self.path_or_url.startswith("https://"):
|
15 |
+
try:
|
16 |
+
response = requests.get(path_or_url)
|
17 |
+
self.pil_image = PilImage.open(BytesIO(response.content))
|
18 |
+
except:
|
19 |
+
raise Exception(f'Could not retrieve image from url:\n{self.path_or_url}')
|
20 |
+
else:
|
21 |
+
self.pil_image = PilImage.open(path_or_url)
|
22 |
+
|
23 |
+
def get_transformed_image(self, transform_fn: Callable): ## to be called internally
|
24 |
+
return transform_fn(self.pil_image)
|
magma/image_prefix.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torchtyping import TensorType
|
4 |
+
from einops import rearrange
|
5 |
+
from .image_encoders import get_image_encoder
|
6 |
+
from .config import MultimodalConfig
|
7 |
+
|
8 |
+
# ------------------------- Image prefix ----------------------------------
|
9 |
+
|
10 |
+
# for models that are fixed to a specific sequence lengths (i.e clip models with no pooling), the sequence lengths are below
|
11 |
+
ENCODER_SEQ_LENS = {
|
12 |
+
"clip_resnet": 49,
|
13 |
+
"clip_resnet_large": 144,
|
14 |
+
}
|
15 |
+
|
16 |
+
ENCODER_OUT_DIMS = {
|
17 |
+
"nfresnet50": 2048,
|
18 |
+
"clip": 512,
|
19 |
+
"clip_resnet": 2560,
|
20 |
+
"clip_resnet_large": 3072,
|
21 |
+
}
|
22 |
+
|
23 |
+
|
24 |
+
class ImagePrefix(nn.Module):
|
25 |
+
|
26 |
+
"""
|
27 |
+
Takes in a batch of images and returns a batch of embeddings of the
|
28 |
+
same dimensions as the LM's word embeddings.
|
29 |
+
|
30 |
+
:param config: MultimodalConfig object
|
31 |
+
:param out_dim: output dimension of the embedding
|
32 |
+
:param device: device to run the model on
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
config: MultimodalConfig,
|
38 |
+
out_dim: int = 2048,
|
39 |
+
device=None,
|
40 |
+
):
|
41 |
+
super().__init__()
|
42 |
+
self.device = device or torch.device(
|
43 |
+
"cuda" if torch.cuda.is_available() else "cpu"
|
44 |
+
)
|
45 |
+
self.config = config
|
46 |
+
self.encoder_type = config.encoder_name
|
47 |
+
|
48 |
+
# get image encoder backbone
|
49 |
+
self.enc = get_image_encoder(
|
50 |
+
config.encoder_name,
|
51 |
+
pretrained=config.pretrained_img_encoder,
|
52 |
+
)
|
53 |
+
self.encoder_out_dim = ENCODER_OUT_DIMS[
|
54 |
+
self.encoder_type
|
55 |
+
] # out dim for image encoder
|
56 |
+
|
57 |
+
self.out_dim = out_dim # out dim for lm
|
58 |
+
|
59 |
+
# set the out seq len to that specified in the config, or for some models, the hardcoded value
|
60 |
+
self.out_seq_len = (
|
61 |
+
config.image_seq_len
|
62 |
+
if config.encoder_name not in ENCODER_SEQ_LENS
|
63 |
+
else ENCODER_SEQ_LENS[config.encoder_name]
|
64 |
+
)
|
65 |
+
|
66 |
+
# get the output projection
|
67 |
+
proj_out_dim = (
|
68 |
+
(self.out_dim * self.out_seq_len)
|
69 |
+
if self.encoder_type not in ENCODER_SEQ_LENS
|
70 |
+
else self.out_dim
|
71 |
+
)
|
72 |
+
self.proj = nn.Linear(self.encoder_out_dim, proj_out_dim)
|
73 |
+
self.dropout = nn.Dropout(config.image_embed_dropout_prob)
|
74 |
+
self.use_layernorm = config.use_image_embed_layernorm
|
75 |
+
if self.use_layernorm:
|
76 |
+
self.ln = nn.LayerNorm(self.out_dim)
|
77 |
+
|
78 |
+
def forward(
|
79 |
+
self, x: TensorType["b", "c", "h", "w"]
|
80 |
+
) -> TensorType["b", "seq", "out_dim"]:
|
81 |
+
|
82 |
+
# pass through image encoder
|
83 |
+
logits = self.enc(x)
|
84 |
+
|
85 |
+
# remove trailing dimensions of size 1 + pass through linear
|
86 |
+
if logits.ndim == 4:
|
87 |
+
logits = rearrange(logits, "b d 1 1 -> b d")
|
88 |
+
elif logits.ndim == 3:
|
89 |
+
assert self.encoder_type in ENCODER_SEQ_LENS
|
90 |
+
else:
|
91 |
+
assert logits.ndim == 2
|
92 |
+
|
93 |
+
logits = self.proj(logits)
|
94 |
+
|
95 |
+
# reshape to desired output shape
|
96 |
+
if (
|
97 |
+
self.encoder_type not in ENCODER_SEQ_LENS
|
98 |
+
): # don't need to reshape those with fixed seq lens / no pooling
|
99 |
+
logits = rearrange(
|
100 |
+
logits, "b (s d) -> b s d", d=self.out_dim, s=self.out_seq_len
|
101 |
+
)
|
102 |
+
|
103 |
+
# pass through dropout and layer norm
|
104 |
+
logits = self.dropout(logits)
|
105 |
+
|
106 |
+
if self.use_layernorm:
|
107 |
+
logits = self.ln(logits)
|
108 |
+
|
109 |
+
return logits
|
magma/language_model.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import GPTNeoForCausalLM, AutoConfig, GPT2LMHeadModel
|
3 |
+
from .utils import print_main
|
4 |
+
from pathlib import Path
|
5 |
+
from transformers.modeling_utils import no_init_weights
|
6 |
+
|
7 |
+
LANGUAGE_MODELS = [
|
8 |
+
"gptj",
|
9 |
+
]
|
10 |
+
|
11 |
+
|
12 |
+
def gptj_config():
|
13 |
+
config = AutoConfig.from_pretrained("EleutherAI/gpt-neo-2.7B")
|
14 |
+
config.attention_layers = ["global"] * 28
|
15 |
+
config.attention_types = [["global"], 28]
|
16 |
+
config.num_layers = 28
|
17 |
+
config.num_heads = 16
|
18 |
+
config.hidden_size = 256 * config.num_heads
|
19 |
+
config.vocab_size = 50400
|
20 |
+
config.rotary = True
|
21 |
+
config.rotary_dim = 64
|
22 |
+
config.jax = True
|
23 |
+
config.gradient_checkpointing = True
|
24 |
+
return config
|
25 |
+
|
26 |
+
|
27 |
+
def get_gptj(
|
28 |
+
gradient_checkpointing: bool = True,
|
29 |
+
from_pretrained=False,
|
30 |
+
) -> torch.nn.Module:
|
31 |
+
"""
|
32 |
+
Loads GPTJ language model from HF
|
33 |
+
"""
|
34 |
+
print_main("Loading GPTJ language model...")
|
35 |
+
config = gptj_config()
|
36 |
+
config.gradient_checkpointing = gradient_checkpointing
|
37 |
+
if gradient_checkpointing:
|
38 |
+
config.use_cache = False
|
39 |
+
config.model_device = "cpu"
|
40 |
+
if from_pretrained:
|
41 |
+
raise NotImplemented("GPTJ pretrained not implemented")
|
42 |
+
else:
|
43 |
+
with no_init_weights():
|
44 |
+
model = GPTNeoForCausalLM(config=config)
|
45 |
+
return model
|
magma/magma.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from os.path import exists
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from copy import deepcopy
|
6 |
+
from typing import Literal, Optional, List
|
7 |
+
from torchtyping import TensorType
|
8 |
+
from transformers.file_utils import ModelOutput
|
9 |
+
from magma.config import MultimodalConfig
|
10 |
+
|
11 |
+
from magma.utils import get_tokenizer
|
12 |
+
from .language_model import get_gptj
|
13 |
+
from .adapters import (
|
14 |
+
Adapter,
|
15 |
+
ParallelAdapter,
|
16 |
+
AdapterWrapper,
|
17 |
+
ParallelAdapterWrapper,
|
18 |
+
)
|
19 |
+
from .image_prefix import ImagePrefix
|
20 |
+
from .sampling import generate
|
21 |
+
from .utils import build_labels, is_url, print_main, download_checkpoint
|
22 |
+
from .image_input import ImageInput
|
23 |
+
from .transforms import get_transforms
|
24 |
+
|
25 |
+
# ------------------------- Magma main class ----------------------------------
|
26 |
+
|
27 |
+
|
28 |
+
class Magma(nn.Module):
|
29 |
+
def __init__(self, config, device=None):
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
if isinstance(config, (str, Path)):
|
33 |
+
config = MultimodalConfig.from_yml(
|
34 |
+
config
|
35 |
+
) # load config from yml file if config is a string
|
36 |
+
else:
|
37 |
+
assert isinstance(config, MultimodalConfig)
|
38 |
+
|
39 |
+
self.device = device or torch.device(
|
40 |
+
"cuda" if torch.cuda.is_available() else "cpu"
|
41 |
+
)
|
42 |
+
self.config = config
|
43 |
+
self.lm = get_gptj().to(self.device)
|
44 |
+
self.seq_len = self.lm.config.max_position_embeddings
|
45 |
+
|
46 |
+
self.tokenizer = get_tokenizer("gpt2", sequence_length=self.seq_len)
|
47 |
+
|
48 |
+
self.image_token = self.tokenizer.cls_token_id
|
49 |
+
self.eos_token = self.tokenizer.eos_token_id
|
50 |
+
self.lm.resize_token_embeddings(len(self.tokenizer))
|
51 |
+
self.lm.config.pad_token_id = self.tokenizer.eos_token_id
|
52 |
+
self.word_embedding = self.lm.transformer.wte.to(device)
|
53 |
+
self.transformer = self.lm.transformer.h
|
54 |
+
|
55 |
+
# adapter settings
|
56 |
+
self.mlp_adapter_added, self.attn_adapter_added = False, False
|
57 |
+
|
58 |
+
self.image_prefix = ImagePrefix(
|
59 |
+
config=config,
|
60 |
+
out_dim=self.lm.config.hidden_size,
|
61 |
+
).to(self.device)
|
62 |
+
|
63 |
+
# might change based on the type of image encoder, so get from prefix instead of config
|
64 |
+
self.image_prefix_seq_len = self.image_prefix.out_seq_len
|
65 |
+
|
66 |
+
self.transforms = get_transforms(
|
67 |
+
config.image_size,
|
68 |
+
config.encoder_name,
|
69 |
+
input_resolution=self.image_prefix.enc.input_resolution,
|
70 |
+
)
|
71 |
+
|
72 |
+
# add adapters
|
73 |
+
if config.adapter_config:
|
74 |
+
mlp_config = deepcopy(config.adapter_config.get("mlp", None))
|
75 |
+
if mlp_config:
|
76 |
+
assert mlp_config.get("adapter_type") is not None
|
77 |
+
self.add_adapters(
|
78 |
+
location="mlp",
|
79 |
+
adapter_type=mlp_config.pop("adapter_type"),
|
80 |
+
downsample_factor=mlp_config.pop("downsample_factor", 4),
|
81 |
+
**mlp_config,
|
82 |
+
)
|
83 |
+
attn_config = deepcopy(config.adapter_config.get("attention", None))
|
84 |
+
if attn_config:
|
85 |
+
assert attn_config.get("adapter_type") is not None
|
86 |
+
self.add_adapters(
|
87 |
+
location="attention",
|
88 |
+
adapter_type=attn_config.pop("adapter_type"),
|
89 |
+
**attn_config,
|
90 |
+
)
|
91 |
+
|
92 |
+
# freeze parameters
|
93 |
+
if config.freeze_lm:
|
94 |
+
for name, param in self.lm.named_parameters(): # freeze lm weights
|
95 |
+
if config.adapter_config and "adapter" in name:
|
96 |
+
param.requires_grad = True
|
97 |
+
|
98 |
+
if config.freeze_img_encoder:
|
99 |
+
for param in self.image_prefix.enc.parameters():
|
100 |
+
param.requires_grad = False
|
101 |
+
|
102 |
+
def add_adapters(
|
103 |
+
self,
|
104 |
+
downsample_factor: int = 4,
|
105 |
+
adapter_type: Literal["normal", "parallel", "scaled_parallel"] = "normal",
|
106 |
+
location: Literal["mlp", "attention"] = "mlp",
|
107 |
+
ff_attr: str = "mlp",
|
108 |
+
attn_attr: str = "attn",
|
109 |
+
**adapter_kwargs,
|
110 |
+
):
|
111 |
+
"""
|
112 |
+
Adds an adapter layer to `self` at the specified location
|
113 |
+
"""
|
114 |
+
assert adapter_type in [
|
115 |
+
"normal",
|
116 |
+
"parallel",
|
117 |
+
"scaled_parallel",
|
118 |
+
], "adapter_type must be one of 'normal', 'parallel', or 'scaled_parallel'"
|
119 |
+
assert location in [
|
120 |
+
"mlp",
|
121 |
+
"attention",
|
122 |
+
], "location must be one of 'mlp' or 'attention'"
|
123 |
+
|
124 |
+
for l in range(len(self.transformer)):
|
125 |
+
if location == "mlp":
|
126 |
+
if self.mlp_adapter_added:
|
127 |
+
raise ValueError("Adapter layer already added")
|
128 |
+
mlp = getattr(self.transformer[l], ff_attr)
|
129 |
+
if adapter_type in ["parallel", "scaled_parallel"]:
|
130 |
+
adapter_layer = ParallelAdapter(
|
131 |
+
module=mlp,
|
132 |
+
dim=self.lm.config.hidden_size,
|
133 |
+
downsample_factor=downsample_factor,
|
134 |
+
scaled=adapter_type == "scaled_parallel",
|
135 |
+
**adapter_kwargs,
|
136 |
+
)
|
137 |
+
else:
|
138 |
+
adpt = Adapter(
|
139 |
+
dim=self.lm.config.hidden_size,
|
140 |
+
downsample_factor=downsample_factor,
|
141 |
+
**adapter_kwargs,
|
142 |
+
)
|
143 |
+
adapter_layer = nn.Sequential(
|
144 |
+
*[
|
145 |
+
mlp,
|
146 |
+
adpt,
|
147 |
+
]
|
148 |
+
)
|
149 |
+
setattr(self.transformer[l], ff_attr, adapter_layer)
|
150 |
+
else:
|
151 |
+
if self.attn_adapter_added:
|
152 |
+
raise ValueError("Adapter layer already added")
|
153 |
+
attn = getattr(self.transformer[l], attn_attr)
|
154 |
+
if adapter_type in ["parallel", "scaled_parallel"]:
|
155 |
+
adapter_layer = ParallelAdapterWrapper(
|
156 |
+
module=attn,
|
157 |
+
dim=self.lm.config.hidden_size,
|
158 |
+
downsample_factor=downsample_factor,
|
159 |
+
scaled="scaled" in adapter_type,
|
160 |
+
**adapter_kwargs,
|
161 |
+
)
|
162 |
+
else:
|
163 |
+
adapter_layer = AdapterWrapper(
|
164 |
+
attn_block=attn,
|
165 |
+
dim=self.lm.config.hidden_size,
|
166 |
+
downsample_factor=downsample_factor,
|
167 |
+
**adapter_kwargs,
|
168 |
+
)
|
169 |
+
setattr(self.transformer[l], attn_attr, adapter_layer)
|
170 |
+
|
171 |
+
if location == "mlp":
|
172 |
+
self.mlp_adapter_added = True
|
173 |
+
else:
|
174 |
+
self.attn_adapter_added = True
|
175 |
+
|
176 |
+
def preprocess_inputs(self, input_list: list, embed = True) -> List[torch.Tensor]:
|
177 |
+
"""
|
178 |
+
Expects a list of strings and instances of ImageInput
|
179 |
+
Converts them into a list of tensors and then optionally runs self.embed over it
|
180 |
+
"""
|
181 |
+
for i in range(len(input_list)):
|
182 |
+
inp = input_list[i]
|
183 |
+
if isinstance(inp, str):
|
184 |
+
input_list[i] = self.tokenizer.encode(inp, return_tensors="pt")
|
185 |
+
elif isinstance(inp, ImageInput):
|
186 |
+
input_list[i] = inp.get_transformed_image(transform_fn = self.transforms)
|
187 |
+
else:
|
188 |
+
raise Exception(f'Invalid input type:{type(inp)}')
|
189 |
+
|
190 |
+
if embed == True:
|
191 |
+
return self.embed(input_list)
|
192 |
+
else:
|
193 |
+
return input_list
|
194 |
+
|
195 |
+
def embed(self, inputs: List[torch.Tensor]) -> TensorType["b", "s", "d"]:
|
196 |
+
"""
|
197 |
+
Embeds a list of tensors In the correct format to input into the LM (b, s, d).
|
198 |
+
For each tensor, if it's 2d assume it's text and use word embedding,
|
199 |
+
if it's 4d, assume it's an image, and use image_prefix to embed.
|
200 |
+
"""
|
201 |
+
emb_list = []
|
202 |
+
for x in inputs:
|
203 |
+
if x.ndim == 2:
|
204 |
+
x = x.to(self.device)
|
205 |
+
emb_list.append(self.word_embedding(x))
|
206 |
+
elif x.ndim == 4:
|
207 |
+
x = x.to(self.device).half()
|
208 |
+
image_embeddings = self.image_prefix(x)
|
209 |
+
emb_list.append(image_embeddings)
|
210 |
+
else:
|
211 |
+
raise ValueError(f"Expected 2d or 4d tensor, got {x.ndim}d")
|
212 |
+
return torch.cat(emb_list, dim=1)
|
213 |
+
|
214 |
+
@torch.no_grad()
|
215 |
+
def generate(
|
216 |
+
self,
|
217 |
+
embeddings: TensorType["b", "s", "d"],
|
218 |
+
max_steps: int = 100,
|
219 |
+
temperature: float = 0.7,
|
220 |
+
top_k: int = 0,
|
221 |
+
top_p: float = 0.9,
|
222 |
+
decode: bool = True,
|
223 |
+
):
|
224 |
+
"""
|
225 |
+
Generates captions for a batch of embeddings.
|
226 |
+
"""
|
227 |
+
|
228 |
+
return generate(
|
229 |
+
self,
|
230 |
+
embeddings=embeddings,
|
231 |
+
max_steps=max_steps,
|
232 |
+
temperature=temperature,
|
233 |
+
top_k=top_k,
|
234 |
+
top_p=top_p,
|
235 |
+
decode=decode,
|
236 |
+
)
|
237 |
+
|
238 |
+
def forward(
|
239 |
+
self,
|
240 |
+
images: TensorType["b", "c", "h", "w"] = None,
|
241 |
+
captions: Optional[TensorType["b", "seq"]] = None,
|
242 |
+
output_hidden_states: bool = False,
|
243 |
+
input_embeddings: TensorType["b", "s", "d"] = None,
|
244 |
+
) -> ModelOutput:
|
245 |
+
assert captions is not None, "Must provide captions in training"
|
246 |
+
assert any([i is not None for i in [images, input_embeddings]]) and not all(
|
247 |
+
[i is not None for i in [images, input_embeddings]]
|
248 |
+
), "Pass in either images, or input embeddings, not both."
|
249 |
+
assert (
|
250 |
+
captions.shape[1] == self.seq_len
|
251 |
+
), f"in training, captions should be padded to sequence length ({self.seq_len}), but are length {captions.shape[1]}"
|
252 |
+
|
253 |
+
if input_embeddings is None:
|
254 |
+
input_embeddings = self.image_prefix(images)
|
255 |
+
labels = build_labels(
|
256 |
+
input_embeddings, captions, self.eos_token, self.device
|
257 |
+
) # build labels from input_embeddings
|
258 |
+
word_embeddings = self.word_embedding(captions)
|
259 |
+
|
260 |
+
# join together
|
261 |
+
input_embeddings = torch.cat(
|
262 |
+
(
|
263 |
+
input_embeddings,
|
264 |
+
word_embeddings[:, : -input_embeddings.shape[1], :],
|
265 |
+
), # remove padding in the word embedding before concatenating
|
266 |
+
dim=1,
|
267 |
+
)
|
268 |
+
|
269 |
+
# forward joined embeddings through lm
|
270 |
+
lm_outputs = self.lm(
|
271 |
+
inputs_embeds=input_embeddings,
|
272 |
+
labels=labels,
|
273 |
+
output_hidden_states=output_hidden_states,
|
274 |
+
)
|
275 |
+
|
276 |
+
return lm_outputs
|
277 |
+
|
278 |
+
@classmethod
|
279 |
+
def from_checkpoint(cls, config_path, checkpoint_path, device = 'cpu'):
|
280 |
+
"""
|
281 |
+
Loads a model checkpoint from disk / downlods from url if not present
|
282 |
+
"""
|
283 |
+
|
284 |
+
checkpoint_url = 'https://drive.google.com/u/0/uc?id=1EiAY3IcKWmGADaLDzdG25ykQghUwza6L&export=download'
|
285 |
+
|
286 |
+
if exists(checkpoint_path) == False:
|
287 |
+
print_main(f'checkpoint: {checkpoint_path} does not exist, downloading model')
|
288 |
+
download_checkpoint(checkpoint_url = checkpoint_url, save_as = checkpoint_path)
|
289 |
+
|
290 |
+
model = cls(config = config_path)
|
291 |
+
|
292 |
+
sd = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
293 |
+
if "module" in sd.keys():
|
294 |
+
sd = sd["module"]
|
295 |
+
|
296 |
+
print_main('loading checkpoint magma')
|
297 |
+
model.load_state_dict(sd, strict=False)
|
298 |
+
print_main("magma model successfully loaded")
|
299 |
+
|
300 |
+
model.half().to(device)
|
301 |
+
return model
|
magma/sampling.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torchtyping import TensorType
|
4 |
+
from typing import Union, List
|
5 |
+
|
6 |
+
|
7 |
+
def top_p_filter(logits: TensorType[..., "vocab"], threshold: float = 0.9):
|
8 |
+
"""
|
9 |
+
Nucleus sampling
|
10 |
+
"""
|
11 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
12 |
+
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
13 |
+
|
14 |
+
sorted_indices_to_remove = cum_probs > (1 - threshold)
|
15 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
16 |
+
sorted_indices_to_remove[..., 0] = 0
|
17 |
+
|
18 |
+
sorted_logits[sorted_indices_to_remove] = float("-inf")
|
19 |
+
return sorted_logits.scatter(1, sorted_indices, sorted_logits)
|
20 |
+
|
21 |
+
|
22 |
+
def top_k_filter(logits, k):
|
23 |
+
"""
|
24 |
+
Top K sampling
|
25 |
+
"""
|
26 |
+
assert k > 0
|
27 |
+
val, ind = torch.topk(logits, k)
|
28 |
+
probs = torch.full_like(logits, float("-inf"))
|
29 |
+
probs.scatter_(1, ind, val)
|
30 |
+
return probs
|
31 |
+
|
32 |
+
|
33 |
+
def remove_tokens_after_eos(tensor, eos_token, image_token):
|
34 |
+
# any tokens after and end of sequence token is produced are also set to the eos token, and removed
|
35 |
+
eos_index = (tensor == eos_token).nonzero()
|
36 |
+
if eos_index.any():
|
37 |
+
tensor[eos_index[0] :] = eos_token
|
38 |
+
|
39 |
+
tensor = tensor.tolist()
|
40 |
+
return [i for i in tensor if (not i == image_token) and (not i == eos_token)]
|
41 |
+
|
42 |
+
|
43 |
+
@torch.no_grad()
|
44 |
+
def generate(
|
45 |
+
model: "Magma",
|
46 |
+
embeddings: TensorType["b", "s", "d"],
|
47 |
+
max_steps: int = 100,
|
48 |
+
temperature: float = 0.7,
|
49 |
+
top_k: int = 0,
|
50 |
+
top_p: float = 0.9,
|
51 |
+
eos_token: int = None,
|
52 |
+
decode: bool = True,
|
53 |
+
) -> Union[List[str], TensorType["b", "s"]]:
|
54 |
+
"""
|
55 |
+
Generates captions for a batch of embeddings.
|
56 |
+
|
57 |
+
:param model: The model to use for generation.
|
58 |
+
:param embeddings: The embeddings to generate captions for.
|
59 |
+
:param max_steps: The maximum number of steps to generate captions for.
|
60 |
+
:param temperature: The temperature to use for sampling.
|
61 |
+
:param top_k: value for top k sampling. If 0, no sampling will be used.
|
62 |
+
:param top_p: value for top p sampling. If 0, no sampling will be used.
|
63 |
+
:param eos_token: The token to use for end of sequence.
|
64 |
+
:param decode: Whether to decode the output into text, or return the raw tokens.
|
65 |
+
"""
|
66 |
+
|
67 |
+
# init values
|
68 |
+
eos_token = eos_token or model.eos_token
|
69 |
+
was_training = model.training
|
70 |
+
model.eval()
|
71 |
+
b, s, _ = embeddings.shape
|
72 |
+
past_key_values = None
|
73 |
+
|
74 |
+
# init output with image tokens
|
75 |
+
out = torch.zeros((b, s), dtype=torch.long).to(model.device) + model.image_token
|
76 |
+
|
77 |
+
# do sampling
|
78 |
+
for i in range(max_steps):
|
79 |
+
if i == 0:
|
80 |
+
# initial input
|
81 |
+
outputs = model.lm(
|
82 |
+
inputs_embeds=embeddings,
|
83 |
+
use_cache=True,
|
84 |
+
past_key_values=past_key_values,
|
85 |
+
)
|
86 |
+
else:
|
87 |
+
# now caching past k/v so we can use only the last token
|
88 |
+
outputs = model.lm(
|
89 |
+
input_ids=out[:, -1:], use_cache=True, past_key_values=past_key_values
|
90 |
+
)
|
91 |
+
|
92 |
+
logits = outputs.logits[:, -1, :].float()
|
93 |
+
past_key_values = outputs.past_key_values
|
94 |
+
|
95 |
+
# filter / temperature sample
|
96 |
+
if temperature == 0.0:
|
97 |
+
next_token = torch.argmax(logits, dim=-1)
|
98 |
+
else:
|
99 |
+
if top_k > 0:
|
100 |
+
logits = top_k_filter(logits, k=top_k)
|
101 |
+
if top_p > 0:
|
102 |
+
logits = top_p_filter(logits, threshold=top_p)
|
103 |
+
|
104 |
+
probs = F.softmax(logits / temperature, dim=-1)
|
105 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
106 |
+
|
107 |
+
out = torch.cat((out, next_token), dim=-1)
|
108 |
+
|
109 |
+
if eos_token is not None and (next_token == eos_token).all():
|
110 |
+
break
|
111 |
+
|
112 |
+
if decode:
|
113 |
+
captions = []
|
114 |
+
for b in out:
|
115 |
+
b = remove_tokens_after_eos(b, eos_token, model.image_token)
|
116 |
+
caption = model.tokenizer.decode(b)
|
117 |
+
captions.append(caption)
|
118 |
+
out = captions
|
119 |
+
|
120 |
+
model.train(was_training)
|
121 |
+
return out
|
magma/train_loop.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from tqdm import tqdm
|
3 |
+
from .utils import reduce_losses, to_cuda_half
|
4 |
+
from torchvision.utils import make_grid
|
5 |
+
|
6 |
+
|
7 |
+
def train_step(config, train_loader, model_engine):
|
8 |
+
losses = []
|
9 |
+
|
10 |
+
for _ in range(config.gradient_accumulation_steps):
|
11 |
+
images, captions = next(train_loader)
|
12 |
+
images, captions = images.half().cuda(), captions.cuda()
|
13 |
+
if config.run_blind:
|
14 |
+
images = torch.zeros_like(images)
|
15 |
+
outputs = model_engine(images, captions)
|
16 |
+
loss = outputs.loss
|
17 |
+
losses.append(loss)
|
18 |
+
model_engine.backward(loss)
|
19 |
+
model_engine.step()
|
20 |
+
|
21 |
+
return reduce_losses(torch.mean(torch.stack(losses))).item()
|
22 |
+
|
23 |
+
|
24 |
+
def train_step_classification(config, train_loader, model_engine, return_accuracy=True):
|
25 |
+
losses = []
|
26 |
+
if return_accuracy:
|
27 |
+
accuracies = []
|
28 |
+
for _ in range(config.gradient_accumulation_steps):
|
29 |
+
images, captions, class_labels = next(train_loader)
|
30 |
+
images, captions, class_labels = to_cuda_half(images, captions, class_labels)
|
31 |
+
if config.run_blind:
|
32 |
+
images = torch.zeros_like(images)
|
33 |
+
loss, logits = model_engine(images, captions, class_labels)
|
34 |
+
losses.append(loss)
|
35 |
+
if return_accuracy:
|
36 |
+
argmax_pred = logits.argmax(dim=-1)
|
37 |
+
accuracies.append((argmax_pred == class_labels).float().mean())
|
38 |
+
model_engine.backward(loss)
|
39 |
+
model_engine.step()
|
40 |
+
|
41 |
+
loss_reduced = reduce_losses(torch.mean(torch.stack(losses))).item()
|
42 |
+
if return_accuracy:
|
43 |
+
accuracy_reduced = reduce_losses(torch.mean(torch.stack(accuracies))).item()
|
44 |
+
return loss_reduced, accuracy_reduced
|
45 |
+
return loss_reduced
|
46 |
+
|
47 |
+
|
48 |
+
def eval_step(config, eval_loader, model_engine):
|
49 |
+
losses = []
|
50 |
+
|
51 |
+
for i in tqdm(range(config.eval_steps), "evaluating..."):
|
52 |
+
images, captions = next(eval_loader)
|
53 |
+
images, captions = images.half().cuda(), captions.cuda()
|
54 |
+
if config.run_blind:
|
55 |
+
images = torch.zeros_like(images)
|
56 |
+
outputs = model_engine(images, captions)
|
57 |
+
loss = outputs.loss
|
58 |
+
losses.append(loss)
|
59 |
+
|
60 |
+
return reduce_losses(torch.mean(torch.stack(losses))).item()
|
61 |
+
|
62 |
+
|
63 |
+
def eval_step_classification(config, train_loader, model_engine, return_accuracy=True):
|
64 |
+
losses = []
|
65 |
+
if return_accuracy:
|
66 |
+
accuracies = []
|
67 |
+
for _ in range(config.gradient_accumulation_steps):
|
68 |
+
images, captions, class_labels = next(train_loader)
|
69 |
+
images, captions, class_labels = to_cuda_half(images, captions, class_labels)
|
70 |
+
if config.run_blind:
|
71 |
+
images = torch.zeros_like(images)
|
72 |
+
loss, logits = model_engine(images, captions, class_labels)
|
73 |
+
losses.append(loss)
|
74 |
+
if return_accuracy:
|
75 |
+
argmax_pred = logits.argmax(dim=-1)
|
76 |
+
accuracies.append((argmax_pred == class_labels).float().mean())
|
77 |
+
|
78 |
+
loss_reduced = reduce_losses(torch.mean(torch.stack(losses))).item()
|
79 |
+
if return_accuracy:
|
80 |
+
accuracy_reduced = reduce_losses(torch.mean(torch.stack(accuracies))).item()
|
81 |
+
return loss_reduced, accuracy_reduced
|
82 |
+
return loss_reduced
|
83 |
+
|
84 |
+
|
85 |
+
def inference_step(config, eval_loader, model_engine):
|
86 |
+
images, _ = next(eval_loader)
|
87 |
+
images = images.half().cuda()
|
88 |
+
if config.run_blind:
|
89 |
+
images = torch.zeros_like(images)
|
90 |
+
captions = model_engine(
|
91 |
+
images, captions=None, inference=True
|
92 |
+
) # [caption1, caption2, ... b]
|
93 |
+
width = min(2, images.shape[0])
|
94 |
+
image_grid = make_grid(images[:width])
|
95 |
+
caption = ""
|
96 |
+
for i in range(width):
|
97 |
+
caption += f"Caption {i}: \n{captions[i]}\n"
|
98 |
+
return image_grid, caption
|
magma/transforms.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision import transforms as T
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from PIL import ImageOps
|
4 |
+
import PIL
|
5 |
+
import random
|
6 |
+
|
7 |
+
|
8 |
+
def pad_to_size(x, size=256):
|
9 |
+
delta_w = size - x.size[0]
|
10 |
+
delta_h = size - x.size[1]
|
11 |
+
padding = (
|
12 |
+
delta_w // 2,
|
13 |
+
delta_h // 2,
|
14 |
+
delta_w - (delta_w // 2),
|
15 |
+
delta_h - (delta_h // 2),
|
16 |
+
)
|
17 |
+
new_im = ImageOps.expand(x, padding)
|
18 |
+
return new_im
|
19 |
+
|
20 |
+
|
21 |
+
def pad_to_size_tensor(x, size=256):
|
22 |
+
offset_dim_1 = size - x.shape[1]
|
23 |
+
offset_dim_2 = size - x.shape[2]
|
24 |
+
|
25 |
+
padding_dim_1 = max(offset_dim_1 // 2, 0)
|
26 |
+
padding_dim_2 = max(offset_dim_2 // 2, 0)
|
27 |
+
|
28 |
+
if offset_dim_1 % 2 == 0:
|
29 |
+
pad_tuple_1 = (padding_dim_1, padding_dim_1)
|
30 |
+
else:
|
31 |
+
pad_tuple_1 = (padding_dim_1 + 1, padding_dim_1)
|
32 |
+
|
33 |
+
if offset_dim_2 % 2 == 0:
|
34 |
+
pad_tuple_2 = (padding_dim_2, padding_dim_2)
|
35 |
+
else:
|
36 |
+
pad_tuple_2 = (padding_dim_2 + 1, padding_dim_2)
|
37 |
+
|
38 |
+
padded = F.pad(x, pad=(*pad_tuple_2, *pad_tuple_1, 0, 0))
|
39 |
+
return padded
|
40 |
+
|
41 |
+
|
42 |
+
class RandCropResize(object):
|
43 |
+
|
44 |
+
"""
|
45 |
+
Randomly crops, then randomly resizes, then randomly crops again, an image. Mirroring the augmentations from https://arxiv.org/abs/2102.12092
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(self, target_size):
|
49 |
+
self.target_size = target_size
|
50 |
+
|
51 |
+
def __call__(self, img):
|
52 |
+
img = pad_to_size(img, self.target_size)
|
53 |
+
d_min = min(img.size)
|
54 |
+
img = T.RandomCrop(size=d_min)(img)
|
55 |
+
t_min = min(d_min, round(9 / 8 * self.target_size))
|
56 |
+
t_max = min(d_min, round(12 / 8 * self.target_size))
|
57 |
+
t = random.randint(t_min, t_max + 1)
|
58 |
+
img = T.Resize(t)(img)
|
59 |
+
if min(img.size) < 256:
|
60 |
+
img = T.Resize(256)(img)
|
61 |
+
return T.RandomCrop(size=self.target_size)(img)
|
62 |
+
|
63 |
+
|
64 |
+
def get_transforms(
|
65 |
+
image_size, encoder_name, input_resolution=None, use_extra_transforms=False
|
66 |
+
):
|
67 |
+
if "clip" in encoder_name:
|
68 |
+
assert input_resolution is not None
|
69 |
+
return clip_preprocess(input_resolution)
|
70 |
+
|
71 |
+
base_transforms = [
|
72 |
+
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
|
73 |
+
RandCropResize(image_size),
|
74 |
+
T.RandomHorizontalFlip(p=0.5),
|
75 |
+
]
|
76 |
+
if use_extra_transforms:
|
77 |
+
extra_transforms = [T.ColorJitter(0.1, 0.1, 0.1, 0.05)]
|
78 |
+
base_transforms += extra_transforms
|
79 |
+
base_transforms += [
|
80 |
+
T.ToTensor(),
|
81 |
+
maybe_add_batch_dim,
|
82 |
+
]
|
83 |
+
base_transforms = T.Compose(base_transforms)
|
84 |
+
return base_transforms
|
85 |
+
|
86 |
+
|
87 |
+
def maybe_add_batch_dim(t):
|
88 |
+
if t.ndim == 3:
|
89 |
+
return t.unsqueeze(0)
|
90 |
+
else:
|
91 |
+
return t
|
92 |
+
|
93 |
+
|
94 |
+
def pad_img(desired_size):
|
95 |
+
def fn(im):
|
96 |
+
old_size = im.size # old_size[0] is in (width, height) format
|
97 |
+
|
98 |
+
ratio = float(desired_size) / max(old_size)
|
99 |
+
new_size = tuple([int(x * ratio) for x in old_size])
|
100 |
+
|
101 |
+
im = im.resize(new_size, PIL.Image.ANTIALIAS)
|
102 |
+
# create a new image and paste the resized on it
|
103 |
+
|
104 |
+
new_im = PIL.Image.new("RGB", (desired_size, desired_size))
|
105 |
+
new_im.paste(
|
106 |
+
im, ((desired_size - new_size[0]) // 2, (desired_size - new_size[1]) // 2)
|
107 |
+
)
|
108 |
+
|
109 |
+
return new_im
|
110 |
+
|
111 |
+
return fn
|
112 |
+
|
113 |
+
|
114 |
+
def crop_or_pad(n_px, pad=False):
|
115 |
+
if pad:
|
116 |
+
return pad_img(n_px)
|
117 |
+
else:
|
118 |
+
return T.CenterCrop(n_px)
|
119 |
+
|
120 |
+
|
121 |
+
def clip_preprocess(n_px, use_pad=False):
|
122 |
+
return T.Compose(
|
123 |
+
[
|
124 |
+
T.Resize(n_px, interpolation=T.InterpolationMode.BICUBIC),
|
125 |
+
crop_or_pad(n_px, pad=use_pad),
|
126 |
+
lambda image: image.convert("RGB"),
|
127 |
+
T.ToTensor(),
|
128 |
+
maybe_add_batch_dim,
|
129 |
+
T.Normalize(
|
130 |
+
(0.48145466, 0.4578275, 0.40821073),
|
131 |
+
(0.26862954, 0.26130258, 0.27577711),
|
132 |
+
),
|
133 |
+
]
|
134 |
+
)
|
magma/utils.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch.distributed as dist
|
3 |
+
from transformers import GPT2TokenizerFast
|
4 |
+
import deepspeed
|
5 |
+
from pathlib import Path
|
6 |
+
import wandb
|
7 |
+
import os
|
8 |
+
import yaml
|
9 |
+
import torch
|
10 |
+
from collections import defaultdict
|
11 |
+
from torchtyping import TensorType
|
12 |
+
import gdown
|
13 |
+
|
14 |
+
|
15 |
+
def is_main():
|
16 |
+
if dist.is_initialized():
|
17 |
+
return dist.get_rank() == 0
|
18 |
+
return True
|
19 |
+
|
20 |
+
|
21 |
+
def print_main(*msg):
|
22 |
+
if is_main():
|
23 |
+
print(*msg)
|
24 |
+
|
25 |
+
|
26 |
+
def reduce_losses(losses):
|
27 |
+
"""Reduce a tensor of losses across all GPUs."""
|
28 |
+
if dist.is_initialized():
|
29 |
+
losses = losses.detach().clone()
|
30 |
+
# We use `all_reduce` because it is better supported than `reduce`
|
31 |
+
dist.all_reduce(losses, dist.ReduceOp.SUM)
|
32 |
+
return losses / dist.get_world_size()
|
33 |
+
else:
|
34 |
+
return losses
|
35 |
+
|
36 |
+
|
37 |
+
def cycle(loader):
|
38 |
+
while True:
|
39 |
+
for data in loader:
|
40 |
+
yield data
|
41 |
+
|
42 |
+
|
43 |
+
def get_tokenizer(name="gpt2", sequence_length=2048):
|
44 |
+
"""
|
45 |
+
Gets tokenizer for LM
|
46 |
+
"""
|
47 |
+
if name == "gpt2":
|
48 |
+
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
49 |
+
tokenizer.pad_token_id = tokenizer.eos_token
|
50 |
+
tokenizer.padding_side = "right"
|
51 |
+
tokenizer.model_max_length = sequence_length
|
52 |
+
# setup lm settings
|
53 |
+
tokenizer.add_special_tokens(
|
54 |
+
{"cls_token": "<|image|>"}
|
55 |
+
) # add special image token to tokenizer
|
56 |
+
else:
|
57 |
+
raise ValueError(f"Tokenizer {name} not recognized")
|
58 |
+
return tokenizer
|
59 |
+
|
60 |
+
|
61 |
+
def parse_args():
|
62 |
+
parser = argparse.ArgumentParser()
|
63 |
+
parser.add_argument(
|
64 |
+
"--config", type=str, required=False, help="path to your training config"
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
"--local_rank",
|
68 |
+
type=int,
|
69 |
+
default=-1,
|
70 |
+
help="local rank passed from distributed launcher",
|
71 |
+
)
|
72 |
+
deepspeed.add_config_arguments(parser)
|
73 |
+
|
74 |
+
args = parser.parse_args()
|
75 |
+
args.deepspeed = True
|
76 |
+
return args
|
77 |
+
|
78 |
+
|
79 |
+
def wandb_log(*args, **kwargs):
|
80 |
+
if is_main():
|
81 |
+
wandb.log(*args, **kwargs)
|
82 |
+
|
83 |
+
|
84 |
+
def wandb_init(*args, **kwargs):
|
85 |
+
if is_main():
|
86 |
+
wandb.init(*args, **kwargs)
|
87 |
+
|
88 |
+
|
89 |
+
def save_model(model_engine, save_dir, global_step, config=None):
|
90 |
+
os.makedirs(save_dir, exist_ok=True)
|
91 |
+
if config is not None:
|
92 |
+
config = config.to_dict()
|
93 |
+
with open(str(Path(save_dir) / "config.yml"), "w") as f:
|
94 |
+
yaml.dump(config, f, default_flow_style=False)
|
95 |
+
sd = {"global_step": global_step, "config": config}
|
96 |
+
model_engine.save_checkpoint(save_dir, client_state=sd)
|
97 |
+
|
98 |
+
|
99 |
+
def load_model(
|
100 |
+
model_engine, load_dir, load_optimizer_states=True, load_lr_scheduler_states=True
|
101 |
+
):
|
102 |
+
"""
|
103 |
+
Loads a model from disk and returns the global step to resume from if loading was successful, otherwise returns 0
|
104 |
+
"""
|
105 |
+
try:
|
106 |
+
load_path, sd = model_engine.load_checkpoint(
|
107 |
+
load_dir,
|
108 |
+
load_optimizer_states=load_optimizer_states,
|
109 |
+
load_lr_scheduler_states=load_lr_scheduler_states,
|
110 |
+
)
|
111 |
+
except AssertionError as e:
|
112 |
+
load_path = None
|
113 |
+
print(e)
|
114 |
+
if load_path is None:
|
115 |
+
print("Model loading failed - starting from global step 0")
|
116 |
+
return 0
|
117 |
+
return sd["global_step"]
|
118 |
+
|
119 |
+
|
120 |
+
def get_params_for_weight_decay_optimization(module, config):
|
121 |
+
"""
|
122 |
+
Divide params into with-weight-decay and without-weight-decay groups.
|
123 |
+
Layernorms and biases will have no weight decay but the rest will.
|
124 |
+
"""
|
125 |
+
weight_decay_params = {"params": []}
|
126 |
+
no_weight_decay_params = {"params": [], "weight_decay": 0.0}
|
127 |
+
blacklist_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
128 |
+
|
129 |
+
for module_ in module.modules():
|
130 |
+
if isinstance(module_, blacklist_modules) or (
|
131 |
+
config.weight_decay == 0.0
|
132 |
+
): # also include all parameters here if no weight decay is being done
|
133 |
+
no_weight_decay_params["params"].extend(
|
134 |
+
[
|
135 |
+
p
|
136 |
+
for p in list(module_._parameters.values())
|
137 |
+
if (p is not None) and p.requires_grad
|
138 |
+
]
|
139 |
+
)
|
140 |
+
else:
|
141 |
+
for n, p in list(module_._parameters.items()):
|
142 |
+
if p is not None and p.requires_grad:
|
143 |
+
if n != "bias":
|
144 |
+
weight_decay_params["params"].append(p)
|
145 |
+
else:
|
146 |
+
no_weight_decay_params["params"].append(p)
|
147 |
+
|
148 |
+
param_dict = {
|
149 |
+
pn: p
|
150 |
+
for pn, p in module.named_parameters()
|
151 |
+
if p is not None and p.requires_grad
|
152 |
+
}
|
153 |
+
assert len(no_weight_decay_params["params"]) + len(
|
154 |
+
weight_decay_params["params"]
|
155 |
+
) == len(
|
156 |
+
param_dict.keys()
|
157 |
+
), "Number of params in both groups != total number of trainable params"
|
158 |
+
if config.weight_decay == 0.0:
|
159 |
+
# only return a single param group if no weight decay is being used anyway
|
160 |
+
return [no_weight_decay_params]
|
161 |
+
return [weight_decay_params, no_weight_decay_params]
|
162 |
+
|
163 |
+
|
164 |
+
def configure_param_groups(model, config):
|
165 |
+
"""
|
166 |
+
Configures the different parameter groups in the model for training.
|
167 |
+
If a separate learning rate for the image prefix is provided, we separate out the groups here.
|
168 |
+
Additionally, parameters to which weight decay shouldn't be applied (layernorms / biases) are separated.
|
169 |
+
"""
|
170 |
+
if config.image_enc_lr is not None:
|
171 |
+
|
172 |
+
# get the params for the image prefix / proj
|
173 |
+
image_enc_params = get_params_for_weight_decay_optimization(
|
174 |
+
model.image_prefix.enc, config
|
175 |
+
)
|
176 |
+
for pdict in image_enc_params:
|
177 |
+
pdict["lr"] = config.image_enc_lr
|
178 |
+
image_proj_params = get_params_for_weight_decay_optimization(
|
179 |
+
model.image_prefix.proj, config
|
180 |
+
)
|
181 |
+
|
182 |
+
# get the params for layernorm if it exists
|
183 |
+
if config.use_image_embed_layernorm:
|
184 |
+
image_ln_params = get_params_for_weight_decay_optimization(
|
185 |
+
model.image_prefix.ln, config
|
186 |
+
)
|
187 |
+
image_proj_params += image_ln_params
|
188 |
+
|
189 |
+
# get the params for the lm
|
190 |
+
lm_params = get_params_for_weight_decay_optimization(model.lm, config)
|
191 |
+
|
192 |
+
# get params for class head if it exists
|
193 |
+
class_params = []
|
194 |
+
if hasattr(model, "class_head") and model.class_head is not None:
|
195 |
+
class_params = get_params_for_weight_decay_optimization(
|
196 |
+
model.class_head, config
|
197 |
+
)
|
198 |
+
|
199 |
+
all_params = []
|
200 |
+
for p in image_enc_params + lm_params + image_proj_params + class_params:
|
201 |
+
if p["params"]:
|
202 |
+
all_params.append(p)
|
203 |
+
else:
|
204 |
+
all_params = get_params_for_weight_decay_optimization(model, config)
|
205 |
+
|
206 |
+
# merge param dicts with shared lr / wd values
|
207 |
+
d = defaultdict(dict)
|
208 |
+
for param_group in all_params:
|
209 |
+
lr = param_group.get("lr", None)
|
210 |
+
wd = param_group.get("weight_decay", None)
|
211 |
+
key = f"lr_{lr}_wd_{wd}"
|
212 |
+
if d[key].get("params") is None:
|
213 |
+
d[key]["params"] = []
|
214 |
+
d[key]["params"].extend(param_group["params"])
|
215 |
+
if lr is not None:
|
216 |
+
d[key]["lr"] = lr
|
217 |
+
if wd is not None:
|
218 |
+
d[key]["weight_decay"] = wd
|
219 |
+
all_params = list(d.values())
|
220 |
+
|
221 |
+
n_params = sum([len(d["params"]) for d in all_params])
|
222 |
+
param_dict = {
|
223 |
+
pn: p for pn, p in model.named_parameters() if p is not None and p.requires_grad
|
224 |
+
}
|
225 |
+
assert n_params == len(
|
226 |
+
param_dict
|
227 |
+
), f"Some parameters are missing from param groups ({n_params} | {len(param_dict)})"
|
228 |
+
|
229 |
+
# if we're using multiple param groups, set the min / max lr for each one[]
|
230 |
+
# appropriately in deepspeed's scheduler
|
231 |
+
config.deepspeed_config_params["scheduler"]["params"]["warmup_min_lr"] = [
|
232 |
+
config.min_lr for _ in all_params
|
233 |
+
]
|
234 |
+
config.deepspeed_config_params["scheduler"]["params"]["warmup_max_lr"] = [
|
235 |
+
d.get("lr", config.lr) for d in all_params
|
236 |
+
]
|
237 |
+
|
238 |
+
return all_params
|
239 |
+
|
240 |
+
|
241 |
+
def count_parameters(model):
|
242 |
+
"""
|
243 |
+
Counts the number of trainable parameters in a model
|
244 |
+
"""
|
245 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
246 |
+
|
247 |
+
|
248 |
+
def log_table(name, model_outputs, gt_answers_list, global_step):
|
249 |
+
results_table = wandb.Table(columns=["model output", "ground truth(s)"])
|
250 |
+
for o, gt in zip(model_outputs, gt_answers_list):
|
251 |
+
results_table.add_data(o, gt)
|
252 |
+
wandb_log({f"eval/{name}": results_table}, step=global_step)
|
253 |
+
|
254 |
+
|
255 |
+
def get_world_info():
|
256 |
+
local_rank = int(os.environ["LOCAL_RANK"])
|
257 |
+
rank = int(os.environ["RANK"])
|
258 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
259 |
+
return local_rank, rank, world_size
|
260 |
+
|
261 |
+
|
262 |
+
def init_distributed(backend="nccl"):
|
263 |
+
if not torch.distributed.is_initialized():
|
264 |
+
deepspeed.init_distributed(
|
265 |
+
dist_backend=backend, verbose=True, auto_mpi_discovery=True
|
266 |
+
)
|
267 |
+
local_rank, rank, world_size = get_world_info()
|
268 |
+
torch.cuda.set_device(local_rank)
|
269 |
+
return local_rank, rank, world_size
|
270 |
+
|
271 |
+
|
272 |
+
def collate_fn_classification(batch_data, seq_len=2048):
|
273 |
+
|
274 |
+
# for nvlr2: list(zip*(batch_data)) = [l_images, r_images, captions, class_labels]
|
275 |
+
image_list = list(zip(*batch_data))[:-2]
|
276 |
+
captions, class_labels = list(zip(*batch_data))[-2:]
|
277 |
+
|
278 |
+
# images, captions, class_labels = list(zip(*batch_data))
|
279 |
+
images_list = [torch.cat(image) for image in image_list]
|
280 |
+
captions = torch.cat([i[:, :seq_len] for i in captions])
|
281 |
+
class_labels = torch.stack(class_labels)
|
282 |
+
return images_list, captions, class_labels
|
283 |
+
|
284 |
+
|
285 |
+
def infer_checkpoint_path_from_config(config):
|
286 |
+
checkpoint_folder = config.save
|
287 |
+
if checkpoint_folder is None:
|
288 |
+
raise ValueError(
|
289 |
+
"No checkpoint folder specified in config. Please provide a checkpoint."
|
290 |
+
)
|
291 |
+
|
292 |
+
# check for 'latest' tag in checkpoint folder
|
293 |
+
if (Path(checkpoint_folder) / "latest").exists():
|
294 |
+
latest_ckpt = (Path(checkpoint_folder) / "latest").read_text().strip()
|
295 |
+
else:
|
296 |
+
raise ValueError(
|
297 |
+
f"No checkpoint found in {checkpoint_folder}. Please provide a checkpoint."
|
298 |
+
)
|
299 |
+
|
300 |
+
checkpoint_path = str(
|
301 |
+
Path(checkpoint_folder) / latest_ckpt / "mp_rank_00_model_states.pt"
|
302 |
+
)
|
303 |
+
if not Path(checkpoint_path).exists():
|
304 |
+
raise ValueError(
|
305 |
+
f"No checkpoint found in {checkpoint_path}. Please provide a checkpoint."
|
306 |
+
)
|
307 |
+
|
308 |
+
return checkpoint_path
|
309 |
+
|
310 |
+
|
311 |
+
# [tensor_1, tensor_2], tensor_3, tensor_4 = to_cuda_half([tensor_1, tensor_2], tensor_3, tensor_4)
|
312 |
+
# probably not working yet
|
313 |
+
def to_cuda_half(*args):
|
314 |
+
cuda_half_args = []
|
315 |
+
for x in args:
|
316 |
+
if isinstance(x, list):
|
317 |
+
x_cuda_half = to_cuda_half(*x)
|
318 |
+
cuda_half_args.append(x_cuda_half)
|
319 |
+
elif isinstance(x, tuple):
|
320 |
+
x_cuda_half = to_cuda_half(*x)
|
321 |
+
cuda_half_args.append(x_cuda_half)
|
322 |
+
else:
|
323 |
+
if x.dtype in [torch.float32, torch.float16]:
|
324 |
+
cuda_half_args.append(x.cuda().half())
|
325 |
+
elif x.dtype == torch.long:
|
326 |
+
cuda_half_args.append(x.cuda())
|
327 |
+
|
328 |
+
if len(cuda_half_args) == 1:
|
329 |
+
return cuda_half_args[0]
|
330 |
+
else:
|
331 |
+
return cuda_half_args
|
332 |
+
|
333 |
+
|
334 |
+
def build_labels(
|
335 |
+
input_embeddings: TensorType["b", "s", "d"],
|
336 |
+
captions: TensorType["b", "s"],
|
337 |
+
eos_token,
|
338 |
+
device,
|
339 |
+
) -> TensorType["b", "s"]:
|
340 |
+
"""
|
341 |
+
Builds labels from input embeddings.
|
342 |
+
|
343 |
+
Masks out the labels with -100 in positions up to the seq length of the embeddings, so loss is only computed for captions,
|
344 |
+
and not for image tokens.
|
345 |
+
Additionally, masks out everything *after* the first eos token.
|
346 |
+
"""
|
347 |
+
shape = input_embeddings.shape[:2] # b, s
|
348 |
+
|
349 |
+
assert captions.shape[1] >= shape[1]
|
350 |
+
|
351 |
+
# make sure to add masked embedding tokens in the appropriate locations in the labels
|
352 |
+
embedding_tokens = torch.zeros(shape, dtype=torch.int64).to(device) - 100
|
353 |
+
labels = torch.cat(
|
354 |
+
(embedding_tokens, captions[:, : -shape[1]]), dim=1
|
355 |
+
) # we truncate the sequence length of the captions, as they are always padded to the full sequence length
|
356 |
+
|
357 |
+
# mask out repeating eos tokens
|
358 |
+
for label in labels:
|
359 |
+
for k, token in enumerate(label):
|
360 |
+
if token == eos_token:
|
361 |
+
label[k + 1 :] = -100
|
362 |
+
break
|
363 |
+
|
364 |
+
return labels
|
365 |
+
|
366 |
+
|
367 |
+
def is_url(string):
|
368 |
+
return string.startswith("http://") or string.startswith("https://")
|
369 |
+
|
370 |
+
def download_checkpoint(checkpoint_url, save_as):
|
371 |
+
|
372 |
+
gdown.download(url = checkpoint_url, output = save_as, quiet=False)
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torchtyping
|
2 |
+
typeguard
|
3 |
+
git+https://github.com/finetuneanon/transformers.git#egg=transformers
|
4 |
+
gdown
|
5 |
+
tqdm
|
6 |
+
timm
|
7 |
+
git+https://github.com/openai/CLIP.git
|
8 |
+
deepspeed
|
9 |
+
wandb
|
test.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from magma import Magma
|
4 |
+
from magma.language_model import get_language_model
|
5 |
+
from magma.utils import get_tokenizer
|
6 |
+
|
7 |
+
if __name__ == "__main__":
|
8 |
+
# model = Magma.from_checkpoint(
|
9 |
+
# "configs/MAGMA_v1.yml",
|
10 |
+
# "/mnt/localdisk/mp_rank_00_model_states.pt",
|
11 |
+
# model_dir="/mnt/localdisk/gptj",
|
12 |
+
# lm_from_pretrained=True,
|
13 |
+
# )
|
14 |
+
# gptj_model = model.lm
|
15 |
+
# model.half().cuda().eval()
|
16 |
+
tokenizer = get_tokenizer()
|
17 |
+
input_text = tokenizer.encode("this is a test", return_tensors="pt").cuda()
|
18 |
+
input_img = torch.ones(1, 3, 384, 384).half().cuda()
|
19 |
+
|
20 |
+
# input = model.embed([input_img, input_text])
|
21 |
+
# logits = gptj_model(inputs_embeds=input).logits
|
22 |
+
# logits = logits.detach().cpu().numpy()
|
23 |
+
# np.save("/mnt/localdisk/logits_new.npy", logits)
|
24 |
+
|
25 |
+
from transformers import GPTJForCausalLM
|
26 |
+
import torch
|
27 |
+
|
28 |
+
# load new model
|
29 |
+
model = GPTJForCausalLM.from_pretrained(
|
30 |
+
"EleutherAI/gpt-j-6B",
|
31 |
+
revision="float16",
|
32 |
+
torch_dtype=torch.float16,
|
33 |
+
low_cpu_mem_usage=True,
|
34 |
+
)
|
35 |
+
model.cuda()
|
36 |
+
|
37 |
+
model.eval()
|
38 |
+
|
39 |
+
logits = model(input_text).logits
|
40 |
+
logits = logits.detach().cpu().numpy()
|
41 |
+
np.save("/mnt/localdisk/gptj_logits_new.npy", logits)
|
42 |
+
|
43 |
+
print("test")
|
train.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import deepspeed
|
4 |
+
import wandb
|
5 |
+
from torch.utils.data import random_split, ConcatDataset
|
6 |
+
from torch.optim import AdamW
|
7 |
+
from tqdm import tqdm
|
8 |
+
from functools import partial
|
9 |
+
from magma.datasets import (
|
10 |
+
collate_fn,
|
11 |
+
ImgCptDataset,
|
12 |
+
)
|
13 |
+
from magma.magma import (
|
14 |
+
Magma,
|
15 |
+
)
|
16 |
+
from magma.utils import (
|
17 |
+
is_main,
|
18 |
+
cycle,
|
19 |
+
parse_args,
|
20 |
+
wandb_log,
|
21 |
+
wandb_init,
|
22 |
+
save_model,
|
23 |
+
load_model,
|
24 |
+
print_main,
|
25 |
+
configure_param_groups,
|
26 |
+
)
|
27 |
+
from magma.train_loop import (
|
28 |
+
eval_step,
|
29 |
+
inference_step,
|
30 |
+
train_step,
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
def _load_img_cpt_datasets(dataset_dir, tokenizer, transforms):
|
35 |
+
if isinstance(dataset_dir, (list, tuple)):
|
36 |
+
return ConcatDataset(
|
37 |
+
[_load_img_cpt_datasets(d, tokenizer, transforms) for d in dataset_dir]
|
38 |
+
)
|
39 |
+
elif isinstance(dataset_dir, str):
|
40 |
+
return ImgCptDataset(dataset_dir, tokenizer=tokenizer, transforms=transforms)
|
41 |
+
else:
|
42 |
+
raise TypeError("dataset dir wrong type")
|
43 |
+
|
44 |
+
|
45 |
+
def get_pretraining_datasets(config, tokenizer, transforms):
|
46 |
+
# if config.train_dataset_dir is a list, load all datasets + join together
|
47 |
+
train_dataset = _load_img_cpt_datasets(
|
48 |
+
config.train_dataset_dir, tokenizer, transforms
|
49 |
+
)
|
50 |
+
# if no dedicated eval sets are given, use a percentage of the train dataset
|
51 |
+
if config.eval_dataset_dir is None:
|
52 |
+
eval_len = int(len(train_dataset) * config.eval_dataset_pct)
|
53 |
+
train_len = len(train_dataset) - eval_len
|
54 |
+
print(
|
55 |
+
f"Randomly splitting train_dataset into two datasets of length {train_len} and {eval_len}"
|
56 |
+
)
|
57 |
+
train_dataset, eval_dataset = random_split(train_dataset, [train_len, eval_len])
|
58 |
+
else:
|
59 |
+
eval_dataset = _load_img_cpt_datasets(
|
60 |
+
config.eval_dataset_dir, tokenizer, transforms
|
61 |
+
)
|
62 |
+
|
63 |
+
print_main(f"Loaded train dataset with {len(train_dataset)} samples")
|
64 |
+
print_main(f"Loaded eval dataset with {len(eval_dataset)} samples")
|
65 |
+
|
66 |
+
return train_dataset, eval_dataset
|
67 |
+
|
68 |
+
|
69 |
+
# tell tokenizers not to do parallelism
|
70 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
71 |
+
|
72 |
+
if __name__ == "__main__":
|
73 |
+
|
74 |
+
# parse command line arguments:
|
75 |
+
args = parse_args()
|
76 |
+
deepspeed.init_distributed()
|
77 |
+
|
78 |
+
# load model + tokenizer:
|
79 |
+
model = Magma(
|
80 |
+
args.config
|
81 |
+
) # for finetuning one might want to load the model via Magma.from_checkpoint(...) here
|
82 |
+
tokenizer, config, transforms = model.tokenizer, model.config, model.transforms
|
83 |
+
|
84 |
+
# filter frozen from trainable parameters:
|
85 |
+
trainable_parameters = configure_param_groups(model, config)
|
86 |
+
|
87 |
+
# load data:
|
88 |
+
train_dataset, eval_dataset = get_pretraining_datasets(
|
89 |
+
config, tokenizer, transforms
|
90 |
+
)
|
91 |
+
|
92 |
+
print_main(f"Loaded train dataset with {len(train_dataset)} samples")
|
93 |
+
print_main(f"Loaded eval dataset with {len(eval_dataset)} samples")
|
94 |
+
|
95 |
+
opt = AdamW(
|
96 |
+
trainable_parameters,
|
97 |
+
config.lr,
|
98 |
+
betas=(0.9, 0.95),
|
99 |
+
weight_decay=config.weight_decay,
|
100 |
+
)
|
101 |
+
|
102 |
+
model_engine, opt, train_loader, lr_scheduler = deepspeed.initialize(
|
103 |
+
args=args,
|
104 |
+
model=model,
|
105 |
+
optimizer=opt,
|
106 |
+
model_parameters=trainable_parameters,
|
107 |
+
training_data=train_dataset,
|
108 |
+
collate_fn=partial(collate_fn, seq_len=model.seq_len),
|
109 |
+
config_params=config.deepspeed_config_params,
|
110 |
+
)
|
111 |
+
eval_loader = cycle(model_engine.deepspeed_io(eval_dataset))
|
112 |
+
train_loader = cycle(train_loader)
|
113 |
+
|
114 |
+
# initialize training
|
115 |
+
global_step = 0
|
116 |
+
if config.load:
|
117 |
+
# loads a deepspeed checkpoint if provided. For finetuning, set load_optimizer to false
|
118 |
+
previous_global_step = load_model(
|
119 |
+
model_engine,
|
120 |
+
config.load,
|
121 |
+
load_optimizer_states=config.load_optimizer,
|
122 |
+
load_lr_scheduler_states=config.load_optimizer,
|
123 |
+
)
|
124 |
+
|
125 |
+
if config.load_optimizer:
|
126 |
+
global_step = previous_global_step
|
127 |
+
|
128 |
+
pbar = tqdm(
|
129 |
+
range(0, config.train_steps),
|
130 |
+
desc="training...",
|
131 |
+
initial=global_step,
|
132 |
+
total=config.train_steps,
|
133 |
+
disable=not is_main(),
|
134 |
+
)
|
135 |
+
wandb_init(
|
136 |
+
project=config.wandb_project,
|
137 |
+
name=config.name or wandb.util.generate_id(),
|
138 |
+
config=config,
|
139 |
+
)
|
140 |
+
|
141 |
+
# training loop
|
142 |
+
for i in pbar:
|
143 |
+
if global_step >= config.train_steps:
|
144 |
+
break
|
145 |
+
|
146 |
+
##### train step
|
147 |
+
loss = train_step(config, train_loader, model_engine)
|
148 |
+
|
149 |
+
global_step += 1
|
150 |
+
|
151 |
+
if global_step % config.log_every == 0:
|
152 |
+
pbar.set_description(f"training... Step: {global_step} Loss: {loss}")
|
153 |
+
current_lr = (
|
154 |
+
[lr for lr in lr_scheduler.get_lr()]
|
155 |
+
if lr_scheduler is not None
|
156 |
+
else config.lr
|
157 |
+
)
|
158 |
+
to_log = {"train/loss": loss, "train/lr": current_lr}
|
159 |
+
wandb_log(to_log, step=global_step)
|
160 |
+
|
161 |
+
##### Evaluation phase
|
162 |
+
if global_step % config.eval_every == 0:
|
163 |
+
model_engine.eval()
|
164 |
+
with torch.no_grad():
|
165 |
+
|
166 |
+
##### eval step:
|
167 |
+
eval_loss = eval_step(config, eval_loader, model_engine)
|
168 |
+
|
169 |
+
wandb_log({"eval/loss": eval_loss}, step=global_step)
|
170 |
+
pbar.set_description(
|
171 |
+
f"evaluating... Step: {global_step} Eval Loss: {eval_loss}"
|
172 |
+
)
|
173 |
+
|
174 |
+
##### inference:
|
175 |
+
image_grid, caption = inference_step(config, eval_loader, model_engine)
|
176 |
+
wandb_log(
|
177 |
+
{"inference/image": wandb.Image(image_grid, caption=caption)},
|
178 |
+
step=global_step,
|
179 |
+
)
|
180 |
+
|
181 |
+
model_engine.train()
|
182 |
+
|
183 |
+
##### Save model
|
184 |
+
if global_step % config.save_every == 0:
|
185 |
+
if config.save is not None:
|
186 |
+
save_model(model_engine, config.save, global_step)
|
187 |
+
print_main(f"saving model at step {global_step}")
|
188 |
+
|
189 |
+
##### Save model after training is finished
|
190 |
+
if config.save is not None:
|
191 |
+
save_model(model_engine, config.save, global_step)
|
192 |
+
print_main(f"saving model at end of training (step {global_step})")
|