ydshieh
commited on
Commit
•
a244e91
1
Parent(s):
8364b8b
add codes
Browse files- .gitattributes +2 -0
- generate.py +63 -0
- run_summarization.py +832 -0
- test_vit_gpt2.py +83 -0
- test_wit_dataset_script.py +23 -0
- tests_load.py +48 -0
- tests_save.py +48 -0
- vit_gpt2/__init__.py +0 -0
- vit_gpt2/configuration_vit_gpt2.py +45 -0
- vit_gpt2/modeling_flax_gpt2.py +752 -0
- vit_gpt2/modeling_flax_vit_gpt2.py +704 -0
- vit_gpt2/modeling_flax_vit_gpt2_lm.py +684 -0
- wit_data_dir/dev/dev.tsv +3 -0
- wit_data_dir/test/test.tsv +3 -0
- wit_dataset_script.py +145 -0
.gitattributes
CHANGED
@@ -16,3 +16,5 @@
|
|
16 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
17 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
18 |
wit_data_dir/train/train.tsv filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
16 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
17 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
18 |
wit_data_dir/train/train.tsv filter=lfs diff=lfs merge=lfs -text
|
19 |
+
wit_data_dir/dev/dev.tsv filter=lfs diff=lfs merge=lfs -text
|
20 |
+
wit_data_dir/test/test.tsv filter=lfs diff=lfs merge=lfs -text
|
generate.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, os
|
2 |
+
|
3 |
+
current_path = os.path.dirname(os.path.abspath(__file__))
|
4 |
+
sys.path.append(current_path)
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
# Vit - as encoder
|
9 |
+
from transformers import ViTFeatureExtractor
|
10 |
+
from PIL import Image
|
11 |
+
import requests
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
15 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
16 |
+
|
17 |
+
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
|
18 |
+
encoder_inputs = feature_extractor(images=image, return_tensors="jax")
|
19 |
+
pixel_values = encoder_inputs.pixel_values
|
20 |
+
|
21 |
+
# GPT2 / GPT2LM - as decoder
|
22 |
+
from transformers import ViTFeatureExtractor, GPT2Tokenizer
|
23 |
+
|
24 |
+
name = 'asi/gpt-fr-cased-small'
|
25 |
+
tokenizer = GPT2Tokenizer.from_pretrained(name)
|
26 |
+
decoder_inputs = tokenizer("mon chien est mignon", return_tensors="jax", )
|
27 |
+
print(decoder_inputs)
|
28 |
+
|
29 |
+
# Setup the tokenizer for targets
|
30 |
+
with tokenizer.as_target_tokenizer():
|
31 |
+
labels = tokenizer(
|
32 |
+
['un chien super beau' + ' ' + tokenizer.eos_token, 'un chat' + ' ' + tokenizer.eos_token], max_length=5, padding="max_length", truncation=True, return_tensors="np"
|
33 |
+
)
|
34 |
+
print(labels)
|
35 |
+
exit(0)
|
36 |
+
|
37 |
+
inputs = dict(decoder_inputs)
|
38 |
+
inputs['pixel_values'] = pixel_values
|
39 |
+
#print(inputs)
|
40 |
+
|
41 |
+
|
42 |
+
# With the LM head in GPT2LM
|
43 |
+
from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
|
44 |
+
flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_pretrained('./outputs-small-ds/ckpt_3',)
|
45 |
+
|
46 |
+
logits = flax_vit_gpt2_lm(**inputs)[0]
|
47 |
+
preds = np.argmax(logits, axis=-1)
|
48 |
+
print('=' * 60)
|
49 |
+
print('Flax: Vit + modified GPT2LM')
|
50 |
+
#print(preds)
|
51 |
+
|
52 |
+
max_length = 32
|
53 |
+
num_beams = 16
|
54 |
+
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
55 |
+
batch = {'pixel_values': pixel_values}
|
56 |
+
generation = flax_vit_gpt2_lm.generate(batch['pixel_values'], **gen_kwargs)
|
57 |
+
print(generation)
|
58 |
+
|
59 |
+
token_ids = np.array(generation.sequences)[0]
|
60 |
+
generation = tokenizer.decode(token_ids)
|
61 |
+
print(generation)
|
62 |
+
|
63 |
+
del flax_vit_gpt2_lm
|
run_summarization.py
ADDED
@@ -0,0 +1,832 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2021 The HuggingFace Team All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Fine-tuning the library models for summarization.
|
18 |
+
"""
|
19 |
+
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
20 |
+
|
21 |
+
import sys, os
|
22 |
+
|
23 |
+
current_path = os.path.dirname(os.path.abspath(__file__))
|
24 |
+
sys.path.append(current_path)
|
25 |
+
|
26 |
+
import logging
|
27 |
+
import os
|
28 |
+
import sys
|
29 |
+
import time
|
30 |
+
from dataclasses import dataclass, field
|
31 |
+
from functools import partial
|
32 |
+
from pathlib import Path
|
33 |
+
from typing import Callable, Optional
|
34 |
+
|
35 |
+
import datasets
|
36 |
+
import nltk # Here to have a nice missing dependency error message early on
|
37 |
+
import numpy as np
|
38 |
+
from datasets import Dataset, load_dataset, load_metric
|
39 |
+
from tqdm import tqdm
|
40 |
+
|
41 |
+
import jax
|
42 |
+
import jax.numpy as jnp
|
43 |
+
import optax
|
44 |
+
import transformers
|
45 |
+
from filelock import FileLock
|
46 |
+
from flax import jax_utils, traverse_util
|
47 |
+
from flax.jax_utils import unreplicate
|
48 |
+
from flax.training import train_state
|
49 |
+
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
50 |
+
from transformers import (
|
51 |
+
CONFIG_MAPPING,
|
52 |
+
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
53 |
+
AutoConfig,
|
54 |
+
AutoTokenizer,
|
55 |
+
FlaxAutoModelForSeq2SeqLM,
|
56 |
+
HfArgumentParser,
|
57 |
+
TrainingArguments,
|
58 |
+
is_tensorboard_available,
|
59 |
+
)
|
60 |
+
from transformers.file_utils import is_offline_mode
|
61 |
+
|
62 |
+
from transformers import ViTFeatureExtractor, GPT2Tokenizer, GPT2Config
|
63 |
+
from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
|
64 |
+
|
65 |
+
logger = logging.getLogger(__name__)
|
66 |
+
|
67 |
+
try:
|
68 |
+
nltk.data.find("tokenizers/punkt")
|
69 |
+
except (LookupError, OSError):
|
70 |
+
if is_offline_mode():
|
71 |
+
raise LookupError(
|
72 |
+
"Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
|
73 |
+
)
|
74 |
+
with FileLock(".lock") as lock:
|
75 |
+
nltk.download("punkt", quiet=True)
|
76 |
+
|
77 |
+
|
78 |
+
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
|
79 |
+
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
80 |
+
|
81 |
+
|
82 |
+
@dataclass
|
83 |
+
class ModelArguments:
|
84 |
+
"""
|
85 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
86 |
+
"""
|
87 |
+
|
88 |
+
model_name_or_path: Optional[str] = field(
|
89 |
+
default=None,
|
90 |
+
metadata={
|
91 |
+
"help": "The model checkpoint for weights initialization."
|
92 |
+
"Don't set if you want to train a model from scratch."
|
93 |
+
},
|
94 |
+
)
|
95 |
+
model_type: Optional[str] = field(
|
96 |
+
default=None,
|
97 |
+
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
98 |
+
)
|
99 |
+
config_name: Optional[str] = field(
|
100 |
+
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
101 |
+
)
|
102 |
+
tokenizer_name: Optional[str] = field(
|
103 |
+
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
104 |
+
)
|
105 |
+
cache_dir: Optional[str] = field(
|
106 |
+
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
107 |
+
)
|
108 |
+
use_fast_tokenizer: bool = field(
|
109 |
+
default=True,
|
110 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
111 |
+
)
|
112 |
+
dtype: Optional[str] = field(
|
113 |
+
default="float32",
|
114 |
+
metadata={
|
115 |
+
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
116 |
+
},
|
117 |
+
)
|
118 |
+
|
119 |
+
|
120 |
+
@dataclass
|
121 |
+
class DataTrainingArguments:
|
122 |
+
"""
|
123 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
124 |
+
"""
|
125 |
+
|
126 |
+
dataset_name: Optional[str] = field(
|
127 |
+
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
128 |
+
)
|
129 |
+
dataset_config_name: Optional[str] = field(
|
130 |
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
131 |
+
)
|
132 |
+
text_column: Optional[str] = field(
|
133 |
+
default=None,
|
134 |
+
metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
|
135 |
+
)
|
136 |
+
summary_column: Optional[str] = field(
|
137 |
+
default=None,
|
138 |
+
metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
|
139 |
+
)
|
140 |
+
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
141 |
+
validation_file: Optional[str] = field(
|
142 |
+
default=None,
|
143 |
+
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
144 |
+
)
|
145 |
+
max_source_length: Optional[int] = field(
|
146 |
+
default=1024,
|
147 |
+
metadata={
|
148 |
+
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
149 |
+
"than this will be truncated, sequences shorter will be padded."
|
150 |
+
},
|
151 |
+
)
|
152 |
+
max_target_length: Optional[int] = field(
|
153 |
+
default=128,
|
154 |
+
metadata={
|
155 |
+
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
|
156 |
+
"than this will be truncated, sequences shorter will be padded."
|
157 |
+
},
|
158 |
+
)
|
159 |
+
val_max_target_length: Optional[int] = field(
|
160 |
+
default=None,
|
161 |
+
metadata={
|
162 |
+
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
163 |
+
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
164 |
+
"This argument is also used to override the `max_length` param of `model.generate`, which is used "
|
165 |
+
"during evaluation."
|
166 |
+
},
|
167 |
+
)
|
168 |
+
max_train_samples: Optional[int] = field(
|
169 |
+
default=None,
|
170 |
+
metadata={
|
171 |
+
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
172 |
+
"value if set."
|
173 |
+
},
|
174 |
+
)
|
175 |
+
max_eval_samples: Optional[int] = field(
|
176 |
+
default=None,
|
177 |
+
metadata={
|
178 |
+
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
179 |
+
"value if set."
|
180 |
+
},
|
181 |
+
)
|
182 |
+
max_predict_samples: Optional[int] = field(
|
183 |
+
default=None,
|
184 |
+
metadata={
|
185 |
+
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
186 |
+
"value if set."
|
187 |
+
},
|
188 |
+
)
|
189 |
+
preprocessing_num_workers: Optional[int] = field(
|
190 |
+
default=None,
|
191 |
+
metadata={"help": "The number of processes to use for the preprocessing."},
|
192 |
+
)
|
193 |
+
source_prefix: Optional[str] = field(
|
194 |
+
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
195 |
+
)
|
196 |
+
predict_with_generate: bool = field(
|
197 |
+
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
|
198 |
+
)
|
199 |
+
num_beams: Optional[int] = field(
|
200 |
+
default=None,
|
201 |
+
metadata={
|
202 |
+
"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
|
203 |
+
"which is used during evaluation."
|
204 |
+
},
|
205 |
+
)
|
206 |
+
overwrite_cache: bool = field(
|
207 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
208 |
+
)
|
209 |
+
|
210 |
+
def __post_init__(self):
|
211 |
+
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
212 |
+
raise ValueError("Need either a dataset name or a training/validation file.")
|
213 |
+
else:
|
214 |
+
if self.train_file is not None:
|
215 |
+
extension = self.train_file.split(".")[-1]
|
216 |
+
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
|
217 |
+
if self.validation_file is not None:
|
218 |
+
extension = self.validation_file.split(".")[-1]
|
219 |
+
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
220 |
+
if self.val_max_target_length is None:
|
221 |
+
self.val_max_target_length = self.max_target_length
|
222 |
+
|
223 |
+
|
224 |
+
summarization_name_mapping = {
|
225 |
+
"amazon_reviews_multi": ("review_body", "review_title"),
|
226 |
+
"big_patent": ("description", "abstract"),
|
227 |
+
"cnn_dailymail": ("article", "highlights"),
|
228 |
+
"orange_sum": ("text", "summary"),
|
229 |
+
"pn_summary": ("article", "summary"),
|
230 |
+
"psc": ("extract_text", "summary_text"),
|
231 |
+
"samsum": ("dialogue", "summary"),
|
232 |
+
"thaisum": ("body", "summary"),
|
233 |
+
"xglue": ("news_body", "news_title"),
|
234 |
+
"xsum": ("document", "summary"),
|
235 |
+
"wiki_summary": ("article", "highlights"),
|
236 |
+
}
|
237 |
+
|
238 |
+
|
239 |
+
class TrainState(train_state.TrainState):
|
240 |
+
dropout_rng: jnp.ndarray
|
241 |
+
|
242 |
+
def replicate(self):
|
243 |
+
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
244 |
+
|
245 |
+
|
246 |
+
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
|
247 |
+
"""
|
248 |
+
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
249 |
+
Shuffle batches if `shuffle` is `True`.
|
250 |
+
"""
|
251 |
+
steps_per_epoch = len(dataset) // batch_size
|
252 |
+
|
253 |
+
if shuffle:
|
254 |
+
batch_idx = jax.random.permutation(rng, len(dataset))
|
255 |
+
else:
|
256 |
+
batch_idx = jnp.arange(len(dataset))
|
257 |
+
|
258 |
+
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
259 |
+
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
260 |
+
|
261 |
+
for idx in batch_idx:
|
262 |
+
batch = dataset[idx]
|
263 |
+
batch = {k: jnp.array(v) for k, v in batch.items()}
|
264 |
+
|
265 |
+
batch = shard(batch)
|
266 |
+
|
267 |
+
yield batch
|
268 |
+
|
269 |
+
|
270 |
+
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
271 |
+
summary_writer.scalar("train_time", train_time, step)
|
272 |
+
|
273 |
+
train_metrics = get_metrics(train_metrics)
|
274 |
+
for key, vals in train_metrics.items():
|
275 |
+
tag = f"train_{key}"
|
276 |
+
for i, val in enumerate(vals):
|
277 |
+
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
278 |
+
|
279 |
+
for metric_name, value in eval_metrics.items():
|
280 |
+
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
281 |
+
|
282 |
+
|
283 |
+
def create_learning_rate_fn(
|
284 |
+
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
285 |
+
) -> Callable[[int], jnp.array]:
|
286 |
+
"""Returns a linear warmup, linear_decay learning rate function."""
|
287 |
+
steps_per_epoch = train_ds_size // train_batch_size
|
288 |
+
num_train_steps = steps_per_epoch * num_train_epochs
|
289 |
+
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
290 |
+
decay_fn = optax.linear_schedule(
|
291 |
+
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
292 |
+
)
|
293 |
+
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
294 |
+
return schedule_fn
|
295 |
+
|
296 |
+
|
297 |
+
def main():
|
298 |
+
# See all possible arguments in src/transformers/training_args.py
|
299 |
+
# or by passing the --help flag to this script.
|
300 |
+
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
301 |
+
|
302 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
303 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
304 |
+
# If we pass only one argument to the script and it's the path to a json file,
|
305 |
+
# let's parse it to get our arguments.
|
306 |
+
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
307 |
+
else:
|
308 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
309 |
+
|
310 |
+
if (
|
311 |
+
os.path.exists(training_args.output_dir)
|
312 |
+
and os.listdir(training_args.output_dir)
|
313 |
+
and training_args.do_train
|
314 |
+
and not training_args.overwrite_output_dir
|
315 |
+
):
|
316 |
+
raise ValueError(
|
317 |
+
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
318 |
+
"Use --overwrite_output_dir to overcome."
|
319 |
+
)
|
320 |
+
|
321 |
+
# Make one log on every process with the configuration for debugging.
|
322 |
+
logging.basicConfig(
|
323 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
324 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
325 |
+
level=logging.INFO,
|
326 |
+
)
|
327 |
+
# Setup logging, we only want one process per machine to log things on the screen.
|
328 |
+
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
|
329 |
+
if jax.process_index() == 0:
|
330 |
+
datasets.utils.logging.set_verbosity_warning()
|
331 |
+
transformers.utils.logging.set_verbosity_info()
|
332 |
+
else:
|
333 |
+
datasets.utils.logging.set_verbosity_error()
|
334 |
+
transformers.utils.logging.set_verbosity_error()
|
335 |
+
|
336 |
+
# Set the verbosity to info of the Transformers logger (on main process only):
|
337 |
+
logger.info(f"Training/evaluation parameters {training_args}")
|
338 |
+
|
339 |
+
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
|
340 |
+
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
341 |
+
# (the dataset will be downloaded automatically from the datasets Hub).
|
342 |
+
#
|
343 |
+
# For CSV/JSON files this script will use the first column for the full texts and the second column for the
|
344 |
+
# summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments).
|
345 |
+
#
|
346 |
+
if data_args.dataset_name is not None:
|
347 |
+
# Downloading and loading a dataset from the hub.
|
348 |
+
dataset = load_dataset(
|
349 |
+
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False, data_dir='./wit_data_dir/'
|
350 |
+
)
|
351 |
+
else:
|
352 |
+
data_files = {}
|
353 |
+
if data_args.train_file is not None:
|
354 |
+
data_files["train"] = data_args.train_file
|
355 |
+
extension = data_args.train_file.split(".")[-1]
|
356 |
+
if data_args.validation_file is not None:
|
357 |
+
data_files["validation"] = data_args.validation_file
|
358 |
+
extension = data_args.validation_file.split(".")[-1]
|
359 |
+
if data_args.test_file is not None:
|
360 |
+
data_files["test"] = data_args.test_file
|
361 |
+
extension = data_args.test_file.split(".")[-1]
|
362 |
+
dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
|
363 |
+
|
364 |
+
vit_name_path = 'google/vit-base-patch16-224-in21k'
|
365 |
+
gpt2_name_path = 'asi/gpt-fr-cased-small'
|
366 |
+
|
367 |
+
gpt2_config = GPT2Config.from_pretrained(gpt2_name_path)
|
368 |
+
gpt2_config.add_cross_attention = True
|
369 |
+
|
370 |
+
|
371 |
+
vit_gpt2_name_path = ''
|
372 |
+
|
373 |
+
feature_extractor = ViTFeatureExtractor.from_pretrained(vit_name_path)
|
374 |
+
|
375 |
+
tokenizer = GPT2Tokenizer.from_pretrained(gpt2_name_path)
|
376 |
+
|
377 |
+
if not vit_gpt2_name_path:
|
378 |
+
assert vit_name_path
|
379 |
+
assert gpt2_name_path
|
380 |
+
vit_gpt2_model = FlaxViTGPT2LMForConditionalGeneration.from_vit_gpt2_pretrained(
|
381 |
+
vit_name_path, gpt2_name_path
|
382 |
+
)
|
383 |
+
else:
|
384 |
+
vit_gpt2_model = FlaxViTGPT2LMForConditionalGeneration.from_pretrained(
|
385 |
+
vit_gpt2_name_path
|
386 |
+
)
|
387 |
+
|
388 |
+
model = vit_gpt2_model
|
389 |
+
model.config.is_encoder_decoder = True
|
390 |
+
model.config.decoder_start_token_id = gpt2_config.bos_token_id
|
391 |
+
model.config.bos_token_id = gpt2_config.bos_token_id
|
392 |
+
model.config.eos_token_id = gpt2_config.eos_token_id
|
393 |
+
model.config.pad_token_id = gpt2_config.pad_token_id
|
394 |
+
|
395 |
+
# Preprocessing the datasets.
|
396 |
+
# We need to tokenize inputs and targets.
|
397 |
+
if training_args.do_train:
|
398 |
+
column_names = dataset["train"].column_names
|
399 |
+
elif training_args.do_eval:
|
400 |
+
column_names = dataset["validation"].column_names
|
401 |
+
elif training_args.do_predict:
|
402 |
+
column_names = dataset["test"].column_names
|
403 |
+
else:
|
404 |
+
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
|
405 |
+
return
|
406 |
+
|
407 |
+
image_file_column = 'image_file'
|
408 |
+
caption_column = 'caption'
|
409 |
+
pixels_file_column = 'pixels_file'
|
410 |
+
|
411 |
+
# Temporarily set max_target_length for training.
|
412 |
+
max_target_length = data_args.max_target_length
|
413 |
+
|
414 |
+
# In Flax, for seq2seq models we need to pass `decoder_input_ids`
|
415 |
+
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
|
416 |
+
# for that dynamically import the `shift_tokens_right` function from the model file
|
417 |
+
model_module = __import__(vit_gpt2_model.__module__, fromlist=["shift_tokens_right"])
|
418 |
+
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
|
419 |
+
|
420 |
+
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
421 |
+
def preprocess_function(examples):
|
422 |
+
|
423 |
+
pixels_file = examples[pixels_file_column]
|
424 |
+
if not pixels_file:
|
425 |
+
assert examples[image_file_column]
|
426 |
+
_pixel_values = []
|
427 |
+
for y in examples[image_file_column]:
|
428 |
+
with Image.open(y) as image:
|
429 |
+
encoder_inputs = feature_extractor(images=image, return_tensors="np")
|
430 |
+
x = encoder_inputs.pixel_values
|
431 |
+
_pixel_values.append(x)
|
432 |
+
pixel_values = np.concatenate(_pixel_values)
|
433 |
+
else:
|
434 |
+
pixel_values = np.concatenate([np.load(x) for x in pixels_file])
|
435 |
+
|
436 |
+
targets = examples[caption_column]
|
437 |
+
|
438 |
+
# Add eos_token!!
|
439 |
+
targets = [x + ' ' + tokenizer.eos_token for x in targets]
|
440 |
+
|
441 |
+
model_inputs = {}
|
442 |
+
model_inputs['pixel_values'] = pixel_values
|
443 |
+
|
444 |
+
# Setup the tokenizer for targets
|
445 |
+
with tokenizer.as_target_tokenizer():
|
446 |
+
labels = tokenizer(
|
447 |
+
targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="np"
|
448 |
+
)
|
449 |
+
|
450 |
+
model_inputs["labels"] = labels["input_ids"]
|
451 |
+
|
452 |
+
#print(labels["input_ids"])
|
453 |
+
#print(gpt2_config.pad_token_id)
|
454 |
+
#rint(gpt2_config.bos_token_id)
|
455 |
+
|
456 |
+
decoder_input_ids = shift_tokens_right_fn(
|
457 |
+
jnp.array(labels["input_ids"]), gpt2_config.pad_token_id, gpt2_config.bos_token_id
|
458 |
+
)
|
459 |
+
model_inputs["input_ids"] = np.asarray(decoder_input_ids)
|
460 |
+
|
461 |
+
# We need decoder_attention_mask so we can ignore pad tokens from loss
|
462 |
+
model_inputs["attention_mask"] = labels["attention_mask"]
|
463 |
+
|
464 |
+
return model_inputs
|
465 |
+
|
466 |
+
if training_args.do_train:
|
467 |
+
if "train" not in dataset:
|
468 |
+
raise ValueError("--do_train requires a train dataset")
|
469 |
+
train_dataset = dataset["train"]
|
470 |
+
if data_args.max_train_samples is not None:
|
471 |
+
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
472 |
+
|
473 |
+
train_dataset = train_dataset.map(
|
474 |
+
preprocess_function,
|
475 |
+
batched=True,
|
476 |
+
num_proc=data_args.preprocessing_num_workers,
|
477 |
+
remove_columns=column_names,
|
478 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
479 |
+
desc="Running tokenizer on train dataset",
|
480 |
+
)
|
481 |
+
|
482 |
+
if training_args.do_eval:
|
483 |
+
max_target_length = data_args.val_max_target_length
|
484 |
+
if "validation" not in dataset:
|
485 |
+
raise ValueError("--do_eval requires a validation dataset")
|
486 |
+
eval_dataset = dataset["validation"]
|
487 |
+
if data_args.max_eval_samples is not None:
|
488 |
+
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
489 |
+
eval_dataset = eval_dataset.map(
|
490 |
+
preprocess_function,
|
491 |
+
batched=True,
|
492 |
+
num_proc=data_args.preprocessing_num_workers,
|
493 |
+
remove_columns=column_names,
|
494 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
495 |
+
desc="Running tokenizer on validation dataset",
|
496 |
+
)
|
497 |
+
|
498 |
+
if training_args.do_predict:
|
499 |
+
max_target_length = data_args.val_max_target_length
|
500 |
+
if "test" not in dataset:
|
501 |
+
raise ValueError("--do_predict requires a test dataset")
|
502 |
+
predict_dataset = dataset["test"]
|
503 |
+
if data_args.max_predict_samples is not None:
|
504 |
+
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
505 |
+
predict_dataset = predict_dataset.map(
|
506 |
+
preprocess_function,
|
507 |
+
batched=True,
|
508 |
+
num_proc=data_args.preprocessing_num_workers,
|
509 |
+
remove_columns=column_names,
|
510 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
511 |
+
desc="Running tokenizer on prediction dataset",
|
512 |
+
)
|
513 |
+
|
514 |
+
# Metric
|
515 |
+
metric = load_metric("rouge")
|
516 |
+
|
517 |
+
def postprocess_text(preds, labels):
|
518 |
+
preds = [pred.strip() for pred in preds]
|
519 |
+
labels = [label.strip() for label in labels]
|
520 |
+
|
521 |
+
# rougeLSum expects newline after each sentence
|
522 |
+
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
|
523 |
+
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
|
524 |
+
|
525 |
+
return preds, labels
|
526 |
+
|
527 |
+
def compute_metrics(preds, labels):
|
528 |
+
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
529 |
+
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
530 |
+
|
531 |
+
# Some simple post-processing
|
532 |
+
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
533 |
+
|
534 |
+
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
|
535 |
+
# Extract a few results from ROUGE
|
536 |
+
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
|
537 |
+
|
538 |
+
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
|
539 |
+
result["gen_len"] = np.mean(prediction_lens)
|
540 |
+
result = {k: round(v, 4) for k, v in result.items()}
|
541 |
+
return result
|
542 |
+
|
543 |
+
# Enable tensorboard only on the master node
|
544 |
+
has_tensorboard = is_tensorboard_available()
|
545 |
+
if has_tensorboard and jax.process_index() == 0:
|
546 |
+
try:
|
547 |
+
from flax.metrics.tensorboard import SummaryWriter
|
548 |
+
|
549 |
+
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
550 |
+
except ImportError as ie:
|
551 |
+
has_tensorboard = False
|
552 |
+
logger.warning(
|
553 |
+
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
|
554 |
+
)
|
555 |
+
else:
|
556 |
+
logger.warning(
|
557 |
+
"Unable to display metrics through TensorBoard because the package is not installed: "
|
558 |
+
"Please run pip install tensorboard to enable."
|
559 |
+
)
|
560 |
+
|
561 |
+
# Initialize our training
|
562 |
+
rng = jax.random.PRNGKey(training_args.seed)
|
563 |
+
rng, dropout_rng = jax.random.split(rng)
|
564 |
+
|
565 |
+
# Store some constant
|
566 |
+
num_epochs = int(training_args.num_train_epochs)
|
567 |
+
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
568 |
+
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
569 |
+
steps_per_epoch = len(train_dataset) // train_batch_size
|
570 |
+
total_train_steps = steps_per_epoch * num_epochs
|
571 |
+
|
572 |
+
# Create learning rate schedule
|
573 |
+
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
574 |
+
len(train_dataset),
|
575 |
+
train_batch_size,
|
576 |
+
training_args.num_train_epochs,
|
577 |
+
training_args.warmup_steps,
|
578 |
+
training_args.learning_rate,
|
579 |
+
)
|
580 |
+
|
581 |
+
# We use Optax's "masking" functionality to not apply weight decay
|
582 |
+
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
583 |
+
# mask boolean with the same structure as the parameters.
|
584 |
+
# The mask is True for parameters that should be decayed.
|
585 |
+
# Note that this mask is specifically adapted for FlaxBart.
|
586 |
+
# For FlaxT5, one should correct the layer norm parameter naming
|
587 |
+
# accordingly - see `run_t5_mlm_flax.py` e.g.
|
588 |
+
def decay_mask_fn(params):
|
589 |
+
flat_params = traverse_util.flatten_dict(params)
|
590 |
+
layer_norm_params = [
|
591 |
+
(name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
|
592 |
+
]
|
593 |
+
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
|
594 |
+
return traverse_util.unflatten_dict(flat_mask)
|
595 |
+
|
596 |
+
# create adam optimizer
|
597 |
+
adamw = optax.adamw(
|
598 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
599 |
+
b1=training_args.adam_beta1,
|
600 |
+
b2=training_args.adam_beta2,
|
601 |
+
eps=training_args.adam_epsilon,
|
602 |
+
weight_decay=training_args.weight_decay,
|
603 |
+
mask=decay_mask_fn,
|
604 |
+
)
|
605 |
+
|
606 |
+
# Setup train state
|
607 |
+
state = TrainState.create(apply_fn=vit_gpt2_model.__call__, params=vit_gpt2_model.params, tx=adamw, dropout_rng=dropout_rng)
|
608 |
+
|
609 |
+
# label smoothed cross entropy
|
610 |
+
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
611 |
+
"""
|
612 |
+
The label smoothing implementation is adapted from Flax's official example:
|
613 |
+
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
|
614 |
+
"""
|
615 |
+
vocab_size = logits.shape[-1]
|
616 |
+
confidence = 1.0 - label_smoothing_factor
|
617 |
+
low_confidence = (1.0 - confidence) / (vocab_size - 1)
|
618 |
+
normalizing_constant = -(
|
619 |
+
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
|
620 |
+
)
|
621 |
+
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
|
622 |
+
|
623 |
+
loss = optax.softmax_cross_entropy(logits, soft_labels)
|
624 |
+
loss = loss - normalizing_constant
|
625 |
+
|
626 |
+
# ignore padded tokens from loss
|
627 |
+
loss = loss * padding_mask
|
628 |
+
loss = loss.sum() / padding_mask.sum()
|
629 |
+
return loss
|
630 |
+
|
631 |
+
# Define gradient update step fn
|
632 |
+
def train_step(state, batch, label_smoothing_factor=0.0):
|
633 |
+
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
634 |
+
|
635 |
+
def compute_loss(params):
|
636 |
+
labels = batch.pop("labels")
|
637 |
+
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
638 |
+
loss = loss_fn(logits, labels, batch["attention_mask"], label_smoothing_factor)
|
639 |
+
return loss
|
640 |
+
|
641 |
+
grad_fn = jax.value_and_grad(compute_loss)
|
642 |
+
loss, grad = grad_fn(state.params)
|
643 |
+
grad = jax.lax.pmean(grad, "batch")
|
644 |
+
|
645 |
+
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
646 |
+
|
647 |
+
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
648 |
+
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
649 |
+
|
650 |
+
return new_state, metrics
|
651 |
+
|
652 |
+
# Define eval fn
|
653 |
+
def eval_step(params, batch, label_smoothing_factor=0.0):
|
654 |
+
labels = batch.pop("labels")
|
655 |
+
logits = model(**batch, params=params, train=False)[0]
|
656 |
+
loss = loss_fn(logits, labels, batch["attention_mask"], label_smoothing_factor)
|
657 |
+
|
658 |
+
# summarize metrics
|
659 |
+
metrics = {"loss": loss}
|
660 |
+
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
661 |
+
return metrics
|
662 |
+
|
663 |
+
# Define generation function
|
664 |
+
max_length = (
|
665 |
+
data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length
|
666 |
+
)
|
667 |
+
num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams
|
668 |
+
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
669 |
+
|
670 |
+
def generate_step(params, batch):
|
671 |
+
model.params = params
|
672 |
+
# output_ids = model.generate(batch["pixel_values"], **gen_kwargs)
|
673 |
+
|
674 |
+
#encoder_outputs = model.encode(pixel_values=batch['pixel_values'])
|
675 |
+
#output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], encoder_outputs=encoder_outputs, **gen_kwargs)
|
676 |
+
|
677 |
+
# encoder_outputs = model.encode(pixel_values=batch['pixel_values'], params=params, train=False)
|
678 |
+
output_ids = model.generate(batch['pixel_values'], **gen_kwargs)
|
679 |
+
|
680 |
+
|
681 |
+
return output_ids.sequences
|
682 |
+
|
683 |
+
# Create parallel version of the train and eval step
|
684 |
+
p_train_step = jax.pmap(
|
685 |
+
partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
|
686 |
+
)
|
687 |
+
p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
|
688 |
+
p_generate_step = jax.pmap(generate_step, "batch")
|
689 |
+
|
690 |
+
# Replicate the train state on each device
|
691 |
+
state = state.replicate()
|
692 |
+
|
693 |
+
logger.info("***** Running training *****")
|
694 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
695 |
+
logger.info(f" Num Epochs = {num_epochs}")
|
696 |
+
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
|
697 |
+
logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
|
698 |
+
logger.info(f" Total optimization steps = {total_train_steps}")
|
699 |
+
|
700 |
+
train_time = 0
|
701 |
+
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
702 |
+
for epoch in epochs:
|
703 |
+
# ======================== Training ================================
|
704 |
+
train_start = time.time()
|
705 |
+
|
706 |
+
# Create sampling rng
|
707 |
+
rng, input_rng = jax.random.split(rng)
|
708 |
+
train_metrics = []
|
709 |
+
|
710 |
+
# Generate an epoch by shuffling sampling indices from the train dataset
|
711 |
+
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
712 |
+
steps_per_epoch = len(train_dataset) // train_batch_size
|
713 |
+
# train
|
714 |
+
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
715 |
+
batch = next(train_loader)
|
716 |
+
state, train_metric = p_train_step(state, batch)
|
717 |
+
train_metrics.append(train_metric)
|
718 |
+
|
719 |
+
train_time += time.time() - train_start
|
720 |
+
|
721 |
+
train_metric = unreplicate(train_metric)
|
722 |
+
|
723 |
+
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
724 |
+
epochs.write(desc)
|
725 |
+
epochs.desc = desc
|
726 |
+
logger.info(desc)
|
727 |
+
with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
|
728 |
+
fp.write(desc + '\n')
|
729 |
+
|
730 |
+
|
731 |
+
# ======================== Evaluating ==============================
|
732 |
+
eval_metrics = []
|
733 |
+
eval_preds = []
|
734 |
+
eval_labels = []
|
735 |
+
|
736 |
+
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
|
737 |
+
eval_steps = len(eval_dataset) // eval_batch_size
|
738 |
+
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
|
739 |
+
# Model forward
|
740 |
+
batch = next(eval_loader)
|
741 |
+
labels = batch["labels"]
|
742 |
+
|
743 |
+
metrics = p_eval_step(state.params, batch)
|
744 |
+
eval_metrics.append(metrics)
|
745 |
+
|
746 |
+
# generation
|
747 |
+
if data_args.predict_with_generate:
|
748 |
+
generated_ids = p_generate_step(state.params, batch)
|
749 |
+
eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
750 |
+
eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
|
751 |
+
|
752 |
+
# normalize eval metrics
|
753 |
+
eval_metrics = get_metrics(eval_metrics)
|
754 |
+
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
755 |
+
|
756 |
+
# compute ROUGE metrics
|
757 |
+
rouge_desc = ""
|
758 |
+
if data_args.predict_with_generate:
|
759 |
+
rouge_metrics = compute_metrics(eval_preds, eval_labels)
|
760 |
+
eval_metrics.update(rouge_metrics)
|
761 |
+
rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
|
762 |
+
|
763 |
+
# Print metrics and update progress bar
|
764 |
+
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
|
765 |
+
epochs.write(desc)
|
766 |
+
epochs.desc = desc
|
767 |
+
logger.info(desc)
|
768 |
+
with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
|
769 |
+
fp.write(desc + '\n')
|
770 |
+
|
771 |
+
|
772 |
+
# Save metrics
|
773 |
+
if has_tensorboard and jax.process_index() == 0:
|
774 |
+
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
775 |
+
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
|
776 |
+
|
777 |
+
# ======================== Prediction loop ==============================
|
778 |
+
if training_args.do_predict:
|
779 |
+
logger.info("*** Predict ***")
|
780 |
+
|
781 |
+
pred_metrics = []
|
782 |
+
pred_generations = []
|
783 |
+
pred_labels = []
|
784 |
+
|
785 |
+
pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size)
|
786 |
+
pred_steps = len(predict_dataset) // eval_batch_size
|
787 |
+
for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False):
|
788 |
+
# Model forward
|
789 |
+
batch = next(pred_loader)
|
790 |
+
labels = batch["labels"]
|
791 |
+
|
792 |
+
metrics = p_eval_step(state.params, batch)
|
793 |
+
pred_metrics.append(metrics)
|
794 |
+
|
795 |
+
# generation
|
796 |
+
if data_args.predict_with_generate:
|
797 |
+
generated_ids = p_generate_step(state.params, batch)
|
798 |
+
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
799 |
+
pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
|
800 |
+
|
801 |
+
# normalize prediction metrics
|
802 |
+
pred_metrics = get_metrics(pred_metrics)
|
803 |
+
pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
|
804 |
+
|
805 |
+
# compute ROUGE metrics
|
806 |
+
rouge_desc = ""
|
807 |
+
if data_args.predict_with_generate:
|
808 |
+
rouge_metrics = compute_metrics(pred_generations, pred_labels)
|
809 |
+
pred_metrics.update(rouge_metrics)
|
810 |
+
rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()])
|
811 |
+
|
812 |
+
# Print metrics
|
813 |
+
desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
|
814 |
+
epochs.write(desc)
|
815 |
+
epochs.desc = desc
|
816 |
+
logger.info(desc)
|
817 |
+
with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
|
818 |
+
fp.write(desc + '\n')
|
819 |
+
|
820 |
+
|
821 |
+
# save checkpoint after each epoch and push checkpoint to the hub
|
822 |
+
if jax.process_index() == 0:
|
823 |
+
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
824 |
+
model.save_pretrained(
|
825 |
+
os.path.join(training_args.output_dir, f'ckpt_{epoch+1}'),
|
826 |
+
params=params,
|
827 |
+
push_to_hub=training_args.push_to_hub,
|
828 |
+
commit_message=f"Saving weights and logs of epoch {epoch+1}",
|
829 |
+
)
|
830 |
+
|
831 |
+
if __name__ == "__main__":
|
832 |
+
main()
|
test_vit_gpt2.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, os
|
2 |
+
|
3 |
+
current_path = os.path.dirname(os.path.abspath(__file__))
|
4 |
+
sys.path.append(current_path)
|
5 |
+
|
6 |
+
# Vit - as encoder
|
7 |
+
from transformers import ViTFeatureExtractor
|
8 |
+
from PIL import Image
|
9 |
+
import requests
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
13 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
14 |
+
|
15 |
+
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
|
16 |
+
encoder_inputs = feature_extractor(images=image, return_tensors="jax")
|
17 |
+
pixel_values = encoder_inputs.pixel_values
|
18 |
+
|
19 |
+
# GPT2 / GPT2LM - as decoder
|
20 |
+
from transformers import ViTFeatureExtractor, GPT2Tokenizer
|
21 |
+
|
22 |
+
name = 'asi/gpt-fr-cased-small'
|
23 |
+
tokenizer = GPT2Tokenizer.from_pretrained(name)
|
24 |
+
decoder_inputs = tokenizer("mon chien est mignon", return_tensors="jax")
|
25 |
+
|
26 |
+
inputs = dict(decoder_inputs)
|
27 |
+
inputs['pixel_values'] = pixel_values
|
28 |
+
print(inputs)
|
29 |
+
|
30 |
+
# With new added LM head
|
31 |
+
from vit_gpt2.modeling_flax_vit_gpt2 import FlaxViTGPT2ForConditionalGeneration
|
32 |
+
flax_vit_gpt2 = FlaxViTGPT2ForConditionalGeneration.from_vit_gpt2_pretrained(
|
33 |
+
'google/vit-base-patch16-224-in21k', 'asi/gpt-fr-cased-small'
|
34 |
+
)
|
35 |
+
logits = flax_vit_gpt2(**inputs)[0]
|
36 |
+
preds = np.argmax(logits, axis=-1)
|
37 |
+
print('=' * 60)
|
38 |
+
print('Flax: Vit + modified GPT2 + LM')
|
39 |
+
print(preds)
|
40 |
+
|
41 |
+
del flax_vit_gpt2
|
42 |
+
|
43 |
+
# With the LM head in GPT2LM
|
44 |
+
from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
|
45 |
+
flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_vit_gpt2_pretrained(
|
46 |
+
'google/vit-base-patch16-224-in21k', 'asi/gpt-fr-cased-small'
|
47 |
+
)
|
48 |
+
|
49 |
+
logits = flax_vit_gpt2_lm(**inputs)[0]
|
50 |
+
preds = np.argmax(logits, axis=-1)
|
51 |
+
print('=' * 60)
|
52 |
+
print('Flax: Vit + modified GPT2LM')
|
53 |
+
print(preds)
|
54 |
+
|
55 |
+
del flax_vit_gpt2_lm
|
56 |
+
|
57 |
+
# With PyTorch [Vit + unmodified GPT2LMHeadModel]
|
58 |
+
import torch
|
59 |
+
from transformers import ViTModel, GPT2Config, GPT2LMHeadModel
|
60 |
+
|
61 |
+
vit_model_pt = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
|
62 |
+
encoder_inputs = feature_extractor(images=image, return_tensors="pt")
|
63 |
+
vit_outputs = vit_model_pt(**encoder_inputs)
|
64 |
+
vit_last_hidden_states = vit_outputs.last_hidden_state
|
65 |
+
|
66 |
+
del vit_model_pt
|
67 |
+
|
68 |
+
inputs_pt = tokenizer("mon chien est mignon", return_tensors="pt")
|
69 |
+
inputs_pt = dict(inputs_pt)
|
70 |
+
inputs_pt['encoder_hidden_states'] = vit_last_hidden_states
|
71 |
+
|
72 |
+
config = GPT2Config.from_pretrained('asi/gpt-fr-cased-small')
|
73 |
+
config.add_cross_attention = True
|
74 |
+
gpt2_model_pt = GPT2LMHeadModel.from_pretrained('asi/gpt-fr-cased-small', config=config)
|
75 |
+
|
76 |
+
gp2lm_outputs = gpt2_model_pt(**inputs_pt)
|
77 |
+
logits_pt = gp2lm_outputs.logits
|
78 |
+
preds_pt = torch.argmax(logits_pt, dim=-1).cpu().detach().numpy()
|
79 |
+
print('=' * 60)
|
80 |
+
print('Pytorch: Vit + unmodified GPT2LM')
|
81 |
+
print(preds_pt)
|
82 |
+
|
83 |
+
del gpt2_model_pt
|
test_wit_dataset_script.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
import datasets
|
6 |
+
import pandas as pd
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
ds = datasets.load_dataset('./wit_dataset_script.py', data_dir='./wit_data_dir/')
|
10 |
+
test_ds = ds['test']
|
11 |
+
|
12 |
+
|
13 |
+
def transform(example):
|
14 |
+
|
15 |
+
example['pixel_values'] = np.load(example['pixels_file'])
|
16 |
+
return example
|
17 |
+
|
18 |
+
|
19 |
+
test_ds = test_ds.map(transform)
|
20 |
+
|
21 |
+
for x in test_ds:
|
22 |
+
print(x)
|
23 |
+
break
|
tests_load.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, os
|
2 |
+
|
3 |
+
current_path = os.path.dirname(os.path.abspath(__file__))
|
4 |
+
sys.path.append(current_path)
|
5 |
+
|
6 |
+
# Vit - as encoder
|
7 |
+
from transformers import ViTFeatureExtractor
|
8 |
+
from PIL import Image
|
9 |
+
import requests
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
13 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
14 |
+
|
15 |
+
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
|
16 |
+
encoder_inputs = feature_extractor(images=image, return_tensors="jax")
|
17 |
+
pixel_values = encoder_inputs.pixel_values
|
18 |
+
|
19 |
+
# GPT2 / GPT2LM - as decoder
|
20 |
+
from transformers import ViTFeatureExtractor, GPT2Tokenizer
|
21 |
+
|
22 |
+
name = 'asi/gpt-fr-cased-small'
|
23 |
+
tokenizer = GPT2Tokenizer.from_pretrained(name)
|
24 |
+
decoder_inputs = tokenizer("mon chien est mignon", return_tensors="jax")
|
25 |
+
|
26 |
+
inputs = dict(decoder_inputs)
|
27 |
+
inputs['pixel_values'] = pixel_values
|
28 |
+
print(inputs)
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
# With the LM head in GPT2LM
|
35 |
+
from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
|
36 |
+
flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_pretrained(
|
37 |
+
'.',
|
38 |
+
)
|
39 |
+
|
40 |
+
logits = flax_vit_gpt2_lm(**inputs)[0]
|
41 |
+
preds = np.argmax(logits, axis=-1)
|
42 |
+
print('=' * 60)
|
43 |
+
print('Flax: Vit + modified GPT2LM')
|
44 |
+
print(preds)
|
45 |
+
|
46 |
+
# flax_vit_gpt2_lm.save_pretrained('.')
|
47 |
+
|
48 |
+
del flax_vit_gpt2_lm
|
tests_save.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, os
|
2 |
+
|
3 |
+
current_path = os.path.dirname(os.path.abspath(__file__))
|
4 |
+
sys.path.append(current_path)
|
5 |
+
|
6 |
+
# Vit - as encoder
|
7 |
+
from transformers import ViTFeatureExtractor
|
8 |
+
from PIL import Image
|
9 |
+
import requests
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
13 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
14 |
+
|
15 |
+
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
|
16 |
+
encoder_inputs = feature_extractor(images=image, return_tensors="jax")
|
17 |
+
pixel_values = encoder_inputs.pixel_values
|
18 |
+
|
19 |
+
# GPT2 / GPT2LM - as decoder
|
20 |
+
from transformers import ViTFeatureExtractor, GPT2Tokenizer
|
21 |
+
|
22 |
+
name = 'asi/gpt-fr-cased-small'
|
23 |
+
tokenizer = GPT2Tokenizer.from_pretrained(name)
|
24 |
+
decoder_inputs = tokenizer("mon chien est mignon", return_tensors="jax")
|
25 |
+
|
26 |
+
inputs = dict(decoder_inputs)
|
27 |
+
inputs['pixel_values'] = pixel_values
|
28 |
+
print(inputs)
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
# With the LM head in GPT2LM
|
35 |
+
from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
|
36 |
+
flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_vit_gpt2_pretrained(
|
37 |
+
'google/vit-base-patch16-224-in21k', 'asi/gpt-fr-cased-small'
|
38 |
+
)
|
39 |
+
|
40 |
+
logits = flax_vit_gpt2_lm(**inputs)[0]
|
41 |
+
preds = np.argmax(logits, axis=-1)
|
42 |
+
print('=' * 60)
|
43 |
+
print('Flax: Vit + modified GPT2LM')
|
44 |
+
print(preds)
|
45 |
+
|
46 |
+
flax_vit_gpt2_lm.save_pretrained('.')
|
47 |
+
|
48 |
+
del flax_vit_gpt2_lm
|
vit_gpt2/__init__.py
ADDED
File without changes
|
vit_gpt2/configuration_vit_gpt2.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
from transformers import GPT2Config, ViTConfig
|
4 |
+
from transformers.configuration_utils import PretrainedConfig
|
5 |
+
from transformers.utils import logging
|
6 |
+
|
7 |
+
logger = logging.get_logger(__name__)
|
8 |
+
|
9 |
+
|
10 |
+
class ViTGPT2Config(PretrainedConfig):
|
11 |
+
|
12 |
+
model_type = "vit-gpt2"
|
13 |
+
is_composition = True
|
14 |
+
|
15 |
+
def __init__(self, **kwargs):
|
16 |
+
super().__init__(**kwargs)
|
17 |
+
|
18 |
+
if "vit_config" not in kwargs:
|
19 |
+
raise ValueError("`vit_config` can not be `None`.")
|
20 |
+
|
21 |
+
if "gpt2_config" not in kwargs:
|
22 |
+
raise ValueError("`gpt2_config` can not be `None`.")
|
23 |
+
|
24 |
+
vit_config = kwargs.pop("vit_config")
|
25 |
+
gpt2_config = kwargs.pop("gpt2_config")
|
26 |
+
|
27 |
+
self.vit_config = ViTConfig(**vit_config)
|
28 |
+
self.gpt2_config = GPT2Config(**gpt2_config)
|
29 |
+
|
30 |
+
@classmethod
|
31 |
+
def from_vit_gpt2_configs(
|
32 |
+
cls, vit_config: PretrainedConfig, gpt2_config: PretrainedConfig, **kwargs
|
33 |
+
):
|
34 |
+
return cls(
|
35 |
+
vit_config=vit_config.to_dict(),
|
36 |
+
gpt2_config=gpt2_config.to_dict(),
|
37 |
+
**kwargs
|
38 |
+
)
|
39 |
+
|
40 |
+
def to_dict(self):
|
41 |
+
output = copy.deepcopy(self.__dict__)
|
42 |
+
output["vit_config"] = self.vit_config.to_dict()
|
43 |
+
output["gpt2_config"] = self.gpt2_config.to_dict()
|
44 |
+
output["model_type"] = self.__class__.model_type
|
45 |
+
return output
|
vit_gpt2/modeling_flax_gpt2.py
ADDED
@@ -0,0 +1,752 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from typing import Any, Optional, Tuple
|
17 |
+
|
18 |
+
import flax.linen as nn
|
19 |
+
import jax
|
20 |
+
import jax.numpy as jnp
|
21 |
+
from flax.core.frozen_dict import FrozenDict, unfreeze
|
22 |
+
from flax.linen import combine_masks, make_causal_mask
|
23 |
+
from flax.linen.attention import dot_product_attention_weights
|
24 |
+
from jax import lax
|
25 |
+
|
26 |
+
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
27 |
+
from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPast, FlaxCausalLMOutput, FlaxBaseModelOutputWithPastAndCrossAttentions, FlaxSeq2SeqLMOutput
|
28 |
+
from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
|
29 |
+
from transformers.utils import logging
|
30 |
+
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
31 |
+
|
32 |
+
|
33 |
+
logger = logging.get_logger(__name__)
|
34 |
+
|
35 |
+
_CHECKPOINT_FOR_DOC = "gpt2"
|
36 |
+
_CONFIG_FOR_DOC = "GPT2Config"
|
37 |
+
_TOKENIZER_FOR_DOC = "GPT2Tokenizer"
|
38 |
+
|
39 |
+
|
40 |
+
GPT2_START_DOCSTRING = r"""
|
41 |
+
|
42 |
+
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
|
43 |
+
generic methods the library implements for all its model (such as downloading or saving, resizing the input
|
44 |
+
embeddings, pruning heads etc.)
|
45 |
+
|
46 |
+
This model is also a Flax Linen `flax.nn.Module
|
47 |
+
<https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
|
48 |
+
Module and refer to the Flax documentation for all matter related to general usage and behavior.
|
49 |
+
|
50 |
+
Finally, this model supports inherent JAX features such as:
|
51 |
+
|
52 |
+
- `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__
|
53 |
+
- `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__
|
54 |
+
- `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__
|
55 |
+
- `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__
|
56 |
+
|
57 |
+
Parameters:
|
58 |
+
config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.
|
59 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
60 |
+
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
|
61 |
+
model weights.
|
62 |
+
"""
|
63 |
+
|
64 |
+
GPT2_INPUTS_DOCSTRING = r"""
|
65 |
+
Args:
|
66 |
+
input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, input_ids_length)`):
|
67 |
+
:obj:`input_ids_length` = ``sequence_length``. Indices of input sequence tokens in the vocabulary.
|
68 |
+
|
69 |
+
Indices can be obtained using :class:`~transformers.GPT2Tokenizer`. See
|
70 |
+
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
71 |
+
details.
|
72 |
+
|
73 |
+
`What are input IDs? <../glossary.html#input-ids>`__
|
74 |
+
attention_mask (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
75 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
76 |
+
|
77 |
+
- 1 for tokens that are **not masked**,
|
78 |
+
- 0 for tokens that are **masked**.
|
79 |
+
|
80 |
+
`What are attention masks? <../glossary.html#attention-mask>`__
|
81 |
+
position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
82 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
83 |
+
config.max_position_embeddings - 1]``.
|
84 |
+
past_key_values (:obj:`Dict[str, np.ndarray]`, `optional`, returned by ``init_cache`` or when passing previous ``past_key_values``):
|
85 |
+
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
|
86 |
+
auto-regressive decoding. Pre-computed key and value hidden-states are of shape `[batch_size, max_length]`.
|
87 |
+
output_attentions (:obj:`bool`, `optional`):
|
88 |
+
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
89 |
+
tensors for more detail.
|
90 |
+
output_hidden_states (:obj:`bool`, `optional`):
|
91 |
+
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
92 |
+
more detail.
|
93 |
+
return_dict (:obj:`bool`, `optional`):
|
94 |
+
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
95 |
+
"""
|
96 |
+
|
97 |
+
|
98 |
+
class FlaxConv1D(nn.Module):
|
99 |
+
features: int
|
100 |
+
use_bias: bool = True
|
101 |
+
dtype: Any = jnp.float32
|
102 |
+
precision: Any = None
|
103 |
+
|
104 |
+
@nn.compact
|
105 |
+
def __call__(self, inputs):
|
106 |
+
inputs = jnp.asarray(inputs, self.dtype)
|
107 |
+
kernel = self.param("kernel", jax.nn.initializers.normal(stddev=0.02), (self.features, inputs.shape[-1]))
|
108 |
+
kernel = jnp.asarray(kernel.transpose(), self.dtype)
|
109 |
+
y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=self.precision)
|
110 |
+
if self.use_bias:
|
111 |
+
bias = self.param("bias", jax.nn.initializers.zeros, (self.features,))
|
112 |
+
bias = jnp.asarray(bias, self.dtype)
|
113 |
+
y = y + bias
|
114 |
+
return y
|
115 |
+
|
116 |
+
|
117 |
+
class FlaxGPT2Attention(nn.Module):
|
118 |
+
config: GPT2Config
|
119 |
+
dtype: jnp.dtype = jnp.float32
|
120 |
+
causal: bool = True
|
121 |
+
|
122 |
+
def setup(self):
|
123 |
+
config = self.config
|
124 |
+
self.embed_dim = config.hidden_size
|
125 |
+
self.num_heads = config.num_attention_heads
|
126 |
+
self.head_dim = self.embed_dim // self.num_heads
|
127 |
+
|
128 |
+
self.c_attn = FlaxConv1D(features=3 * self.embed_dim, dtype=self.dtype)
|
129 |
+
self.c_proj = FlaxConv1D(self.embed_dim, dtype=self.dtype)
|
130 |
+
|
131 |
+
self.c_attn_for_k_v = FlaxConv1D(features=2 * self.embed_dim, dtype=self.dtype)
|
132 |
+
|
133 |
+
self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)
|
134 |
+
|
135 |
+
if self.causal:
|
136 |
+
self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
|
137 |
+
|
138 |
+
def _split_heads(self, hidden_states):
|
139 |
+
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
|
140 |
+
|
141 |
+
def _merge_heads(self, hidden_states):
|
142 |
+
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
|
143 |
+
|
144 |
+
@nn.compact
|
145 |
+
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
146 |
+
"""
|
147 |
+
This function takes projected key, value states from a single input token and concatenates the states to cached
|
148 |
+
states from previous steps. This function is slighly adapted from the official Flax repository:
|
149 |
+
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
150 |
+
"""
|
151 |
+
# detect if we're initializing by absence of existing cache data.
|
152 |
+
is_initialized = self.has_variable("cache", "cached_key")
|
153 |
+
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
|
154 |
+
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
|
155 |
+
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
|
156 |
+
|
157 |
+
if is_initialized:
|
158 |
+
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
|
159 |
+
# update key, value caches with our new 1d spatial slices
|
160 |
+
cur_index = cache_index.value
|
161 |
+
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
|
162 |
+
key = lax.dynamic_update_slice(cached_key.value, key, indices)
|
163 |
+
value = lax.dynamic_update_slice(cached_value.value, value, indices)
|
164 |
+
cached_key.value = key
|
165 |
+
cached_value.value = value
|
166 |
+
num_updated_cache_vectors = query.shape[1]
|
167 |
+
cache_index.value = cache_index.value + num_updated_cache_vectors
|
168 |
+
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
|
169 |
+
pad_mask = jnp.broadcast_to(
|
170 |
+
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
|
171 |
+
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
|
172 |
+
)
|
173 |
+
attention_mask = combine_masks(pad_mask, attention_mask)
|
174 |
+
return key, value, attention_mask
|
175 |
+
|
176 |
+
def __call__(
|
177 |
+
self,
|
178 |
+
hidden_states,
|
179 |
+
key_value_states: Optional[jnp.ndarray] = None,
|
180 |
+
attention_mask=None,
|
181 |
+
deterministic: bool = True,
|
182 |
+
init_cache: bool = False,
|
183 |
+
output_attentions: bool = False,
|
184 |
+
):
|
185 |
+
|
186 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
187 |
+
# for the decoder
|
188 |
+
is_cross_attention = key_value_states is not None
|
189 |
+
|
190 |
+
qkv_out = self.c_attn(hidden_states)
|
191 |
+
query, key, value = jnp.split(qkv_out, 3, axis=2)
|
192 |
+
|
193 |
+
if is_cross_attention:
|
194 |
+
_qkv_out = self.c_attn_for_k_v(key_value_states)
|
195 |
+
key, value = jnp.split(_qkv_out, 2, axis=2)
|
196 |
+
|
197 |
+
query = self._split_heads(query)
|
198 |
+
key = self._split_heads(key)
|
199 |
+
value = self._split_heads(value)
|
200 |
+
|
201 |
+
query_length, key_length = query.shape[1], key.shape[1]
|
202 |
+
|
203 |
+
if self.causal:
|
204 |
+
if self.has_variable("cache", "cached_key"):
|
205 |
+
mask_shift = self.variables["cache"]["cache_index"]
|
206 |
+
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
207 |
+
causal_mask = lax.dynamic_slice(
|
208 |
+
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
|
209 |
+
)
|
210 |
+
else:
|
211 |
+
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
|
212 |
+
|
213 |
+
batch_size = hidden_states.shape[0]
|
214 |
+
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
|
215 |
+
|
216 |
+
# combine masks if needed
|
217 |
+
if attention_mask is not None and self.causal:
|
218 |
+
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
|
219 |
+
attention_mask = combine_masks(attention_mask, causal_mask)
|
220 |
+
elif self.causal:
|
221 |
+
attention_mask = causal_mask
|
222 |
+
elif attention_mask is not None:
|
223 |
+
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
224 |
+
|
225 |
+
dropout_rng = None
|
226 |
+
if not deterministic and self.config.attn_pdrop > 0.0:
|
227 |
+
dropout_rng = self.make_rng("dropout")
|
228 |
+
|
229 |
+
# During fast autoregressive decoding, we feed one position at a time,
|
230 |
+
# and cache the keys and values step by step.
|
231 |
+
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
|
232 |
+
key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
|
233 |
+
|
234 |
+
# transform boolean mask into float mask
|
235 |
+
if attention_mask is not None:
|
236 |
+
attention_bias = lax.select(
|
237 |
+
attention_mask > 0,
|
238 |
+
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
239 |
+
jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
|
240 |
+
)
|
241 |
+
else:
|
242 |
+
attention_bias = None
|
243 |
+
|
244 |
+
# usual dot product attention
|
245 |
+
attn_weights = dot_product_attention_weights(
|
246 |
+
query,
|
247 |
+
key,
|
248 |
+
bias=attention_bias,
|
249 |
+
dropout_rng=dropout_rng,
|
250 |
+
dropout_rate=self.config.attn_pdrop,
|
251 |
+
deterministic=deterministic,
|
252 |
+
dtype=self.dtype,
|
253 |
+
precision=None,
|
254 |
+
)
|
255 |
+
|
256 |
+
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
|
257 |
+
attn_output = self._merge_heads(attn_output)
|
258 |
+
attn_output = self.c_proj(attn_output)
|
259 |
+
attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
|
260 |
+
|
261 |
+
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
262 |
+
return outputs
|
263 |
+
|
264 |
+
|
265 |
+
class FlaxGPT2MLP(nn.Module):
|
266 |
+
config: GPT2Config
|
267 |
+
intermediate_size: int
|
268 |
+
dtype: jnp.dtype = jnp.float32
|
269 |
+
|
270 |
+
def setup(self):
|
271 |
+
embed_dim = self.config.hidden_size
|
272 |
+
self.c_fc = FlaxConv1D(self.intermediate_size, dtype=self.dtype)
|
273 |
+
self.c_proj = FlaxConv1D(embed_dim, dtype=self.dtype)
|
274 |
+
self.act = ACT2FN[self.config.activation_function]
|
275 |
+
self.dropout = nn.Dropout(rate=self.config.resid_pdrop)
|
276 |
+
|
277 |
+
def __call__(self, hidden_states, deterministic: bool = True):
|
278 |
+
hidden_states = self.c_fc(hidden_states)
|
279 |
+
hidden_states = self.act(hidden_states)
|
280 |
+
hidden_states = self.c_proj(hidden_states)
|
281 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
282 |
+
return hidden_states
|
283 |
+
|
284 |
+
|
285 |
+
class FlaxGPT2Block(nn.Module):
|
286 |
+
config: GPT2Config
|
287 |
+
dtype: jnp.dtype = jnp.float32
|
288 |
+
|
289 |
+
def setup(self):
|
290 |
+
hidden_size = self.config.hidden_size
|
291 |
+
inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * hidden_size
|
292 |
+
|
293 |
+
self.ln_1 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
294 |
+
self.attn = FlaxGPT2Attention(self.config, dtype=self.dtype)
|
295 |
+
self.ln_3 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
296 |
+
self.encoder_attn = FlaxGPT2Attention(config=self.config, dtype=self.dtype)
|
297 |
+
self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
298 |
+
self.mlp = FlaxGPT2MLP(self.config, inner_dim, dtype=self.dtype)
|
299 |
+
|
300 |
+
def __call__(
|
301 |
+
self,
|
302 |
+
hidden_states,
|
303 |
+
attention_mask=None,
|
304 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
305 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
306 |
+
deterministic: bool = True,
|
307 |
+
init_cache: bool = False,
|
308 |
+
output_attentions: bool = False,
|
309 |
+
):
|
310 |
+
residual = hidden_states
|
311 |
+
hidden_states = self.ln_1(hidden_states)
|
312 |
+
outputs = self.attn(
|
313 |
+
hidden_states,
|
314 |
+
attention_mask=attention_mask,
|
315 |
+
deterministic=deterministic,
|
316 |
+
init_cache=init_cache,
|
317 |
+
output_attentions=output_attentions,
|
318 |
+
)
|
319 |
+
# residual connection
|
320 |
+
attn_output = outputs[0]
|
321 |
+
hidden_states = attn_output + residual
|
322 |
+
|
323 |
+
# Cross-Attention Block
|
324 |
+
if encoder_hidden_states is not None:
|
325 |
+
|
326 |
+
residual = hidden_states
|
327 |
+
hidden_states = self.ln_3(hidden_states)
|
328 |
+
|
329 |
+
cross_attn_outputs = self.encoder_attn(
|
330 |
+
hidden_states=hidden_states,
|
331 |
+
key_value_states=encoder_hidden_states,
|
332 |
+
attention_mask=encoder_attention_mask,
|
333 |
+
deterministic=deterministic,
|
334 |
+
output_attentions=output_attentions,
|
335 |
+
)
|
336 |
+
|
337 |
+
# residual connection
|
338 |
+
cross_attn_output = cross_attn_outputs[0]
|
339 |
+
hidden_states = cross_attn_output + residual
|
340 |
+
|
341 |
+
residual = hidden_states
|
342 |
+
hidden_states = self.ln_2(hidden_states)
|
343 |
+
feed_forward_hidden_states = self.mlp(hidden_states, deterministic=deterministic)
|
344 |
+
# residual connection
|
345 |
+
hidden_states = residual + feed_forward_hidden_states
|
346 |
+
|
347 |
+
output = (hidden_states,) + outputs[1:]
|
348 |
+
if encoder_hidden_states is not None:
|
349 |
+
output = output + cross_attn_outputs[1:]
|
350 |
+
|
351 |
+
return output
|
352 |
+
|
353 |
+
|
354 |
+
class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
|
355 |
+
"""
|
356 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
357 |
+
models.
|
358 |
+
"""
|
359 |
+
|
360 |
+
config_class = GPT2Config
|
361 |
+
base_model_prefix = "transformer"
|
362 |
+
module_class: nn.Module = None
|
363 |
+
|
364 |
+
def __init__(
|
365 |
+
self,
|
366 |
+
config: GPT2Config,
|
367 |
+
input_shape: Tuple = (1, 1),
|
368 |
+
seed: int = 0,
|
369 |
+
dtype: jnp.dtype = jnp.float32,
|
370 |
+
**kwargs,
|
371 |
+
):
|
372 |
+
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
373 |
+
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
374 |
+
|
375 |
+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
376 |
+
# init input tensors
|
377 |
+
input_ids = jnp.zeros(input_shape, dtype="i4")
|
378 |
+
attention_mask = jnp.ones_like(input_ids)
|
379 |
+
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
|
380 |
+
params_rng, dropout_rng = jax.random.split(rng)
|
381 |
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
382 |
+
|
383 |
+
if self.config.add_cross_attention:
|
384 |
+
encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,))
|
385 |
+
encoder_attention_mask = attention_mask
|
386 |
+
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, encoder_hidden_states, encoder_attention_mask, return_dict=False)
|
387 |
+
else:
|
388 |
+
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
|
389 |
+
|
390 |
+
return module_init_outputs["params"]
|
391 |
+
|
392 |
+
@classmethod
|
393 |
+
def _from_config(cls, config, **kwargs):
|
394 |
+
return super()._from_config(config, **kwargs)
|
395 |
+
|
396 |
+
def init_cache(self, batch_size, max_length):
|
397 |
+
r"""
|
398 |
+
Args:
|
399 |
+
batch_size (:obj:`int`):
|
400 |
+
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
|
401 |
+
max_length (:obj:`int`):
|
402 |
+
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
|
403 |
+
cache.
|
404 |
+
"""
|
405 |
+
# init input variables to retrieve cache
|
406 |
+
input_ids = jnp.ones((batch_size, max_length))
|
407 |
+
attention_mask = jnp.ones_like(input_ids)
|
408 |
+
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
409 |
+
|
410 |
+
init_variables = self.module.init(
|
411 |
+
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
|
412 |
+
)
|
413 |
+
return init_variables["cache"]
|
414 |
+
|
415 |
+
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
416 |
+
def __call__(
|
417 |
+
self,
|
418 |
+
input_ids,
|
419 |
+
attention_mask=None,
|
420 |
+
position_ids=None,
|
421 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
422 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
423 |
+
params: dict = None,
|
424 |
+
past_key_values: dict = None,
|
425 |
+
dropout_rng: jax.random.PRNGKey = None,
|
426 |
+
train: bool = False,
|
427 |
+
output_attentions: Optional[bool] = None,
|
428 |
+
output_hidden_states: Optional[bool] = None,
|
429 |
+
return_dict: Optional[bool] = None,
|
430 |
+
):
|
431 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
432 |
+
output_hidden_states = (
|
433 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
434 |
+
)
|
435 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
436 |
+
|
437 |
+
if encoder_hidden_states is not None and encoder_attention_mask is None:
|
438 |
+
batch_size, sequence_length = encoder_hidden_states.shape[:2]
|
439 |
+
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
|
440 |
+
|
441 |
+
batch_size, sequence_length = input_ids.shape
|
442 |
+
|
443 |
+
if position_ids is None:
|
444 |
+
if past_key_values is not None:
|
445 |
+
raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
|
446 |
+
|
447 |
+
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
|
448 |
+
|
449 |
+
if attention_mask is None:
|
450 |
+
attention_mask = jnp.ones((batch_size, sequence_length))
|
451 |
+
|
452 |
+
# Handle any PRNG if needed
|
453 |
+
rngs = {}
|
454 |
+
if dropout_rng is not None:
|
455 |
+
rngs["dropout"] = dropout_rng
|
456 |
+
|
457 |
+
inputs = {"params": params or self.params}
|
458 |
+
|
459 |
+
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPT2Attention module
|
460 |
+
if past_key_values:
|
461 |
+
inputs["cache"] = past_key_values
|
462 |
+
mutable = ["cache"]
|
463 |
+
else:
|
464 |
+
mutable = False
|
465 |
+
|
466 |
+
outputs = self.module.apply(
|
467 |
+
inputs,
|
468 |
+
jnp.array(input_ids, dtype="i4"),
|
469 |
+
jnp.array(attention_mask, dtype="i4"),
|
470 |
+
jnp.array(position_ids, dtype="i4"),
|
471 |
+
encoder_hidden_states,
|
472 |
+
encoder_attention_mask,
|
473 |
+
not train,
|
474 |
+
False,
|
475 |
+
output_attentions,
|
476 |
+
output_hidden_states,
|
477 |
+
return_dict,
|
478 |
+
rngs=rngs,
|
479 |
+
mutable=mutable,
|
480 |
+
)
|
481 |
+
|
482 |
+
# add updated cache to model output
|
483 |
+
if past_key_values is not None and return_dict:
|
484 |
+
outputs, past_key_values = outputs
|
485 |
+
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
|
486 |
+
return outputs
|
487 |
+
elif past_key_values is not None and not return_dict:
|
488 |
+
outputs, past_key_values = outputs
|
489 |
+
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
|
490 |
+
|
491 |
+
return outputs
|
492 |
+
|
493 |
+
|
494 |
+
class FlaxGPT2BlockCollection(nn.Module):
|
495 |
+
config: GPT2Config
|
496 |
+
dtype: jnp.dtype = jnp.float32
|
497 |
+
|
498 |
+
def setup(self):
|
499 |
+
self.blocks = [
|
500 |
+
FlaxGPT2Block(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
|
501 |
+
]
|
502 |
+
|
503 |
+
def __call__(
|
504 |
+
self,
|
505 |
+
hidden_states,
|
506 |
+
attention_mask=None,
|
507 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
508 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
509 |
+
deterministic: bool = True,
|
510 |
+
init_cache: bool = False,
|
511 |
+
output_attentions: bool = False,
|
512 |
+
output_hidden_states: bool = False,
|
513 |
+
return_dict: bool = True,
|
514 |
+
):
|
515 |
+
all_attentions = () if output_attentions else None
|
516 |
+
all_hidden_states = () if output_hidden_states else None
|
517 |
+
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
518 |
+
|
519 |
+
for block in self.blocks:
|
520 |
+
if output_hidden_states:
|
521 |
+
all_hidden_states += (hidden_states,)
|
522 |
+
|
523 |
+
layer_outputs = block(
|
524 |
+
hidden_states,
|
525 |
+
attention_mask,
|
526 |
+
encoder_hidden_states=encoder_hidden_states,
|
527 |
+
encoder_attention_mask=encoder_attention_mask,
|
528 |
+
deterministic=deterministic,
|
529 |
+
init_cache=init_cache,
|
530 |
+
output_attentions=output_attentions,
|
531 |
+
)
|
532 |
+
hidden_states = layer_outputs[0]
|
533 |
+
|
534 |
+
if output_attentions:
|
535 |
+
all_attentions += (layer_outputs[1],)
|
536 |
+
if encoder_hidden_states is not None:
|
537 |
+
all_cross_attentions += (layer_outputs[2],)
|
538 |
+
|
539 |
+
if output_hidden_states:
|
540 |
+
all_hidden_states += (hidden_states,)
|
541 |
+
|
542 |
+
outputs = [hidden_states, all_hidden_states, all_attentions, all_cross_attentions]
|
543 |
+
|
544 |
+
if not return_dict:
|
545 |
+
return tuple(v for v in outputs if v is not None)
|
546 |
+
|
547 |
+
if encoder_hidden_states is None:
|
548 |
+
return FlaxBaseModelOutputWithPast(
|
549 |
+
last_hidden_state=hidden_states,
|
550 |
+
past_key_values=None,
|
551 |
+
hidden_states=all_hidden_states,
|
552 |
+
attentions=all_attentions,
|
553 |
+
)
|
554 |
+
else:
|
555 |
+
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
556 |
+
last_hidden_state=hidden_states,
|
557 |
+
past_key_values=None,
|
558 |
+
hidden_states=all_hidden_states,
|
559 |
+
attentions=all_attentions,
|
560 |
+
cross_attentions=all_cross_attentions,
|
561 |
+
)
|
562 |
+
|
563 |
+
class FlaxGPT2Module(nn.Module):
|
564 |
+
config: GPT2Config
|
565 |
+
dtype: jnp.dtype = jnp.float32
|
566 |
+
|
567 |
+
def setup(self):
|
568 |
+
self.embed_dim = self.config.hidden_size
|
569 |
+
|
570 |
+
self.wte = nn.Embed(
|
571 |
+
self.config.vocab_size,
|
572 |
+
self.embed_dim,
|
573 |
+
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
574 |
+
dtype=self.dtype,
|
575 |
+
)
|
576 |
+
self.wpe = nn.Embed(
|
577 |
+
self.config.max_position_embeddings,
|
578 |
+
self.embed_dim,
|
579 |
+
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
580 |
+
dtype=self.dtype,
|
581 |
+
)
|
582 |
+
self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
|
583 |
+
self.h = FlaxGPT2BlockCollection(self.config, dtype=self.dtype)
|
584 |
+
self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
585 |
+
|
586 |
+
def __call__(
|
587 |
+
self,
|
588 |
+
input_ids,
|
589 |
+
attention_mask,
|
590 |
+
position_ids,
|
591 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
592 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
593 |
+
deterministic=True,
|
594 |
+
init_cache: bool = False,
|
595 |
+
output_attentions: bool = False,
|
596 |
+
output_hidden_states: bool = False,
|
597 |
+
return_dict: bool = True,
|
598 |
+
):
|
599 |
+
input_embeds = self.wte(input_ids.astype("i4"))
|
600 |
+
position_embeds = self.wpe(position_ids.astype("i4"))
|
601 |
+
|
602 |
+
hidden_states = input_embeds + position_embeds
|
603 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
604 |
+
|
605 |
+
outputs = self.h(
|
606 |
+
hidden_states,
|
607 |
+
attention_mask,
|
608 |
+
encoder_hidden_states,
|
609 |
+
encoder_attention_mask,
|
610 |
+
deterministic=deterministic,
|
611 |
+
init_cache=init_cache,
|
612 |
+
output_attentions=output_attentions,
|
613 |
+
output_hidden_states=output_hidden_states,
|
614 |
+
return_dict=return_dict,
|
615 |
+
)
|
616 |
+
|
617 |
+
hidden_states = outputs[0]
|
618 |
+
hidden_states = self.ln_f(hidden_states)
|
619 |
+
|
620 |
+
if not return_dict:
|
621 |
+
return (hidden_states,) + outputs[1:]
|
622 |
+
|
623 |
+
if encoder_hidden_states is None:
|
624 |
+
return FlaxBaseModelOutput(
|
625 |
+
last_hidden_state=hidden_states,
|
626 |
+
hidden_states=outputs.hidden_states,
|
627 |
+
attentions=outputs.attentions,
|
628 |
+
)
|
629 |
+
else:
|
630 |
+
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
631 |
+
last_hidden_state=hidden_states,
|
632 |
+
hidden_states=outputs.hidden_states,
|
633 |
+
attentions=outputs.attentions,
|
634 |
+
cross_attentions=outputs.cross_attentions,
|
635 |
+
)
|
636 |
+
|
637 |
+
@add_start_docstrings(
|
638 |
+
"The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
|
639 |
+
GPT2_START_DOCSTRING,
|
640 |
+
)
|
641 |
+
class FlaxGPT2Model(FlaxGPT2PreTrainedModel):
|
642 |
+
module_class = FlaxGPT2Module
|
643 |
+
|
644 |
+
|
645 |
+
append_call_sample_docstring(
|
646 |
+
FlaxGPT2Model, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC
|
647 |
+
)
|
648 |
+
|
649 |
+
|
650 |
+
class FlaxGPT2LMHeadModule(nn.Module):
|
651 |
+
config: GPT2Config
|
652 |
+
dtype: jnp.dtype = jnp.float32
|
653 |
+
|
654 |
+
def setup(self):
|
655 |
+
self.transformer = FlaxGPT2Module(self.config, dtype=self.dtype)
|
656 |
+
self.lm_head = nn.Dense(
|
657 |
+
self.config.vocab_size,
|
658 |
+
use_bias=False,
|
659 |
+
dtype=self.dtype,
|
660 |
+
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range, dtype=self.dtype),
|
661 |
+
)
|
662 |
+
|
663 |
+
def __call__(
|
664 |
+
self,
|
665 |
+
input_ids,
|
666 |
+
attention_mask,
|
667 |
+
position_ids,
|
668 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
669 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
670 |
+
deterministic: bool = True,
|
671 |
+
init_cache: bool = False,
|
672 |
+
output_attentions: bool = False,
|
673 |
+
output_hidden_states: bool = False,
|
674 |
+
return_dict: bool = True,
|
675 |
+
):
|
676 |
+
outputs = self.transformer(
|
677 |
+
input_ids,
|
678 |
+
attention_mask,
|
679 |
+
position_ids,
|
680 |
+
encoder_hidden_states,
|
681 |
+
encoder_attention_mask,
|
682 |
+
deterministic=deterministic,
|
683 |
+
init_cache=init_cache,
|
684 |
+
output_attentions=output_attentions,
|
685 |
+
output_hidden_states=output_hidden_states,
|
686 |
+
return_dict=return_dict,
|
687 |
+
)
|
688 |
+
|
689 |
+
hidden_states = outputs[0]
|
690 |
+
|
691 |
+
if self.config.tie_word_embeddings:
|
692 |
+
shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T
|
693 |
+
lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
|
694 |
+
else:
|
695 |
+
lm_logits = self.lm_head(hidden_states)
|
696 |
+
|
697 |
+
if not return_dict:
|
698 |
+
return (lm_logits,) + outputs[1:]
|
699 |
+
|
700 |
+
if encoder_hidden_states is None:
|
701 |
+
return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
|
702 |
+
else:
|
703 |
+
return FlaxSeq2SeqLMOutput(
|
704 |
+
logits=lm_logits,
|
705 |
+
decoder_hidden_states=outputs.hidden_states,
|
706 |
+
decoder_attentions=outputs.attentions,
|
707 |
+
cross_attentions=outputs.cross_attentions,
|
708 |
+
encoder_last_hidden_state=encoder_hidden_states,
|
709 |
+
encoder_hidden_states=None,
|
710 |
+
encoder_attentions=None,
|
711 |
+
)
|
712 |
+
|
713 |
+
@add_start_docstrings(
|
714 |
+
"""
|
715 |
+
The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
716 |
+
embeddings).
|
717 |
+
""",
|
718 |
+
GPT2_START_DOCSTRING,
|
719 |
+
)
|
720 |
+
class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel):
|
721 |
+
module_class = FlaxGPT2LMHeadModule
|
722 |
+
|
723 |
+
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
724 |
+
# initializing the cache
|
725 |
+
batch_size, seq_length = input_ids.shape
|
726 |
+
|
727 |
+
past_key_values = self.init_cache(batch_size, max_length)
|
728 |
+
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
729 |
+
# But since GPT2 uses a causal mask, those positions are masked anyways.
|
730 |
+
# Thus we can create a single static attention_mask here, which is more efficient for compilation
|
731 |
+
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
732 |
+
if attention_mask is not None:
|
733 |
+
position_ids = attention_mask.cumsum(axis=-1) - 1
|
734 |
+
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
|
735 |
+
else:
|
736 |
+
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
|
737 |
+
|
738 |
+
return {
|
739 |
+
"past_key_values": past_key_values,
|
740 |
+
"attention_mask": extended_attention_mask,
|
741 |
+
"position_ids": position_ids,
|
742 |
+
}
|
743 |
+
|
744 |
+
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
745 |
+
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
746 |
+
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
|
747 |
+
return model_kwargs
|
748 |
+
|
749 |
+
|
750 |
+
append_call_sample_docstring(
|
751 |
+
FlaxGPT2LMHeadModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC
|
752 |
+
)
|
vit_gpt2/modeling_flax_vit_gpt2.py
ADDED
@@ -0,0 +1,704 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Optional, Tuple
|
2 |
+
|
3 |
+
import flax.linen as nn
|
4 |
+
import jax
|
5 |
+
import jax.numpy as jnp
|
6 |
+
from flax.core.frozen_dict import FrozenDict, unfreeze
|
7 |
+
from jax import lax
|
8 |
+
from jax.random import PRNGKey
|
9 |
+
from transformers import GPT2Config, FlaxViTModel, ViTConfig
|
10 |
+
from transformers.modeling_flax_outputs import (
|
11 |
+
FlaxCausalLMOutputWithCrossAttentions,
|
12 |
+
FlaxSeq2SeqLMOutput,
|
13 |
+
FlaxSeq2SeqModelOutput,
|
14 |
+
)
|
15 |
+
from transformers.models.bart.modeling_flax_bart import (
|
16 |
+
shift_tokens_right,
|
17 |
+
)
|
18 |
+
from .modeling_flax_gpt2 import (
|
19 |
+
FlaxGPT2Module,
|
20 |
+
FlaxGPT2Model,
|
21 |
+
FlaxPreTrainedModel
|
22 |
+
)
|
23 |
+
from transformers.models.vit.modeling_flax_vit import FlaxViTModule
|
24 |
+
|
25 |
+
from .configuration_vit_gpt2 import ViTGPT2Config
|
26 |
+
|
27 |
+
|
28 |
+
class FlaxViTGPT2Module(nn.Module):
|
29 |
+
config: ViTGPT2Config
|
30 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
31 |
+
|
32 |
+
def setup(self):
|
33 |
+
|
34 |
+
self.encoder = FlaxViTModule(self.config.vit_config, dtype=self.dtype)
|
35 |
+
self.decoder = FlaxGPT2Module(self.config.gpt2_config, dtype=self.dtype)
|
36 |
+
|
37 |
+
def _get_encoder_module(self):
|
38 |
+
return self.encoder
|
39 |
+
|
40 |
+
def _get_decoder_module(self):
|
41 |
+
return self.decoder
|
42 |
+
|
43 |
+
def __call__(
|
44 |
+
self,
|
45 |
+
pixel_values,
|
46 |
+
input_ids,
|
47 |
+
attention_mask,
|
48 |
+
position_ids,
|
49 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
50 |
+
output_attentions: bool = False,
|
51 |
+
output_hidden_states: bool = False,
|
52 |
+
return_dict: bool = True,
|
53 |
+
deterministic: bool = True,
|
54 |
+
):
|
55 |
+
encoder_outputs = self.encoder(
|
56 |
+
pixel_values=pixel_values,
|
57 |
+
deterministic=deterministic,
|
58 |
+
output_attentions=output_attentions,
|
59 |
+
output_hidden_states=output_hidden_states,
|
60 |
+
return_dict=return_dict,
|
61 |
+
)
|
62 |
+
|
63 |
+
decoder_outputs = self.decoder(
|
64 |
+
input_ids=input_ids,
|
65 |
+
attention_mask=attention_mask,
|
66 |
+
position_ids=position_ids,
|
67 |
+
encoder_hidden_states=encoder_outputs[0],
|
68 |
+
encoder_attention_mask=encoder_attention_mask,
|
69 |
+
deterministic=deterministic,
|
70 |
+
output_attentions=output_attentions,
|
71 |
+
output_hidden_states=output_hidden_states,
|
72 |
+
return_dict=return_dict
|
73 |
+
)
|
74 |
+
|
75 |
+
return FlaxSeq2SeqModelOutput(
|
76 |
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
77 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
78 |
+
decoder_attentions=decoder_outputs.attentions,
|
79 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
80 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
81 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
82 |
+
encoder_attentions=encoder_outputs.attentions,
|
83 |
+
)
|
84 |
+
|
85 |
+
|
86 |
+
class FlaxViTGPT2ForConditionalGenerationModule(nn.Module):
|
87 |
+
config: ViTGPT2Config
|
88 |
+
dtype: jnp.dtype = jnp.float32
|
89 |
+
bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros
|
90 |
+
|
91 |
+
def setup(self):
|
92 |
+
self.model = FlaxViTGPT2Module(config=self.config, dtype=self.dtype)
|
93 |
+
self.lm_head = nn.Dense(
|
94 |
+
self.model.decoder.embed_dim,
|
95 |
+
use_bias=False,
|
96 |
+
dtype=self.dtype,
|
97 |
+
kernel_init=jax.nn.initializers.normal(
|
98 |
+
self.config.gpt2_config.initializer_range, self.dtype
|
99 |
+
),
|
100 |
+
)
|
101 |
+
self.final_logits_bias = self.param(
|
102 |
+
"final_logits_bias", self.bias_init, (1, self.model.decoder.embed_dim)
|
103 |
+
)
|
104 |
+
|
105 |
+
def _get_encoder_module(self):
|
106 |
+
return self.model.encoder
|
107 |
+
|
108 |
+
def _get_decoder_module(self):
|
109 |
+
return self.model.decoder
|
110 |
+
|
111 |
+
def __call__(
|
112 |
+
self,
|
113 |
+
pixel_values,
|
114 |
+
input_ids,
|
115 |
+
attention_mask,
|
116 |
+
position_ids,
|
117 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
118 |
+
output_attentions: bool = False,
|
119 |
+
output_hidden_states: bool = False,
|
120 |
+
return_dict: bool = True,
|
121 |
+
deterministic: bool = True,
|
122 |
+
):
|
123 |
+
outputs = self.model(
|
124 |
+
pixel_values=pixel_values,
|
125 |
+
input_ids=input_ids,
|
126 |
+
attention_mask=attention_mask,
|
127 |
+
position_ids=position_ids,
|
128 |
+
encoder_attention_mask=encoder_attention_mask,
|
129 |
+
output_attentions=output_attentions,
|
130 |
+
output_hidden_states=output_hidden_states,
|
131 |
+
return_dict=return_dict,
|
132 |
+
deterministic=deterministic,
|
133 |
+
)
|
134 |
+
|
135 |
+
hidden_states = outputs[0]
|
136 |
+
lm_logits = self.lm_head(hidden_states)
|
137 |
+
lm_logits += self.final_logits_bias
|
138 |
+
|
139 |
+
if not return_dict:
|
140 |
+
output = (lm_logits,) + outputs[1:]
|
141 |
+
return output
|
142 |
+
|
143 |
+
return FlaxSeq2SeqLMOutput(
|
144 |
+
logits=lm_logits,
|
145 |
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
146 |
+
decoder_attentions=outputs.decoder_attentions,
|
147 |
+
cross_attentions=outputs.cross_attentions,
|
148 |
+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
149 |
+
encoder_hidden_states=outputs.encoder_hidden_states,
|
150 |
+
encoder_attentions=outputs.encoder_attentions,
|
151 |
+
)
|
152 |
+
|
153 |
+
class FlaxViTGPT2PreTrainedModel(FlaxPreTrainedModel):
|
154 |
+
config_class = ViTGPT2Config
|
155 |
+
base_model_prefix: str = "model"
|
156 |
+
module_class: nn.Module = None
|
157 |
+
|
158 |
+
def __init__(
|
159 |
+
self,
|
160 |
+
config: ViTGPT2Config,
|
161 |
+
input_shape: Tuple = None,
|
162 |
+
seed: int = 0,
|
163 |
+
dtype: jnp.dtype = jnp.float32,
|
164 |
+
**kwargs,
|
165 |
+
):
|
166 |
+
if input_shape is None:
|
167 |
+
input_shape = (
|
168 |
+
(1, config.vit_config.image_size, config.vit_config.image_size, 3),
|
169 |
+
(1, 1),
|
170 |
+
)
|
171 |
+
|
172 |
+
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
173 |
+
super().__init__(
|
174 |
+
config, module, input_shape=input_shape, seed=seed, dtype=dtype
|
175 |
+
)
|
176 |
+
|
177 |
+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
178 |
+
# init input tensors
|
179 |
+
pixel_values = jax.random.normal(rng, input_shape[0])
|
180 |
+
# # make sure initialization pass will work for FlaxBartForSequenceClassificationModule
|
181 |
+
# input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id)
|
182 |
+
|
183 |
+
input_ids = jnp.zeros(input_shape[1], dtype="i4")
|
184 |
+
attention_mask = jnp.ones_like(input_ids)
|
185 |
+
|
186 |
+
batch_size, sequence_length = input_ids.shape
|
187 |
+
position_ids = jnp.broadcast_to(
|
188 |
+
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
|
189 |
+
)
|
190 |
+
|
191 |
+
params_rng, dropout_rng = jax.random.split(rng)
|
192 |
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
193 |
+
|
194 |
+
return self.module.init(
|
195 |
+
rngs,
|
196 |
+
pixel_values,
|
197 |
+
input_ids,
|
198 |
+
attention_mask,
|
199 |
+
position_ids,
|
200 |
+
)["params"]
|
201 |
+
|
202 |
+
def init_cache(self, batch_size, max_length, encoder_outputs):
|
203 |
+
|
204 |
+
input_ids = jnp.ones((batch_size, max_length), dtype="i4")
|
205 |
+
attention_mask = jnp.ones_like(input_ids)
|
206 |
+
position_ids = jnp.broadcast_to(
|
207 |
+
jnp.arange(jnp.atleast_2d(input_ids).shape[-1]),
|
208 |
+
input_ids.shape,
|
209 |
+
)
|
210 |
+
|
211 |
+
def _decoder_forward(
|
212 |
+
module,
|
213 |
+
input_ids,
|
214 |
+
attention_mask,
|
215 |
+
position_ids,
|
216 |
+
**kwargs,
|
217 |
+
):
|
218 |
+
decoder_module = module._get_decoder_module()
|
219 |
+
return decoder_module(
|
220 |
+
input_ids,
|
221 |
+
attention_mask,
|
222 |
+
position_ids,
|
223 |
+
**kwargs,
|
224 |
+
)
|
225 |
+
|
226 |
+
init_variables = self.module.init(
|
227 |
+
jax.random.PRNGKey(0),
|
228 |
+
input_ids=input_ids,
|
229 |
+
attention_mask=attention_mask,
|
230 |
+
position_ids=position_ids,
|
231 |
+
encoder_hidden_states=encoder_outputs[0],
|
232 |
+
init_cache=True,
|
233 |
+
method=_decoder_forward, # we only need to call the decoder to init the cache
|
234 |
+
)
|
235 |
+
return unfreeze(init_variables["cache"])
|
236 |
+
|
237 |
+
def encode(
|
238 |
+
self,
|
239 |
+
pixel_values: jnp.ndarray,
|
240 |
+
output_attentions: Optional[bool] = None,
|
241 |
+
output_hidden_states: Optional[bool] = None,
|
242 |
+
return_dict: Optional[bool] = None,
|
243 |
+
train: bool = False,
|
244 |
+
params: dict = None,
|
245 |
+
dropout_rng: PRNGKey = None,
|
246 |
+
):
|
247 |
+
output_attentions = (
|
248 |
+
output_attentions
|
249 |
+
if output_attentions is not None
|
250 |
+
else self.config.output_attentions
|
251 |
+
)
|
252 |
+
output_hidden_states = (
|
253 |
+
output_hidden_states
|
254 |
+
if output_hidden_states is not None
|
255 |
+
else self.config.output_hidden_states
|
256 |
+
)
|
257 |
+
return_dict = (
|
258 |
+
return_dict if return_dict is not None else self.config.return_dict
|
259 |
+
)
|
260 |
+
|
261 |
+
# Handle any PRNG if needed
|
262 |
+
rngs = {}
|
263 |
+
if dropout_rng is not None:
|
264 |
+
rngs["dropout"] = dropout_rng
|
265 |
+
|
266 |
+
def _encoder_forward(module, pixel_values, **kwargs):
|
267 |
+
encode_module = module._get_encoder_module()
|
268 |
+
return encode_module(pixel_values, **kwargs)
|
269 |
+
|
270 |
+
return self.module.apply(
|
271 |
+
{"params": params or self.params},
|
272 |
+
pixel_values=jnp.array(pixel_values, dtype="i4"),
|
273 |
+
output_attentions=output_attentions,
|
274 |
+
output_hidden_states=output_hidden_states,
|
275 |
+
return_dict=return_dict,
|
276 |
+
deterministic=not train,
|
277 |
+
rngs=rngs,
|
278 |
+
method=_encoder_forward,
|
279 |
+
)
|
280 |
+
|
281 |
+
def decode(
|
282 |
+
self,
|
283 |
+
input_ids,
|
284 |
+
encoder_outputs,
|
285 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
286 |
+
attention_mask: Optional[jnp.ndarray] = None,
|
287 |
+
position_ids: Optional[jnp.ndarray] = None,
|
288 |
+
past_key_values: dict = None,
|
289 |
+
output_attentions: Optional[bool] = None,
|
290 |
+
output_hidden_states: Optional[bool] = None,
|
291 |
+
return_dict: Optional[bool] = None,
|
292 |
+
train: bool = False,
|
293 |
+
params: dict = None,
|
294 |
+
dropout_rng: PRNGKey = None,
|
295 |
+
):
|
296 |
+
|
297 |
+
output_attentions = (
|
298 |
+
output_attentions
|
299 |
+
if output_attentions is not None
|
300 |
+
else self.config.output_attentions
|
301 |
+
)
|
302 |
+
output_hidden_states = (
|
303 |
+
output_hidden_states
|
304 |
+
if output_hidden_states is not None
|
305 |
+
else self.config.output_hidden_states
|
306 |
+
)
|
307 |
+
return_dict = (
|
308 |
+
return_dict if return_dict is not None else self.config.return_dict
|
309 |
+
)
|
310 |
+
|
311 |
+
encoder_hidden_states = encoder_outputs[0]
|
312 |
+
if encoder_attention_mask is None:
|
313 |
+
batch_size, sequence_length = encoder_hidden_states.shape[:2]
|
314 |
+
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
|
315 |
+
|
316 |
+
batch_size, sequence_length = input_ids.shape
|
317 |
+
if attention_mask is None:
|
318 |
+
attention_mask = jnp.ones((batch_size, sequence_length))
|
319 |
+
|
320 |
+
if position_ids is None:
|
321 |
+
if past_key_values is not None:
|
322 |
+
raise ValueError(
|
323 |
+
"Make sure to provide `position_ids` when passing `past_key_values`."
|
324 |
+
)
|
325 |
+
|
326 |
+
position_ids = jnp.broadcast_to(
|
327 |
+
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
|
328 |
+
)
|
329 |
+
|
330 |
+
# Handle any PRNG if needed
|
331 |
+
rngs = {}
|
332 |
+
if dropout_rng is not None:
|
333 |
+
rngs["dropout"] = dropout_rng
|
334 |
+
|
335 |
+
inputs = {"params": params or self.params}
|
336 |
+
|
337 |
+
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
|
338 |
+
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
|
339 |
+
# it can be changed by FlaxGPT2Attention module
|
340 |
+
if past_key_values:
|
341 |
+
inputs["cache"] = past_key_values
|
342 |
+
mutable = ["cache"]
|
343 |
+
else:
|
344 |
+
mutable = False
|
345 |
+
|
346 |
+
def _decoder_forward(
|
347 |
+
module,
|
348 |
+
input_ids,
|
349 |
+
attention_mask,
|
350 |
+
position_ids,
|
351 |
+
**kwargs,
|
352 |
+
):
|
353 |
+
decoder_module = module._get_decoder_module()
|
354 |
+
return decoder_module(
|
355 |
+
input_ids,
|
356 |
+
attention_mask,
|
357 |
+
position_ids,
|
358 |
+
**kwargs,
|
359 |
+
)
|
360 |
+
|
361 |
+
outputs = self.module.apply(
|
362 |
+
inputs,
|
363 |
+
input_ids=jnp.array(input_ids, dtype="i4"),
|
364 |
+
attention_mask=jnp.array(attention_mask, dtype="i4"),
|
365 |
+
position_ids=jnp.array(position_ids, dtype="i4"),
|
366 |
+
encoder_hidden_states=encoder_hidden_states,
|
367 |
+
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
|
368 |
+
output_attentions=output_attentions,
|
369 |
+
output_hidden_states=output_hidden_states,
|
370 |
+
return_dict=return_dict,
|
371 |
+
deterministic=not train,
|
372 |
+
rngs=rngs,
|
373 |
+
mutable=mutable,
|
374 |
+
method=_decoder_forward,
|
375 |
+
)
|
376 |
+
|
377 |
+
# add updated cache to model output
|
378 |
+
if past_key_values is not None and return_dict:
|
379 |
+
outputs, past = outputs
|
380 |
+
outputs["past_key_values"] = unfreeze(past["cache"])
|
381 |
+
return outputs
|
382 |
+
elif past_key_values is not None and not return_dict:
|
383 |
+
outputs, past = outputs
|
384 |
+
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
385 |
+
|
386 |
+
return outputs
|
387 |
+
|
388 |
+
def __call__(
|
389 |
+
self,
|
390 |
+
pixel_values: jnp.ndarray,
|
391 |
+
input_ids: Optional[jnp.ndarray] = None,
|
392 |
+
attention_mask: Optional[jnp.ndarray] = None,
|
393 |
+
position_ids: Optional[jnp.ndarray] = None,
|
394 |
+
output_attentions: Optional[bool] = None,
|
395 |
+
output_hidden_states: Optional[bool] = None,
|
396 |
+
return_dict: Optional[bool] = None,
|
397 |
+
train: bool = False,
|
398 |
+
params: dict = None,
|
399 |
+
dropout_rng: PRNGKey = None,
|
400 |
+
):
|
401 |
+
output_attentions = (
|
402 |
+
output_attentions
|
403 |
+
if output_attentions is not None
|
404 |
+
else self.config.output_attentions
|
405 |
+
)
|
406 |
+
output_hidden_states = (
|
407 |
+
output_hidden_states
|
408 |
+
if output_hidden_states is not None
|
409 |
+
else self.config.output_hidden_states
|
410 |
+
)
|
411 |
+
return_dict = (
|
412 |
+
return_dict if return_dict is not None else self.config.return_dict
|
413 |
+
)
|
414 |
+
|
415 |
+
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
|
416 |
+
|
417 |
+
# # prepare encoder inputs
|
418 |
+
# if encoder_attention_mask is None:
|
419 |
+
# encoder_attention_mask = jnp.ones_like(input_ids)
|
420 |
+
|
421 |
+
# if position_ids is None:
|
422 |
+
# batch_size, sequence_length = input_ids.shape
|
423 |
+
# position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
|
424 |
+
|
425 |
+
# prepare decoder inputs
|
426 |
+
# if decoder_input_ids is None:
|
427 |
+
# decoder_input_ids = shift_tokens_right(
|
428 |
+
# input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id
|
429 |
+
# ) # TODO: Check how to use this
|
430 |
+
|
431 |
+
if attention_mask is None:
|
432 |
+
attention_mask = jnp.ones_like(input_ids)
|
433 |
+
if position_ids is None:
|
434 |
+
batch_size, sequence_length = input_ids.shape
|
435 |
+
position_ids = jnp.broadcast_to(
|
436 |
+
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
|
437 |
+
)
|
438 |
+
|
439 |
+
# Handle any PRNG if needed
|
440 |
+
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
|
441 |
+
|
442 |
+
return self.module.apply(
|
443 |
+
{"params": params or self.params},
|
444 |
+
pixel_values=jnp.array(pixel_values, dtype=jnp.float32),
|
445 |
+
input_ids=jnp.array(input_ids, dtype="i4"),
|
446 |
+
attention_mask=jnp.array(attention_mask, dtype="i4"),
|
447 |
+
position_ids=jnp.array(position_ids, dtype="i4"),
|
448 |
+
output_attentions=output_attentions,
|
449 |
+
output_hidden_states=output_hidden_states,
|
450 |
+
return_dict=return_dict,
|
451 |
+
deterministic=not train,
|
452 |
+
rngs=rngs,
|
453 |
+
)
|
454 |
+
|
455 |
+
|
456 |
+
class FlaxViTGPT2ForConditionalGeneration(FlaxViTGPT2PreTrainedModel):
|
457 |
+
module_class = FlaxViTGPT2ForConditionalGenerationModule
|
458 |
+
dtype: jnp.dtype = jnp.float32
|
459 |
+
|
460 |
+
def decode(
|
461 |
+
self,
|
462 |
+
input_ids,
|
463 |
+
encoder_outputs,
|
464 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
465 |
+
attention_mask: Optional[jnp.ndarray] = None,
|
466 |
+
position_ids: Optional[jnp.ndarray] = None,
|
467 |
+
past_key_values: dict = None,
|
468 |
+
output_attentions: Optional[bool] = None,
|
469 |
+
output_hidden_states: Optional[bool] = None,
|
470 |
+
return_dict: Optional[bool] = None,
|
471 |
+
deterministic: bool = True,
|
472 |
+
params: dict = None,
|
473 |
+
dropout_rng: PRNGKey = None,
|
474 |
+
):
|
475 |
+
output_attentions = (
|
476 |
+
output_attentions
|
477 |
+
if output_attentions is not None
|
478 |
+
else self.config.output_attentions
|
479 |
+
)
|
480 |
+
output_hidden_states = (
|
481 |
+
output_hidden_states
|
482 |
+
if output_hidden_states is not None
|
483 |
+
else self.config.output_hidden_states
|
484 |
+
)
|
485 |
+
return_dict = (
|
486 |
+
return_dict if return_dict is not None else self.config.return_dict
|
487 |
+
)
|
488 |
+
|
489 |
+
encoder_hidden_states = encoder_outputs[0]
|
490 |
+
if encoder_attention_mask is None:
|
491 |
+
batch_size, sequence_length = encoder_hidden_states.shape[:2]
|
492 |
+
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
|
493 |
+
|
494 |
+
batch_size, sequence_length = input_ids.shape
|
495 |
+
if attention_mask is None:
|
496 |
+
attention_mask = jnp.ones((batch_size, sequence_length))
|
497 |
+
|
498 |
+
if position_ids is None:
|
499 |
+
if past_key_values is not None:
|
500 |
+
raise ValueError(
|
501 |
+
"Make sure to provide `position_ids` when passing `past_key_values`."
|
502 |
+
)
|
503 |
+
|
504 |
+
position_ids = jnp.broadcast_to(
|
505 |
+
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
|
506 |
+
)
|
507 |
+
|
508 |
+
# Handle any PRNG if needed
|
509 |
+
rngs = {}
|
510 |
+
if dropout_rng is not None:
|
511 |
+
rngs["dropout"] = dropout_rng
|
512 |
+
|
513 |
+
inputs = {"params": params or self.params}
|
514 |
+
|
515 |
+
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
|
516 |
+
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
|
517 |
+
# it can be changed by FlaxGPT2Attention module
|
518 |
+
if past_key_values:
|
519 |
+
inputs["cache"] = past_key_values
|
520 |
+
mutable = ["cache"]
|
521 |
+
else:
|
522 |
+
mutable = False
|
523 |
+
|
524 |
+
def _decoder_forward(
|
525 |
+
module,
|
526 |
+
input_ids,
|
527 |
+
attention_mask,
|
528 |
+
position_ids,
|
529 |
+
**kwargs,
|
530 |
+
):
|
531 |
+
decoder_module = module._get_decoder_module()
|
532 |
+
outputs = decoder_module(
|
533 |
+
input_ids,
|
534 |
+
attention_mask,
|
535 |
+
position_ids,
|
536 |
+
**kwargs,
|
537 |
+
)
|
538 |
+
hidden_states = outputs[0]
|
539 |
+
|
540 |
+
if self.config.tie_word_embeddings:
|
541 |
+
shared_embedding = module.model.variables["params"]["shared"][
|
542 |
+
"embedding"
|
543 |
+
]
|
544 |
+
lm_logits = module.lm_head.apply(
|
545 |
+
{"params": {"kernel": shared_embedding.T}}, hidden_states
|
546 |
+
)
|
547 |
+
else:
|
548 |
+
lm_logits = module.lm_head(hidden_states)
|
549 |
+
|
550 |
+
lm_logits += module.final_logits_bias
|
551 |
+
return lm_logits, outputs
|
552 |
+
|
553 |
+
outputs = self.module.apply(
|
554 |
+
inputs,
|
555 |
+
input_ids=jnp.array(input_ids, dtype="i4"),
|
556 |
+
attention_mask=jnp.array(attention_mask, dtype="i4"),
|
557 |
+
position_ids=jnp.array(position_ids, dtype="i4"),
|
558 |
+
encoder_hidden_states=encoder_hidden_states,
|
559 |
+
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
|
560 |
+
output_attentions=output_attentions,
|
561 |
+
output_hidden_states=output_hidden_states,
|
562 |
+
return_dict=return_dict,
|
563 |
+
deterministic=deterministic,
|
564 |
+
rngs=rngs,
|
565 |
+
mutable=mutable,
|
566 |
+
method=_decoder_forward,
|
567 |
+
)
|
568 |
+
|
569 |
+
if past_key_values is None:
|
570 |
+
lm_logits, outputs = outputs
|
571 |
+
else:
|
572 |
+
(lm_logits, outputs), past = outputs
|
573 |
+
|
574 |
+
if return_dict:
|
575 |
+
outputs = FlaxCausalLMOutputWithCrossAttentions(
|
576 |
+
logits=lm_logits,
|
577 |
+
hidden_states=outputs.hidden_states,
|
578 |
+
attentions=outputs.attentions,
|
579 |
+
cross_attentions=outputs.cross_attentions,
|
580 |
+
)
|
581 |
+
else:
|
582 |
+
outputs = (lm_logits,) + outputs[1:]
|
583 |
+
|
584 |
+
# add updated cache to model output
|
585 |
+
if past_key_values is not None and return_dict:
|
586 |
+
outputs["past_key_values"] = unfreeze(past["cache"])
|
587 |
+
return outputs
|
588 |
+
elif past_key_values is not None and not return_dict:
|
589 |
+
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
590 |
+
|
591 |
+
return outputs
|
592 |
+
|
593 |
+
def prepare_inputs_for_generation(
|
594 |
+
self,
|
595 |
+
input_ids,
|
596 |
+
max_length,
|
597 |
+
encoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
598 |
+
attention_mask: Optional[jnp.DeviceArray] = None,
|
599 |
+
encoder_outputs=None,
|
600 |
+
**kwargs,
|
601 |
+
):
|
602 |
+
# initializing the cache
|
603 |
+
batch_size, seq_length = input_ids.shape
|
604 |
+
|
605 |
+
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
|
606 |
+
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
607 |
+
# But since the decoder uses a causal mask, those positions are masked anyways.
|
608 |
+
# Thus we can create a single static attention_mask here, which is more efficient for compilation
|
609 |
+
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
610 |
+
if attention_mask is not None:
|
611 |
+
position_ids = attention_mask.cumsum(axis=-1) - 1
|
612 |
+
extended_attention_mask = lax.dynamic_update_slice(
|
613 |
+
extended_attention_mask, attention_mask, (0, 0)
|
614 |
+
)
|
615 |
+
else:
|
616 |
+
position_ids = jnp.broadcast_to(
|
617 |
+
jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
|
618 |
+
)
|
619 |
+
|
620 |
+
return {
|
621 |
+
"past_key_values": past_key_values,
|
622 |
+
"encoder_outputs": encoder_outputs,
|
623 |
+
"encoder_attention_mask": encoder_attention_mask,
|
624 |
+
"attention_mask": extended_attention_mask,
|
625 |
+
"position_ids": position_ids,
|
626 |
+
}
|
627 |
+
|
628 |
+
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
629 |
+
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
630 |
+
model_kwargs["position_ids"] = (
|
631 |
+
model_kwargs["position_ids"][:, -1:] + 1
|
632 |
+
)
|
633 |
+
return model_kwargs
|
634 |
+
|
635 |
+
@classmethod
|
636 |
+
def from_vit_gpt2_pretrained(
|
637 |
+
cls,
|
638 |
+
vit_model_name_or_path: str = None,
|
639 |
+
gpt2_model_name_or_path: str = None,
|
640 |
+
*model_args,
|
641 |
+
**kwargs,
|
642 |
+
) -> FlaxViTGPT2PreTrainedModel:
|
643 |
+
|
644 |
+
kwargs_gpt2 = {
|
645 |
+
argument[len("gpt2_") :]: value
|
646 |
+
for argument, value in kwargs.items()
|
647 |
+
if argument.startswith("gpt2_")
|
648 |
+
}
|
649 |
+
|
650 |
+
kwargs_vit = {
|
651 |
+
argument[len("vit_") :]: value
|
652 |
+
for argument, value in kwargs.items()
|
653 |
+
if argument.startswith("vit_")
|
654 |
+
}
|
655 |
+
|
656 |
+
# remove gpt2, vit kwargs from kwargs
|
657 |
+
for key in kwargs_gpt2.keys():
|
658 |
+
del kwargs["gpt2_" + key]
|
659 |
+
for key in kwargs_vit.keys():
|
660 |
+
del kwargs["vit_" + key]
|
661 |
+
|
662 |
+
# Load and initialize the gpt2 and vit model
|
663 |
+
gpt2_model = kwargs_gpt2.pop("model", None)
|
664 |
+
if gpt2_model is None:
|
665 |
+
assert (
|
666 |
+
gpt2_model_name_or_path is not None
|
667 |
+
), "If `model` is not defined as an argument, a `gpt2_model_name_or_path` has to be defined"
|
668 |
+
|
669 |
+
if "config" not in kwargs_gpt2:
|
670 |
+
gpt2_config = GPT2Config.from_pretrained(gpt2_model_name_or_path)
|
671 |
+
kwargs_gpt2["config"] = gpt2_config
|
672 |
+
|
673 |
+
kwargs_gpt2["config"].add_cross_attention = True
|
674 |
+
gpt2_model = FlaxGPT2Model.from_pretrained(
|
675 |
+
gpt2_model_name_or_path, *model_args, **kwargs_gpt2
|
676 |
+
)
|
677 |
+
|
678 |
+
vit_model = kwargs_vit.pop("model", None)
|
679 |
+
if vit_model is None:
|
680 |
+
assert (
|
681 |
+
vit_model_name_or_path is not None
|
682 |
+
), "If `model` is not defined as an argument, a `vit_model_name_or_path` has to be defined"
|
683 |
+
|
684 |
+
if "config" not in kwargs_vit:
|
685 |
+
vit_config = ViTConfig.from_pretrained(vit_model_name_or_path)
|
686 |
+
kwargs_vit["config"] = vit_config
|
687 |
+
|
688 |
+
vit_model = FlaxViTModel.from_pretrained(
|
689 |
+
vit_model_name_or_path, *model_args, **kwargs_vit
|
690 |
+
)
|
691 |
+
|
692 |
+
# instantiate config with corresponding kwargs
|
693 |
+
dtype = kwargs.pop("dtype", jnp.float32)
|
694 |
+
config = ViTGPT2Config.from_vit_gpt2_configs(
|
695 |
+
vit_model.config, gpt2_model.config, **kwargs
|
696 |
+
)
|
697 |
+
|
698 |
+
# init model
|
699 |
+
model = cls(config, *model_args, dtype=dtype, **kwargs)
|
700 |
+
model.params["model"]["encoder"] = vit_model.params
|
701 |
+
model.params["model"]["decoder"] = gpt2_model.params
|
702 |
+
|
703 |
+
return model
|
704 |
+
|
vit_gpt2/modeling_flax_vit_gpt2_lm.py
ADDED
@@ -0,0 +1,684 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Optional, Tuple
|
2 |
+
|
3 |
+
import flax.linen as nn
|
4 |
+
import jax
|
5 |
+
import jax.numpy as jnp
|
6 |
+
from flax.core.frozen_dict import FrozenDict, unfreeze
|
7 |
+
from jax import lax
|
8 |
+
from jax.random import PRNGKey
|
9 |
+
from transformers import GPT2Config, FlaxViTModel, ViTConfig
|
10 |
+
from transformers.modeling_flax_outputs import (
|
11 |
+
FlaxCausalLMOutputWithCrossAttentions,
|
12 |
+
FlaxSeq2SeqLMOutput,
|
13 |
+
FlaxSeq2SeqModelOutput,
|
14 |
+
)
|
15 |
+
from transformers.models.bart.modeling_flax_bart import (
|
16 |
+
shift_tokens_right,
|
17 |
+
)
|
18 |
+
from .modeling_flax_gpt2 import (
|
19 |
+
FlaxGPT2Module,
|
20 |
+
FlaxGPT2Model,
|
21 |
+
FlaxGPT2LMHeadModule,
|
22 |
+
FlaxGPT2LMHeadModel,
|
23 |
+
FlaxPreTrainedModel
|
24 |
+
)
|
25 |
+
from transformers.models.vit.modeling_flax_vit import FlaxViTModule
|
26 |
+
|
27 |
+
from .configuration_vit_gpt2 import ViTGPT2Config
|
28 |
+
|
29 |
+
|
30 |
+
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
|
31 |
+
"""
|
32 |
+
Shift input ids one token to the right.
|
33 |
+
"""
|
34 |
+
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
|
35 |
+
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
|
36 |
+
# replace possible -100 values in labels by `pad_token_id`
|
37 |
+
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
|
38 |
+
|
39 |
+
return shifted_input_ids
|
40 |
+
|
41 |
+
class FlaxViTGPT2LMModule(nn.Module):
|
42 |
+
config: ViTGPT2Config
|
43 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
44 |
+
|
45 |
+
def setup(self):
|
46 |
+
|
47 |
+
self.encoder = FlaxViTModule(self.config.vit_config, dtype=self.dtype)
|
48 |
+
self.decoder = FlaxGPT2LMHeadModule(self.config.gpt2_config, dtype=self.dtype)
|
49 |
+
|
50 |
+
def _get_encoder_module(self):
|
51 |
+
return self.encoder
|
52 |
+
|
53 |
+
def _get_decoder_module(self):
|
54 |
+
return self.decoder
|
55 |
+
|
56 |
+
def __call__(
|
57 |
+
self,
|
58 |
+
pixel_values,
|
59 |
+
input_ids,
|
60 |
+
attention_mask,
|
61 |
+
position_ids,
|
62 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
63 |
+
output_attentions: bool = False,
|
64 |
+
output_hidden_states: bool = False,
|
65 |
+
return_dict: bool = True,
|
66 |
+
deterministic: bool = True,
|
67 |
+
):
|
68 |
+
encoder_outputs = self.encoder(
|
69 |
+
pixel_values=pixel_values,
|
70 |
+
deterministic=deterministic,
|
71 |
+
output_attentions=output_attentions,
|
72 |
+
output_hidden_states=output_hidden_states,
|
73 |
+
return_dict=return_dict,
|
74 |
+
)
|
75 |
+
|
76 |
+
decoder_outputs = self.decoder(
|
77 |
+
input_ids=input_ids,
|
78 |
+
attention_mask=attention_mask,
|
79 |
+
position_ids=position_ids,
|
80 |
+
encoder_hidden_states=encoder_outputs[0],
|
81 |
+
encoder_attention_mask=encoder_attention_mask,
|
82 |
+
deterministic=deterministic,
|
83 |
+
output_attentions=output_attentions,
|
84 |
+
output_hidden_states=output_hidden_states,
|
85 |
+
return_dict=return_dict
|
86 |
+
)
|
87 |
+
|
88 |
+
if not return_dict:
|
89 |
+
return decoder_outputs + encoder_outputs
|
90 |
+
|
91 |
+
return FlaxSeq2SeqLMOutput(
|
92 |
+
logits=decoder_outputs.logits,
|
93 |
+
decoder_hidden_states=decoder_outputs.decoder_hidden_states,
|
94 |
+
decoder_attentions=decoder_outputs.decoder_attentions,
|
95 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
96 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
97 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
98 |
+
encoder_attentions=encoder_outputs.attentions,
|
99 |
+
)
|
100 |
+
|
101 |
+
class FlaxViTGPT2LMForConditionalGenerationModule(nn.Module):
|
102 |
+
config: ViTGPT2Config
|
103 |
+
dtype: jnp.dtype = jnp.float32
|
104 |
+
bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros
|
105 |
+
|
106 |
+
def setup(self):
|
107 |
+
self.model = FlaxViTGPT2LMModule(config=self.config, dtype=self.dtype)
|
108 |
+
|
109 |
+
def _get_encoder_module(self):
|
110 |
+
return self.model.encoder
|
111 |
+
|
112 |
+
def _get_decoder_module(self):
|
113 |
+
return self.model.decoder
|
114 |
+
|
115 |
+
def __call__(
|
116 |
+
self,
|
117 |
+
pixel_values,
|
118 |
+
input_ids,
|
119 |
+
attention_mask,
|
120 |
+
position_ids,
|
121 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
122 |
+
output_attentions: bool = False,
|
123 |
+
output_hidden_states: bool = False,
|
124 |
+
return_dict: bool = True,
|
125 |
+
deterministic: bool = True,
|
126 |
+
):
|
127 |
+
outputs = self.model(
|
128 |
+
pixel_values=pixel_values,
|
129 |
+
input_ids=input_ids,
|
130 |
+
attention_mask=attention_mask,
|
131 |
+
position_ids=position_ids,
|
132 |
+
encoder_attention_mask=encoder_attention_mask,
|
133 |
+
output_attentions=output_attentions,
|
134 |
+
output_hidden_states=output_hidden_states,
|
135 |
+
return_dict=return_dict,
|
136 |
+
deterministic=deterministic,
|
137 |
+
)
|
138 |
+
|
139 |
+
return outputs
|
140 |
+
|
141 |
+
|
142 |
+
class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
|
143 |
+
config_class = ViTGPT2Config
|
144 |
+
base_model_prefix: str = "model"
|
145 |
+
module_class: nn.Module = None
|
146 |
+
|
147 |
+
def __init__(
|
148 |
+
self,
|
149 |
+
config: ViTGPT2Config,
|
150 |
+
input_shape: Tuple = None,
|
151 |
+
seed: int = 0,
|
152 |
+
dtype: jnp.dtype = jnp.float32,
|
153 |
+
**kwargs,
|
154 |
+
):
|
155 |
+
if input_shape is None:
|
156 |
+
input_shape = (
|
157 |
+
(1, config.vit_config.image_size, config.vit_config.image_size, 3),
|
158 |
+
(1, 1),
|
159 |
+
)
|
160 |
+
|
161 |
+
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
162 |
+
super().__init__(
|
163 |
+
config, module, input_shape=input_shape, seed=seed, dtype=dtype
|
164 |
+
)
|
165 |
+
|
166 |
+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
167 |
+
# init input tensors
|
168 |
+
pixel_values = jax.random.normal(rng, input_shape[0])
|
169 |
+
# # make sure initialization pass will work for FlaxBartForSequenceClassificationModule
|
170 |
+
# input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id)
|
171 |
+
|
172 |
+
input_ids = jnp.zeros(input_shape[1], dtype="i4")
|
173 |
+
attention_mask = jnp.ones_like(input_ids)
|
174 |
+
|
175 |
+
batch_size, sequence_length = input_ids.shape
|
176 |
+
position_ids = jnp.broadcast_to(
|
177 |
+
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
|
178 |
+
)
|
179 |
+
|
180 |
+
params_rng, dropout_rng = jax.random.split(rng)
|
181 |
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
182 |
+
|
183 |
+
return self.module.init(
|
184 |
+
rngs,
|
185 |
+
pixel_values,
|
186 |
+
input_ids,
|
187 |
+
attention_mask,
|
188 |
+
position_ids,
|
189 |
+
)["params"]
|
190 |
+
|
191 |
+
def init_cache(self, batch_size, max_length, encoder_outputs):
|
192 |
+
|
193 |
+
input_ids = jnp.ones((batch_size, max_length), dtype="i4")
|
194 |
+
attention_mask = jnp.ones_like(input_ids)
|
195 |
+
position_ids = jnp.broadcast_to(
|
196 |
+
jnp.arange(jnp.atleast_2d(input_ids).shape[-1]),
|
197 |
+
input_ids.shape,
|
198 |
+
)
|
199 |
+
|
200 |
+
def _decoder_forward(
|
201 |
+
module,
|
202 |
+
input_ids,
|
203 |
+
attention_mask,
|
204 |
+
position_ids,
|
205 |
+
**kwargs,
|
206 |
+
):
|
207 |
+
decoder_module = module._get_decoder_module()
|
208 |
+
return decoder_module(
|
209 |
+
input_ids,
|
210 |
+
attention_mask,
|
211 |
+
position_ids,
|
212 |
+
**kwargs,
|
213 |
+
)
|
214 |
+
|
215 |
+
init_variables = self.module.init(
|
216 |
+
jax.random.PRNGKey(0),
|
217 |
+
input_ids=input_ids,
|
218 |
+
attention_mask=attention_mask,
|
219 |
+
position_ids=position_ids,
|
220 |
+
encoder_hidden_states=encoder_outputs[0],
|
221 |
+
init_cache=True,
|
222 |
+
method=_decoder_forward, # we only need to call the decoder to init the cache
|
223 |
+
)
|
224 |
+
return unfreeze(init_variables["cache"])
|
225 |
+
|
226 |
+
def encode(
|
227 |
+
self,
|
228 |
+
pixel_values: jnp.ndarray,
|
229 |
+
attention_mask: Optional[jnp.ndarray] = None,
|
230 |
+
output_attentions: Optional[bool] = None,
|
231 |
+
output_hidden_states: Optional[bool] = None,
|
232 |
+
return_dict: Optional[bool] = None,
|
233 |
+
train: bool = False,
|
234 |
+
params: dict = None,
|
235 |
+
dropout_rng: PRNGKey = None,
|
236 |
+
):
|
237 |
+
output_attentions = (
|
238 |
+
output_attentions
|
239 |
+
if output_attentions is not None
|
240 |
+
else self.config.output_attentions
|
241 |
+
)
|
242 |
+
output_hidden_states = (
|
243 |
+
output_hidden_states
|
244 |
+
if output_hidden_states is not None
|
245 |
+
else self.config.output_hidden_states
|
246 |
+
)
|
247 |
+
return_dict = (
|
248 |
+
return_dict if return_dict is not None else self.config.return_dict
|
249 |
+
)
|
250 |
+
|
251 |
+
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
|
252 |
+
|
253 |
+
# Handle any PRNG if needed
|
254 |
+
rngs = {}
|
255 |
+
if dropout_rng is not None:
|
256 |
+
rngs["dropout"] = dropout_rng
|
257 |
+
|
258 |
+
def _encoder_forward(module, pixel_values, **kwargs):
|
259 |
+
encode_module = module._get_encoder_module()
|
260 |
+
return encode_module(pixel_values, **kwargs)
|
261 |
+
|
262 |
+
return self.module.apply(
|
263 |
+
{"params": params or self.params},
|
264 |
+
pixel_values=jnp.array(pixel_values, dtype="i4"),
|
265 |
+
output_attentions=output_attentions,
|
266 |
+
output_hidden_states=output_hidden_states,
|
267 |
+
return_dict=return_dict,
|
268 |
+
deterministic=not train,
|
269 |
+
rngs=rngs,
|
270 |
+
method=_encoder_forward,
|
271 |
+
)
|
272 |
+
|
273 |
+
def decode(
|
274 |
+
self,
|
275 |
+
input_ids,
|
276 |
+
encoder_outputs,
|
277 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
278 |
+
attention_mask: Optional[jnp.ndarray] = None,
|
279 |
+
position_ids: Optional[jnp.ndarray] = None,
|
280 |
+
past_key_values: dict = None,
|
281 |
+
output_attentions: Optional[bool] = None,
|
282 |
+
output_hidden_states: Optional[bool] = None,
|
283 |
+
return_dict: Optional[bool] = None,
|
284 |
+
train: bool = False,
|
285 |
+
params: dict = None,
|
286 |
+
dropout_rng: PRNGKey = None,
|
287 |
+
):
|
288 |
+
|
289 |
+
output_attentions = (
|
290 |
+
output_attentions
|
291 |
+
if output_attentions is not None
|
292 |
+
else self.config.output_attentions
|
293 |
+
)
|
294 |
+
output_hidden_states = (
|
295 |
+
output_hidden_states
|
296 |
+
if output_hidden_states is not None
|
297 |
+
else self.config.output_hidden_states
|
298 |
+
)
|
299 |
+
return_dict = (
|
300 |
+
return_dict if return_dict is not None else self.config.return_dict
|
301 |
+
)
|
302 |
+
|
303 |
+
encoder_hidden_states = encoder_outputs[0]
|
304 |
+
if encoder_attention_mask is None:
|
305 |
+
batch_size, sequence_length = encoder_hidden_states.shape[:2]
|
306 |
+
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
|
307 |
+
|
308 |
+
batch_size, sequence_length = input_ids.shape
|
309 |
+
if attention_mask is None:
|
310 |
+
attention_mask = jnp.ones((batch_size, sequence_length))
|
311 |
+
|
312 |
+
if position_ids is None:
|
313 |
+
if past_key_values is not None:
|
314 |
+
raise ValueError(
|
315 |
+
"Make sure to provide `position_ids` when passing `past_key_values`."
|
316 |
+
)
|
317 |
+
|
318 |
+
position_ids = jnp.broadcast_to(
|
319 |
+
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
|
320 |
+
)
|
321 |
+
|
322 |
+
# Handle any PRNG if needed
|
323 |
+
rngs = {}
|
324 |
+
if dropout_rng is not None:
|
325 |
+
rngs["dropout"] = dropout_rng
|
326 |
+
|
327 |
+
inputs = {"params": params or self.params}
|
328 |
+
|
329 |
+
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
|
330 |
+
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
|
331 |
+
# it can be changed by FlaxGPT2Attention module
|
332 |
+
if past_key_values:
|
333 |
+
inputs["cache"] = past_key_values
|
334 |
+
mutable = ["cache"]
|
335 |
+
else:
|
336 |
+
mutable = False
|
337 |
+
|
338 |
+
def _decoder_forward(
|
339 |
+
module,
|
340 |
+
input_ids,
|
341 |
+
attention_mask,
|
342 |
+
position_ids,
|
343 |
+
**kwargs,
|
344 |
+
):
|
345 |
+
decoder_module = module._get_decoder_module()
|
346 |
+
return decoder_module(
|
347 |
+
input_ids,
|
348 |
+
attention_mask,
|
349 |
+
position_ids,
|
350 |
+
**kwargs,
|
351 |
+
)
|
352 |
+
|
353 |
+
outputs = self.module.apply(
|
354 |
+
inputs,
|
355 |
+
input_ids=jnp.array(input_ids, dtype="i4"),
|
356 |
+
attention_mask=jnp.array(attention_mask, dtype="i4"),
|
357 |
+
position_ids=jnp.array(position_ids, dtype="i4"),
|
358 |
+
encoder_hidden_states=encoder_hidden_states,
|
359 |
+
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
|
360 |
+
output_attentions=output_attentions,
|
361 |
+
output_hidden_states=output_hidden_states,
|
362 |
+
return_dict=return_dict,
|
363 |
+
deterministic=not train,
|
364 |
+
rngs=rngs,
|
365 |
+
mutable=mutable,
|
366 |
+
method=_decoder_forward,
|
367 |
+
)
|
368 |
+
|
369 |
+
# add updated cache to model output
|
370 |
+
if past_key_values is not None and return_dict:
|
371 |
+
outputs, past = outputs
|
372 |
+
outputs["past_key_values"] = unfreeze(past["cache"])
|
373 |
+
return outputs
|
374 |
+
elif past_key_values is not None and not return_dict:
|
375 |
+
outputs, past = outputs
|
376 |
+
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
377 |
+
|
378 |
+
return outputs
|
379 |
+
|
380 |
+
def __call__(
|
381 |
+
self,
|
382 |
+
pixel_values: jnp.ndarray,
|
383 |
+
input_ids: Optional[jnp.ndarray] = None,
|
384 |
+
attention_mask: Optional[jnp.ndarray] = None,
|
385 |
+
position_ids: Optional[jnp.ndarray] = None,
|
386 |
+
output_attentions: Optional[bool] = None,
|
387 |
+
output_hidden_states: Optional[bool] = None,
|
388 |
+
return_dict: Optional[bool] = None,
|
389 |
+
train: bool = False,
|
390 |
+
params: dict = None,
|
391 |
+
dropout_rng: PRNGKey = None,
|
392 |
+
):
|
393 |
+
output_attentions = (
|
394 |
+
output_attentions
|
395 |
+
if output_attentions is not None
|
396 |
+
else self.config.output_attentions
|
397 |
+
)
|
398 |
+
output_hidden_states = (
|
399 |
+
output_hidden_states
|
400 |
+
if output_hidden_states is not None
|
401 |
+
else self.config.output_hidden_states
|
402 |
+
)
|
403 |
+
return_dict = (
|
404 |
+
return_dict if return_dict is not None else self.config.return_dict
|
405 |
+
)
|
406 |
+
|
407 |
+
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
|
408 |
+
|
409 |
+
# # prepare encoder inputs
|
410 |
+
# if encoder_attention_mask is None:
|
411 |
+
# encoder_attention_mask = jnp.ones_like(input_ids)
|
412 |
+
|
413 |
+
# if position_ids is None:
|
414 |
+
# batch_size, sequence_length = input_ids.shape
|
415 |
+
# position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
|
416 |
+
|
417 |
+
# prepare decoder inputs
|
418 |
+
# if decoder_input_ids is None:
|
419 |
+
# decoder_input_ids = shift_tokens_right(
|
420 |
+
# input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id
|
421 |
+
# ) # TODO: Check how to use this
|
422 |
+
|
423 |
+
if attention_mask is None:
|
424 |
+
attention_mask = jnp.ones_like(input_ids)
|
425 |
+
if position_ids is None:
|
426 |
+
batch_size, sequence_length = input_ids.shape
|
427 |
+
position_ids = jnp.broadcast_to(
|
428 |
+
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
|
429 |
+
)
|
430 |
+
|
431 |
+
# Handle any PRNG if needed
|
432 |
+
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
|
433 |
+
|
434 |
+
return self.module.apply(
|
435 |
+
{"params": params or self.params},
|
436 |
+
pixel_values=jnp.array(pixel_values, dtype=jnp.float32),
|
437 |
+
input_ids=jnp.array(input_ids, dtype="i4"),
|
438 |
+
attention_mask=jnp.array(attention_mask, dtype="i4"),
|
439 |
+
position_ids=jnp.array(position_ids, dtype="i4"),
|
440 |
+
output_attentions=output_attentions,
|
441 |
+
output_hidden_states=output_hidden_states,
|
442 |
+
return_dict=return_dict,
|
443 |
+
deterministic=not train,
|
444 |
+
rngs=rngs,
|
445 |
+
)
|
446 |
+
|
447 |
+
|
448 |
+
class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
|
449 |
+
module_class = FlaxViTGPT2LMForConditionalGenerationModule
|
450 |
+
dtype: jnp.dtype = jnp.float32
|
451 |
+
|
452 |
+
def decode(
|
453 |
+
self,
|
454 |
+
input_ids,
|
455 |
+
encoder_outputs,
|
456 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
457 |
+
attention_mask: Optional[jnp.ndarray] = None,
|
458 |
+
position_ids: Optional[jnp.ndarray] = None,
|
459 |
+
past_key_values: dict = None,
|
460 |
+
output_attentions: Optional[bool] = None,
|
461 |
+
output_hidden_states: Optional[bool] = None,
|
462 |
+
return_dict: Optional[bool] = None,
|
463 |
+
deterministic: bool = True,
|
464 |
+
params: dict = None,
|
465 |
+
dropout_rng: PRNGKey = None,
|
466 |
+
):
|
467 |
+
output_attentions = (
|
468 |
+
output_attentions
|
469 |
+
if output_attentions is not None
|
470 |
+
else self.config.output_attentions
|
471 |
+
)
|
472 |
+
output_hidden_states = (
|
473 |
+
output_hidden_states
|
474 |
+
if output_hidden_states is not None
|
475 |
+
else self.config.output_hidden_states
|
476 |
+
)
|
477 |
+
return_dict = (
|
478 |
+
return_dict if return_dict is not None else self.config.return_dict
|
479 |
+
)
|
480 |
+
|
481 |
+
encoder_hidden_states = encoder_outputs[0]
|
482 |
+
if encoder_attention_mask is None:
|
483 |
+
batch_size, sequence_length = encoder_hidden_states.shape[:2]
|
484 |
+
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
|
485 |
+
|
486 |
+
batch_size, sequence_length = input_ids.shape
|
487 |
+
if attention_mask is None:
|
488 |
+
attention_mask = jnp.ones((batch_size, sequence_length))
|
489 |
+
|
490 |
+
if position_ids is None:
|
491 |
+
if past_key_values is not None:
|
492 |
+
raise ValueError(
|
493 |
+
"Make sure to provide `position_ids` when passing `past_key_values`."
|
494 |
+
)
|
495 |
+
|
496 |
+
position_ids = jnp.broadcast_to(
|
497 |
+
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
|
498 |
+
)
|
499 |
+
|
500 |
+
# Handle any PRNG if needed
|
501 |
+
rngs = {}
|
502 |
+
if dropout_rng is not None:
|
503 |
+
rngs["dropout"] = dropout_rng
|
504 |
+
|
505 |
+
inputs = {"params": params or self.params}
|
506 |
+
|
507 |
+
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
|
508 |
+
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
|
509 |
+
# it can be changed by FlaxGPT2Attention module
|
510 |
+
if past_key_values:
|
511 |
+
inputs["cache"] = past_key_values
|
512 |
+
mutable = ["cache"]
|
513 |
+
else:
|
514 |
+
mutable = False
|
515 |
+
|
516 |
+
def _decoder_forward(
|
517 |
+
module,
|
518 |
+
input_ids,
|
519 |
+
attention_mask,
|
520 |
+
position_ids,
|
521 |
+
**kwargs,
|
522 |
+
):
|
523 |
+
decoder_module = module._get_decoder_module()
|
524 |
+
outputs = decoder_module(
|
525 |
+
input_ids,
|
526 |
+
attention_mask,
|
527 |
+
position_ids,
|
528 |
+
**kwargs,
|
529 |
+
)
|
530 |
+
lm_logits = outputs[0]
|
531 |
+
|
532 |
+
return lm_logits, outputs
|
533 |
+
|
534 |
+
outputs = self.module.apply(
|
535 |
+
inputs,
|
536 |
+
input_ids=jnp.array(input_ids, dtype="i4"),
|
537 |
+
attention_mask=jnp.array(attention_mask, dtype="i4"),
|
538 |
+
position_ids=jnp.array(position_ids, dtype="i4"),
|
539 |
+
encoder_hidden_states=encoder_hidden_states,
|
540 |
+
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
|
541 |
+
output_attentions=output_attentions,
|
542 |
+
output_hidden_states=output_hidden_states,
|
543 |
+
return_dict=return_dict,
|
544 |
+
deterministic=deterministic,
|
545 |
+
rngs=rngs,
|
546 |
+
mutable=mutable,
|
547 |
+
method=_decoder_forward,
|
548 |
+
)
|
549 |
+
|
550 |
+
if past_key_values is None:
|
551 |
+
lm_logits, outputs = outputs
|
552 |
+
else:
|
553 |
+
(lm_logits, outputs), past = outputs
|
554 |
+
|
555 |
+
if return_dict:
|
556 |
+
outputs = FlaxCausalLMOutputWithCrossAttentions(
|
557 |
+
logits=lm_logits,
|
558 |
+
hidden_states=outputs.decoder_hidden_states,
|
559 |
+
attentions=outputs.decoder_attentions,
|
560 |
+
cross_attentions=outputs.cross_attentions,
|
561 |
+
)
|
562 |
+
else:
|
563 |
+
outputs = (lm_logits,) + outputs[1:]
|
564 |
+
|
565 |
+
# add updated cache to model output
|
566 |
+
if past_key_values is not None and return_dict:
|
567 |
+
outputs["past_key_values"] = unfreeze(past["cache"])
|
568 |
+
return outputs
|
569 |
+
elif past_key_values is not None and not return_dict:
|
570 |
+
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
571 |
+
|
572 |
+
return outputs
|
573 |
+
|
574 |
+
def prepare_inputs_for_generation(
|
575 |
+
self,
|
576 |
+
input_ids,
|
577 |
+
max_length,
|
578 |
+
encoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
579 |
+
attention_mask: Optional[jnp.DeviceArray] = None,
|
580 |
+
encoder_outputs=None,
|
581 |
+
**kwargs,
|
582 |
+
):
|
583 |
+
# initializing the cache
|
584 |
+
batch_size, seq_length = input_ids.shape
|
585 |
+
|
586 |
+
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
|
587 |
+
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
588 |
+
# But since the decoder uses a causal mask, those positions are masked anyways.
|
589 |
+
# Thus we can create a single static attention_mask here, which is more efficient for compilation
|
590 |
+
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
591 |
+
if attention_mask is not None:
|
592 |
+
position_ids = attention_mask.cumsum(axis=-1) - 1
|
593 |
+
extended_attention_mask = lax.dynamic_update_slice(
|
594 |
+
extended_attention_mask, attention_mask, (0, 0)
|
595 |
+
)
|
596 |
+
else:
|
597 |
+
position_ids = jnp.broadcast_to(
|
598 |
+
jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
|
599 |
+
)
|
600 |
+
|
601 |
+
return {
|
602 |
+
"past_key_values": past_key_values,
|
603 |
+
"encoder_outputs": encoder_outputs,
|
604 |
+
"encoder_attention_mask": encoder_attention_mask,
|
605 |
+
"attention_mask": extended_attention_mask,
|
606 |
+
"position_ids": position_ids,
|
607 |
+
}
|
608 |
+
|
609 |
+
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
610 |
+
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
611 |
+
model_kwargs["position_ids"] = (
|
612 |
+
model_kwargs["position_ids"][:, -1:] + 1
|
613 |
+
)
|
614 |
+
return model_kwargs
|
615 |
+
|
616 |
+
@classmethod
|
617 |
+
def from_vit_gpt2_pretrained(
|
618 |
+
cls,
|
619 |
+
vit_model_name_or_path: str = None,
|
620 |
+
gpt2_model_name_or_path: str = None,
|
621 |
+
*model_args,
|
622 |
+
**kwargs,
|
623 |
+
) -> FlaxViTGPT2LMPreTrainedModel:
|
624 |
+
|
625 |
+
kwargs_gpt2 = {
|
626 |
+
argument[len("gpt2_") :]: value
|
627 |
+
for argument, value in kwargs.items()
|
628 |
+
if argument.startswith("gpt2_")
|
629 |
+
}
|
630 |
+
|
631 |
+
kwargs_vit = {
|
632 |
+
argument[len("vit_") :]: value
|
633 |
+
for argument, value in kwargs.items()
|
634 |
+
if argument.startswith("vit_")
|
635 |
+
}
|
636 |
+
|
637 |
+
# remove gpt2, vit kwargs from kwargs
|
638 |
+
for key in kwargs_gpt2.keys():
|
639 |
+
del kwargs["gpt2_" + key]
|
640 |
+
for key in kwargs_vit.keys():
|
641 |
+
del kwargs["vit_" + key]
|
642 |
+
|
643 |
+
# Load and initialize the gpt2 and vit model
|
644 |
+
gpt2_model = kwargs_gpt2.pop("model", None)
|
645 |
+
if gpt2_model is None:
|
646 |
+
assert (
|
647 |
+
gpt2_model_name_or_path is not None
|
648 |
+
), "If `model` is not defined as an argument, a `gpt2_model_name_or_path` has to be defined"
|
649 |
+
|
650 |
+
if "config" not in kwargs_gpt2:
|
651 |
+
gpt2_config = GPT2Config.from_pretrained(gpt2_model_name_or_path)
|
652 |
+
kwargs_gpt2["config"] = gpt2_config
|
653 |
+
|
654 |
+
kwargs_gpt2["config"].add_cross_attention = True
|
655 |
+
gpt2_model = FlaxGPT2LMHeadModel.from_pretrained(
|
656 |
+
gpt2_model_name_or_path, *model_args, **kwargs_gpt2
|
657 |
+
)
|
658 |
+
|
659 |
+
vit_model = kwargs_vit.pop("model", None)
|
660 |
+
if vit_model is None:
|
661 |
+
assert (
|
662 |
+
vit_model_name_or_path is not None
|
663 |
+
), "If `model` is not defined as an argument, a `vit_model_name_or_path` has to be defined"
|
664 |
+
|
665 |
+
if "config" not in kwargs_vit:
|
666 |
+
vit_config = ViTConfig.from_pretrained(vit_model_name_or_path)
|
667 |
+
kwargs_vit["config"] = vit_config
|
668 |
+
|
669 |
+
vit_model = FlaxViTModel.from_pretrained(
|
670 |
+
vit_model_name_or_path, *model_args, **kwargs_vit
|
671 |
+
)
|
672 |
+
|
673 |
+
# instantiate config with corresponding kwargs
|
674 |
+
dtype = kwargs.pop("dtype", jnp.float32)
|
675 |
+
config = ViTGPT2Config.from_vit_gpt2_configs(
|
676 |
+
vit_model.config, gpt2_model.config, **kwargs
|
677 |
+
)
|
678 |
+
|
679 |
+
# init model
|
680 |
+
model = cls(config, *model_args, dtype=dtype, **kwargs)
|
681 |
+
model.params["model"]["encoder"] = vit_model.params
|
682 |
+
model.params["model"]["decoder"] = gpt2_model.params
|
683 |
+
|
684 |
+
return model
|
wit_data_dir/dev/dev.tsv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ef1ecdcd132885a8f29c8707fad649431c6ff3d9bbd295d56b8520e7046c0eb7
|
3 |
+
size 1418232
|
wit_data_dir/test/test.tsv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f0517292749005808b1d1d75343c76b8b16c3ed74fde030f7af8b611ad7b4d5d
|
3 |
+
size 1406997
|
wit_dataset_script.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
import datasets
|
6 |
+
import pandas as pd
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
# TODO: Add BibTeX citation
|
11 |
+
# Find for instance the citation on arxiv or on the dataset repo/website
|
12 |
+
_CITATION = """\
|
13 |
+
@InProceedings{huggingface:dataset,
|
14 |
+
title = {A great new dataset},
|
15 |
+
author={huggingface, Inc.
|
16 |
+
},
|
17 |
+
year={2020}
|
18 |
+
}
|
19 |
+
"""
|
20 |
+
|
21 |
+
# TODO: Add description of the dataset here
|
22 |
+
# You can copy an official description
|
23 |
+
_DESCRIPTION = """\
|
24 |
+
This new dataset is designed to solve this great NLP task and is crafted with a lot of care.
|
25 |
+
"""
|
26 |
+
|
27 |
+
# TODO: Add a link to an official homepage for the dataset here
|
28 |
+
_HOMEPAGE = ""
|
29 |
+
|
30 |
+
# TODO: Add the licence for the dataset here if you can find it
|
31 |
+
_LICENSE = ""
|
32 |
+
|
33 |
+
# TODO: Add link to the official dataset URLs here
|
34 |
+
# The HuggingFace dataset library don't host the datasets but only point to the original files
|
35 |
+
# This can be an arbitrary nested dict/list of URLs (see below in `_split_generators` method)
|
36 |
+
_URLs = {
|
37 |
+
}
|
38 |
+
|
39 |
+
|
40 |
+
# TODO: Name of the dataset usually match the script name with CamelCase instead of snake_case
|
41 |
+
class WITDataset(datasets.GeneratorBasedBuilder):
|
42 |
+
"""TODO: Short description of my dataset."""
|
43 |
+
|
44 |
+
VERSION = datasets.Version("1.1.0")
|
45 |
+
|
46 |
+
DEFAULT_CONFIG_NAME = "en"
|
47 |
+
|
48 |
+
def _info(self):
|
49 |
+
# TODO: This method specifies the datasets.DatasetInfo object which contains informations and typings for the dataset
|
50 |
+
|
51 |
+
features = datasets.Features(
|
52 |
+
{
|
53 |
+
"id": datasets.Value("int64"),
|
54 |
+
"lang": datasets.Value("string"),
|
55 |
+
"caption": datasets.Value("string"),
|
56 |
+
"context": datasets.Value("string"),
|
57 |
+
"image_url": datasets.Value("string"),
|
58 |
+
"page_url": datasets.Value("string"),
|
59 |
+
"image_file": datasets.Value("string"),
|
60 |
+
"pixels_file": datasets.Value("string")
|
61 |
+
# These are the features of your dataset like images, labels ...
|
62 |
+
}
|
63 |
+
)
|
64 |
+
|
65 |
+
return datasets.DatasetInfo(
|
66 |
+
# This is the description that will appear on the datasets page.
|
67 |
+
description=_DESCRIPTION,
|
68 |
+
# This defines the different columns of the dataset and their types
|
69 |
+
features=features, # Here we define them above because they are different between the two configurations
|
70 |
+
# If there's a common (input, target) tuple from the features,
|
71 |
+
# specify them here. They'll be used if as_supervised=True in
|
72 |
+
# builder.as_dataset.
|
73 |
+
supervised_keys=None,
|
74 |
+
# Homepage of the dataset for documentation
|
75 |
+
homepage=_HOMEPAGE,
|
76 |
+
# License for the dataset if available
|
77 |
+
license=_LICENSE,
|
78 |
+
# Citation for the dataset
|
79 |
+
citation=_CITATION,
|
80 |
+
)
|
81 |
+
|
82 |
+
def _split_generators(self, dl_manager):
|
83 |
+
"""Returns SplitGenerators."""
|
84 |
+
# TODO: This method is tasked with downloading/extracting the data and defining the splits depending on the configuration
|
85 |
+
# If several configurations are possible (listed in BUILDER_CONFIGS), the configuration selected by the user is in self.config.name
|
86 |
+
|
87 |
+
data_dir = self.config.data_dir
|
88 |
+
|
89 |
+
return [
|
90 |
+
datasets.SplitGenerator(
|
91 |
+
name=datasets.Split.TRAIN,
|
92 |
+
# These kwargs will be passed to _generate_examples
|
93 |
+
gen_kwargs={
|
94 |
+
"data_dir": os.path.join(data_dir, "train"),
|
95 |
+
"split": "train",
|
96 |
+
},
|
97 |
+
),
|
98 |
+
datasets.SplitGenerator(
|
99 |
+
name=datasets.Split.TEST,
|
100 |
+
# These kwargs will be passed to _generate_examples
|
101 |
+
gen_kwargs={
|
102 |
+
"data_dir": os.path.join(data_dir, "test"),
|
103 |
+
"split": "test"
|
104 |
+
},
|
105 |
+
),
|
106 |
+
datasets.SplitGenerator(
|
107 |
+
name=datasets.Split.VALIDATION,
|
108 |
+
# These kwargs will be passed to _generate_examples
|
109 |
+
gen_kwargs={
|
110 |
+
"data_dir": os.path.join(data_dir, "dev"),
|
111 |
+
"split": "dev",
|
112 |
+
},
|
113 |
+
),
|
114 |
+
]
|
115 |
+
|
116 |
+
def _generate_examples(
|
117 |
+
self, data_dir, split # method parameters are unpacked from `gen_kwargs` as given in `_split_generators`
|
118 |
+
):
|
119 |
+
""" Yields examples as (key, example) tuples. """
|
120 |
+
# This method handles input defined in _split_generators to yield (key, example) tuples from the dataset.
|
121 |
+
# The `key` is here for legacy reason (tfds) and is not important in itself.
|
122 |
+
|
123 |
+
df = pd.read_csv(os.path.join(data_dir, f'{split}.tsv'), sep='\t')
|
124 |
+
|
125 |
+
for id_, row in df.iterrows():
|
126 |
+
|
127 |
+
_id = row[0]
|
128 |
+
|
129 |
+
# null caption and context
|
130 |
+
if type(row[4]) != str or type(row[5]) != str:
|
131 |
+
continue
|
132 |
+
|
133 |
+
image_file = os.path.join(data_dir, 'images', f'{_id}.jpg')
|
134 |
+
pixels_file = os.path.join(data_dir, 'numpy', f'{_id}.npy')
|
135 |
+
|
136 |
+
yield id_, {
|
137 |
+
"id": row[0],
|
138 |
+
"lang": row[1],
|
139 |
+
"caption": row[4],
|
140 |
+
"context": row[5],
|
141 |
+
"image_url": row[2],
|
142 |
+
"page_url": row[3],
|
143 |
+
"image_file": image_file,
|
144 |
+
"pixels_file": pixels_file
|
145 |
+
}
|