Upload 7 files
Browse files- LICENSE +21 -0
- README.md +74 -0
- config.json +44 -0
- configuration_indictrans.py +307 -0
- generation_config.json +8 -0
- modeling_indictrans.py +1267 -0
- pytorch_model.bin +3 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) AI4Bharat.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE
|
README.md
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language:
|
3 |
+
- as
|
4 |
+
- bn
|
5 |
+
- brx
|
6 |
+
- doi
|
7 |
+
- en
|
8 |
+
- gom
|
9 |
+
- gu
|
10 |
+
- hi
|
11 |
+
- kn
|
12 |
+
- ks
|
13 |
+
- kas
|
14 |
+
- mai
|
15 |
+
- ml
|
16 |
+
- mr
|
17 |
+
- mni
|
18 |
+
- mnb
|
19 |
+
- ne
|
20 |
+
- or
|
21 |
+
- pa
|
22 |
+
- sa
|
23 |
+
- sat
|
24 |
+
- sd
|
25 |
+
- snd
|
26 |
+
- ta
|
27 |
+
- te
|
28 |
+
- ur
|
29 |
+
language_details: >-
|
30 |
+
asm_Beng, ben_Beng, brx_Deva, doi_Deva, eng_Latn, gom_Deva, guj_Gujr,
|
31 |
+
hin_Deva, kan_Knda, kas_Arab, kas_Deva, mai_Deva, mal_Mlym, mar_Deva,
|
32 |
+
mni_Beng, mni_Mtei, npi_Deva, ory_Orya, pan_Guru, san_Deva, sat_Olck,
|
33 |
+
snd_Arab, snd_Deva, tam_Taml, tel_Telu, urd_Arab
|
34 |
+
tags:
|
35 |
+
- indictrans2
|
36 |
+
- translation
|
37 |
+
- ai4bharat
|
38 |
+
- multilingual
|
39 |
+
license: mit
|
40 |
+
datasets:
|
41 |
+
- flores-200
|
42 |
+
- IN22-Gen
|
43 |
+
- IN22-Conv
|
44 |
+
metrics:
|
45 |
+
- bleu
|
46 |
+
- chrf
|
47 |
+
- chrf++
|
48 |
+
- comet
|
49 |
+
inference: false
|
50 |
+
---
|
51 |
+
|
52 |
+
# IndicTrans2
|
53 |
+
|
54 |
+
This is the model card of IndicTrans2 Indic-En Distilled 200M variant.
|
55 |
+
|
56 |
+
Please refer to [section 7.6: Distilled Models](https://openreview.net/forum?id=vfT4YuzAYA) in the TMLR submission for further details on model training, data and metrics.
|
57 |
+
|
58 |
+
### Usage Instructions
|
59 |
+
|
60 |
+
Please refer to the [github repository](https://github.com/AI4Bharat/IndicTrans2/tree/main/huggingface_inference) for a detail description on how to use HF compatible IndicTrans2 models for inference.
|
61 |
+
|
62 |
+
|
63 |
+
### Citation
|
64 |
+
|
65 |
+
If you consider using our work then please cite using:
|
66 |
+
|
67 |
+
```
|
68 |
+
@article{ai4bharat2023indictrans2,
|
69 |
+
title = {IndicTrans2: Towards High-Quality and Accessible Machine Translation Models for all 22 Scheduled Indian Languages},
|
70 |
+
author = {AI4Bharat and Jay Gala and Pranjal A. Chitale and Raghavan AK and Sumanth Doddapaneni and Varun Gumma and Aswanth Kumar and Janki Nawale and Anupama Sujatha and Ratish Puduppully and Vivek Raghavan and Pratyush Kumar and Mitesh M. Khapra and Raj Dabre and Anoop Kunchukuttan},
|
71 |
+
year = {2023},
|
72 |
+
journal = {arXiv preprint arXiv: 2305.16307}
|
73 |
+
}
|
74 |
+
```
|
config.json
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "ai4bharat/indictrans2-indic-en-dist-200M",
|
3 |
+
"activation_dropout": 0.0,
|
4 |
+
"activation_function": "gelu",
|
5 |
+
"architectures": [
|
6 |
+
"IndicTransForConditionalGeneration"
|
7 |
+
],
|
8 |
+
"auto_map": {
|
9 |
+
"AutoConfig": "configuration_indictrans.IndicTransConfig",
|
10 |
+
"AutoModelForSeq2SeqLM": "modeling_indictrans.IndicTransForConditionalGeneration"
|
11 |
+
},
|
12 |
+
"attention_dropout": 0.0,
|
13 |
+
"bos_token_id": 0,
|
14 |
+
"decoder_attention_heads": 8,
|
15 |
+
"decoder_embed_dim": 512,
|
16 |
+
"decoder_ffn_dim": 2048,
|
17 |
+
"decoder_layerdrop": 0,
|
18 |
+
"decoder_layers": 18,
|
19 |
+
"decoder_normalize_before": true,
|
20 |
+
"decoder_start_token_id": 2,
|
21 |
+
"decoder_vocab_size": 32296,
|
22 |
+
"dropout": 0.2,
|
23 |
+
"encoder_attention_heads": 8,
|
24 |
+
"encoder_embed_dim": 512,
|
25 |
+
"encoder_ffn_dim": 2048,
|
26 |
+
"encoder_layerdrop": 0,
|
27 |
+
"encoder_layers": 18,
|
28 |
+
"encoder_normalize_before": true,
|
29 |
+
"encoder_vocab_size": 122706,
|
30 |
+
"eos_token_id": 2,
|
31 |
+
"init_std": 0.02,
|
32 |
+
"is_encoder_decoder": true,
|
33 |
+
"layernorm_embedding": true,
|
34 |
+
"max_source_positions": 256,
|
35 |
+
"max_target_positions": 256,
|
36 |
+
"model_type": "IndicTrans",
|
37 |
+
"num_hidden_layers": 18,
|
38 |
+
"pad_token_id": 1,
|
39 |
+
"scale_embedding": true,
|
40 |
+
"share_decoder_input_output_embed": true,
|
41 |
+
"torch_dtype": "float32",
|
42 |
+
"transformers_version": "4.32.1",
|
43 |
+
"use_cache": true
|
44 |
+
}
|
configuration_indictrans.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The IndicTrans2 Authors and AI4Bharat team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch IndicTrans config."""
|
16 |
+
|
17 |
+
|
18 |
+
from collections import OrderedDict
|
19 |
+
from typing import Any, Mapping, Optional
|
20 |
+
|
21 |
+
from transformers import PreTrainedTokenizer
|
22 |
+
from transformers.configuration_utils import PretrainedConfig
|
23 |
+
from transformers.onnx import OnnxConfig, OnnxSeq2SeqConfigWithPast
|
24 |
+
from transformers.onnx.utils import compute_effective_axis_dimension
|
25 |
+
from transformers.utils import TensorType, is_torch_available
|
26 |
+
|
27 |
+
|
28 |
+
# Copied from transformers.models.m2m_100.configuration_m2m_100.M2M100Config->IndicTrans
|
29 |
+
class IndicTransConfig(PretrainedConfig):
|
30 |
+
r"""
|
31 |
+
This is the configuration class to store the configuration of a [`IT2Model`]. It is used to instantiate an
|
32 |
+
IT2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
33 |
+
with the defaults will yield a similar configuration to that of the IT2
|
34 |
+
|
35 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
36 |
+
documentation from [`PretrainedConfig`] for more information.
|
37 |
+
|
38 |
+
|
39 |
+
Args:
|
40 |
+
vocab_size (`int`, *optional*, defaults to 50265):
|
41 |
+
Vocabulary size of the IT2 model. Defines the number of different tokens that can be represented by the
|
42 |
+
`inputs_ids` passed when calling [`IT2Model`] or
|
43 |
+
d_model (`int`, *optional*, defaults to 1024):
|
44 |
+
Dimensionality of the layers and the pooler layer.
|
45 |
+
encoder_layers (`int`, *optional*, defaults to 12):
|
46 |
+
Number of encoder layers.
|
47 |
+
decoder_layers (`int`, *optional*, defaults to 12):
|
48 |
+
Number of decoder layers.
|
49 |
+
encoder_attention_heads (`int`, *optional*, defaults to 16):
|
50 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
51 |
+
decoder_attention_heads (`int`, *optional*, defaults to 16):
|
52 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
53 |
+
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
54 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
55 |
+
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
56 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
57 |
+
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
|
58 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
59 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
60 |
+
dropout (`float`, *optional*, defaults to 0.1):
|
61 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
62 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
63 |
+
The dropout ratio for the attention probabilities.
|
64 |
+
activation_dropout (`float`, *optional*, defaults to 0.0):
|
65 |
+
The dropout ratio for activations inside the fully connected layer.
|
66 |
+
classifier_dropout (`float`, *optional*, defaults to 0.0):
|
67 |
+
The dropout ratio for classifier.
|
68 |
+
max_position_embeddings (`int`, *optional*, defaults to 1024):
|
69 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
70 |
+
just in case (e.g., 512 or 1024 or 2048).
|
71 |
+
init_std (`float`, *optional*, defaults to 0.02):
|
72 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
73 |
+
encoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
74 |
+
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
75 |
+
for more details.
|
76 |
+
decoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
77 |
+
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
78 |
+
for more details.
|
79 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
80 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
81 |
+
```"""
|
82 |
+
model_type = "IndicTrans"
|
83 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
84 |
+
attribute_map = {
|
85 |
+
"num_attention_heads": "encoder_attention_heads",
|
86 |
+
"hidden_size": "d_model",
|
87 |
+
}
|
88 |
+
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
encoder_vocab_size=None,
|
92 |
+
decoder_vocab_size=None,
|
93 |
+
encoder_embed_dim=512,
|
94 |
+
decoder_embed_dim=512,
|
95 |
+
max_source_positions=210,
|
96 |
+
max_target_positions=210,
|
97 |
+
encoder_layers=6,
|
98 |
+
encoder_ffn_dim=2048,
|
99 |
+
encoder_attention_heads=8,
|
100 |
+
decoder_layers=6,
|
101 |
+
decoder_ffn_dim=2048,
|
102 |
+
decoder_attention_heads=8,
|
103 |
+
encoder_layerdrop=0.00,
|
104 |
+
decoder_layerdrop=0.00,
|
105 |
+
use_cache=True,
|
106 |
+
is_encoder_decoder=True,
|
107 |
+
activation_function="relu",
|
108 |
+
encoder_normalize_before=False,
|
109 |
+
decoder_normalize_before=False,
|
110 |
+
layernorm_embedding=False,
|
111 |
+
share_decoder_input_output_embed=False,
|
112 |
+
dropout=0.1,
|
113 |
+
attention_dropout=0.0,
|
114 |
+
activation_dropout=0.0,
|
115 |
+
init_std=0.02,
|
116 |
+
scale_embedding=True,
|
117 |
+
decoder_start_token_id=2,
|
118 |
+
pad_token_id=1,
|
119 |
+
bos_token_id=0,
|
120 |
+
eos_token_id=2,
|
121 |
+
**kwargs,
|
122 |
+
):
|
123 |
+
self.encoder_vocab_size = encoder_vocab_size
|
124 |
+
self.decoder_vocab_size = decoder_vocab_size
|
125 |
+
self.encoder_normalize_before = encoder_normalize_before
|
126 |
+
self.decoder_normalize_before = decoder_normalize_before
|
127 |
+
self.layernorm_embedding = layernorm_embedding
|
128 |
+
self.max_source_positions = max_source_positions
|
129 |
+
self.max_target_positions = max_target_positions
|
130 |
+
self.encoder_embed_dim = encoder_embed_dim
|
131 |
+
self.decoder_embed_dim = decoder_embed_dim
|
132 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
133 |
+
self.encoder_layers = encoder_layers
|
134 |
+
self.encoder_attention_heads = encoder_attention_heads
|
135 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
136 |
+
self.decoder_layers = decoder_layers
|
137 |
+
self.decoder_attention_heads = decoder_attention_heads
|
138 |
+
self.dropout = dropout
|
139 |
+
self.attention_dropout = attention_dropout
|
140 |
+
self.activation_dropout = activation_dropout
|
141 |
+
self.activation_function = activation_function
|
142 |
+
self.init_std = init_std
|
143 |
+
self.encoder_layerdrop = encoder_layerdrop
|
144 |
+
self.decoder_layerdrop = decoder_layerdrop
|
145 |
+
self.use_cache = use_cache
|
146 |
+
self.num_hidden_layers = encoder_layers
|
147 |
+
self.scale_embedding = scale_embedding
|
148 |
+
self.share_decoder_input_output_embed = share_decoder_input_output_embed
|
149 |
+
|
150 |
+
super().__init__(
|
151 |
+
pad_token_id=pad_token_id,
|
152 |
+
bos_token_id=bos_token_id,
|
153 |
+
eos_token_id=eos_token_id,
|
154 |
+
is_encoder_decoder=is_encoder_decoder,
|
155 |
+
decoder_start_token_id=decoder_start_token_id,
|
156 |
+
**kwargs,
|
157 |
+
)
|
158 |
+
|
159 |
+
|
160 |
+
class IndicTransOnnxConfig(OnnxSeq2SeqConfigWithPast):
|
161 |
+
@property
|
162 |
+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
163 |
+
common_inputs = OrderedDict(
|
164 |
+
[
|
165 |
+
("input_ids", {0: "batch", 1: "encoder_sequence"}),
|
166 |
+
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
|
167 |
+
]
|
168 |
+
)
|
169 |
+
|
170 |
+
if self.use_past:
|
171 |
+
common_inputs["decoder_input_ids"] = {0: "batch"}
|
172 |
+
common_inputs["decoder_attention_mask"] = {
|
173 |
+
0: "batch",
|
174 |
+
1: "past_decoder_sequence + sequence",
|
175 |
+
}
|
176 |
+
else:
|
177 |
+
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
|
178 |
+
common_inputs["decoder_attention_mask"] = {
|
179 |
+
0: "batch",
|
180 |
+
1: "decoder_sequence",
|
181 |
+
}
|
182 |
+
|
183 |
+
if self.use_past:
|
184 |
+
self.fill_with_past_key_values_(common_inputs, direction="inputs")
|
185 |
+
return common_inputs
|
186 |
+
|
187 |
+
# Copied from BartOnnxConfig._generate_dummy_inputs_for_sequence_classification_and_question_answering
|
188 |
+
# A better name would be _generate_dummy_inputs_for_encoder_and_decoder because sequence classification and question
|
189 |
+
# answering are not supported for IT2, but this name is preserved to be able to check that the copy matches what
|
190 |
+
# was done for BART so that it can be updated if need be.
|
191 |
+
def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
192 |
+
self,
|
193 |
+
tokenizer: PreTrainedTokenizer,
|
194 |
+
batch_size: int = -1,
|
195 |
+
seq_length: int = -1,
|
196 |
+
is_pair: bool = False,
|
197 |
+
framework: Optional[TensorType] = None,
|
198 |
+
) -> Mapping[str, Any]:
|
199 |
+
# Copied from OnnxConfig.generate_dummy_inputs
|
200 |
+
# Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
|
201 |
+
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
|
202 |
+
batch_size = compute_effective_axis_dimension(
|
203 |
+
batch_size,
|
204 |
+
fixed_dimension=OnnxConfig.default_fixed_batch,
|
205 |
+
num_token_to_add=0,
|
206 |
+
)
|
207 |
+
|
208 |
+
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
|
209 |
+
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
|
210 |
+
seq_length = compute_effective_axis_dimension(
|
211 |
+
seq_length,
|
212 |
+
fixed_dimension=OnnxConfig.default_fixed_sequence,
|
213 |
+
num_token_to_add=token_to_add,
|
214 |
+
)
|
215 |
+
|
216 |
+
# Generate dummy inputs according to compute batch and sequence
|
217 |
+
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
|
218 |
+
common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
|
219 |
+
return common_inputs
|
220 |
+
|
221 |
+
# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig._generate_dummy_inputs_for_default_and_seq2seq_lm
|
222 |
+
def _generate_dummy_inputs_for_default_and_seq2seq_lm(
|
223 |
+
self,
|
224 |
+
tokenizer: PreTrainedTokenizer,
|
225 |
+
batch_size: int = -1,
|
226 |
+
seq_length: int = -1,
|
227 |
+
is_pair: bool = False,
|
228 |
+
framework: Optional[TensorType] = None,
|
229 |
+
) -> Mapping[str, Any]:
|
230 |
+
encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
231 |
+
tokenizer, batch_size, seq_length, is_pair, framework
|
232 |
+
)
|
233 |
+
|
234 |
+
# Generate decoder inputs
|
235 |
+
decoder_seq_length = seq_length if not self.use_past else 1
|
236 |
+
decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
237 |
+
tokenizer, batch_size, decoder_seq_length, is_pair, framework
|
238 |
+
)
|
239 |
+
decoder_inputs = {
|
240 |
+
f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()
|
241 |
+
}
|
242 |
+
common_inputs = dict(**encoder_inputs, **decoder_inputs)
|
243 |
+
|
244 |
+
if self.use_past:
|
245 |
+
if not is_torch_available():
|
246 |
+
raise ValueError(
|
247 |
+
"Cannot generate dummy past_keys inputs without PyTorch installed."
|
248 |
+
)
|
249 |
+
else:
|
250 |
+
import torch
|
251 |
+
batch, encoder_seq_length = common_inputs["input_ids"].shape
|
252 |
+
decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
|
253 |
+
(
|
254 |
+
num_encoder_attention_heads,
|
255 |
+
num_decoder_attention_heads,
|
256 |
+
) = self.num_attention_heads
|
257 |
+
encoder_shape = (
|
258 |
+
batch,
|
259 |
+
num_encoder_attention_heads,
|
260 |
+
encoder_seq_length,
|
261 |
+
self._config.hidden_size // num_encoder_attention_heads,
|
262 |
+
)
|
263 |
+
decoder_past_length = decoder_seq_length + 3
|
264 |
+
decoder_shape = (
|
265 |
+
batch,
|
266 |
+
num_decoder_attention_heads,
|
267 |
+
decoder_past_length,
|
268 |
+
self._config.hidden_size // num_decoder_attention_heads,
|
269 |
+
)
|
270 |
+
|
271 |
+
common_inputs["decoder_attention_mask"] = torch.cat(
|
272 |
+
[
|
273 |
+
common_inputs["decoder_attention_mask"],
|
274 |
+
torch.ones(batch, decoder_past_length),
|
275 |
+
],
|
276 |
+
dim=1,
|
277 |
+
)
|
278 |
+
|
279 |
+
common_inputs["past_key_values"] = []
|
280 |
+
# If the number of encoder and decoder layers are present in the model configuration, both are considered
|
281 |
+
num_encoder_layers, num_decoder_layers = self.num_layers
|
282 |
+
min_num_layers = min(num_encoder_layers, num_decoder_layers)
|
283 |
+
max_num_layers = (
|
284 |
+
max(num_encoder_layers, num_decoder_layers) - min_num_layers
|
285 |
+
)
|
286 |
+
remaining_side_name = (
|
287 |
+
"encoder" if num_encoder_layers > num_decoder_layers else "decoder"
|
288 |
+
)
|
289 |
+
|
290 |
+
for _ in range(min_num_layers):
|
291 |
+
common_inputs["past_key_values"].append(
|
292 |
+
(
|
293 |
+
torch.zeros(decoder_shape),
|
294 |
+
torch.zeros(decoder_shape),
|
295 |
+
torch.zeros(encoder_shape),
|
296 |
+
torch.zeros(encoder_shape),
|
297 |
+
)
|
298 |
+
)
|
299 |
+
# TODO: test this.
|
300 |
+
shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
|
301 |
+
for _ in range(min_num_layers, max_num_layers):
|
302 |
+
common_inputs["past_key_values"].append(
|
303 |
+
(torch.zeros(shape), torch.zeros(shape))
|
304 |
+
)
|
305 |
+
return common_inputs
|
306 |
+
|
307 |
+
generate_dummy_inputs = _generate_dummy_inputs_for_default_and_seq2seq_lm
|
generation_config.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bos_token_id": 0,
|
4 |
+
"decoder_start_token_id": 2,
|
5 |
+
"eos_token_id": 2,
|
6 |
+
"pad_token_id": 1,
|
7 |
+
"transformers_version": "4.32.1"
|
8 |
+
}
|
modeling_indictrans.py
ADDED
@@ -0,0 +1,1267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The IndicTrans2 Authors and AI4Bharat team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch IndicTrans model."""
|
16 |
+
|
17 |
+
|
18 |
+
import math
|
19 |
+
from typing import List, Optional, Tuple, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
from torch.nn import functional as F
|
24 |
+
|
25 |
+
from transformers.activations import ACT2FN
|
26 |
+
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
27 |
+
from transformers.modeling_outputs import (
|
28 |
+
BaseModelOutput,
|
29 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
30 |
+
Seq2SeqLMOutput,
|
31 |
+
Seq2SeqModelOutput,
|
32 |
+
)
|
33 |
+
|
34 |
+
from transformers.utils import logging
|
35 |
+
from transformers.modeling_utils import PreTrainedModel
|
36 |
+
|
37 |
+
from .configuration_indictrans import IndicTransConfig
|
38 |
+
|
39 |
+
|
40 |
+
logger = logging.get_logger(__name__)
|
41 |
+
|
42 |
+
_CONFIG_FOR_DOC = "IndicTransConfig"
|
43 |
+
|
44 |
+
INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
|
45 |
+
|
46 |
+
|
47 |
+
# Copied from transformers.models.bart.modeling_bart.shift_tokens_right
|
48 |
+
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
49 |
+
"""
|
50 |
+
Shift input ids one token to the right.
|
51 |
+
"""
|
52 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
53 |
+
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
|
54 |
+
shifted_input_ids[:, 0] = decoder_start_token_id
|
55 |
+
|
56 |
+
if pad_token_id is None:
|
57 |
+
raise ValueError("self.model.config.pad_token_id has to be defined.")
|
58 |
+
# replace possible -100 values in labels by `pad_token_id`
|
59 |
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
60 |
+
|
61 |
+
return shifted_input_ids
|
62 |
+
|
63 |
+
|
64 |
+
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
65 |
+
def _make_causal_mask(
|
66 |
+
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
67 |
+
):
|
68 |
+
"""
|
69 |
+
Make causal mask used for bi-directional self-attention.
|
70 |
+
"""
|
71 |
+
bsz, tgt_len = input_ids_shape
|
72 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
73 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
74 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
75 |
+
mask = mask.to(dtype)
|
76 |
+
|
77 |
+
if past_key_values_length > 0:
|
78 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
79 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
80 |
+
|
81 |
+
|
82 |
+
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
83 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
84 |
+
"""
|
85 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
86 |
+
"""
|
87 |
+
bsz, src_len = mask.size()
|
88 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
89 |
+
|
90 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
91 |
+
|
92 |
+
inverted_mask = 1.0 - expanded_mask
|
93 |
+
|
94 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
95 |
+
|
96 |
+
|
97 |
+
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
|
98 |
+
"""
|
99 |
+
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
100 |
+
are ignored. This is modified from fairseq's `utils.make_positions`.
|
101 |
+
"""
|
102 |
+
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
103 |
+
mask = input_ids.ne(padding_idx).int()
|
104 |
+
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
|
105 |
+
return incremental_indices.long() + padding_idx
|
106 |
+
|
107 |
+
|
108 |
+
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding->IndicTrans
|
109 |
+
class IndicTransSinusoidalPositionalEmbedding(nn.Module):
|
110 |
+
"""This module produces sinusoidal positional embeddings of any length."""
|
111 |
+
|
112 |
+
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
113 |
+
super().__init__()
|
114 |
+
self.offset = 2
|
115 |
+
self.embedding_dim = embedding_dim
|
116 |
+
self.padding_idx = padding_idx
|
117 |
+
self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
|
118 |
+
|
119 |
+
def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
120 |
+
emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
|
121 |
+
if hasattr(self, "weights"):
|
122 |
+
# in forward put the weights on the correct dtype and device of the param
|
123 |
+
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
|
124 |
+
|
125 |
+
self.register_buffer("weights", emb_weights, persistent=False)
|
126 |
+
|
127 |
+
@staticmethod
|
128 |
+
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
129 |
+
"""
|
130 |
+
Build sinusoidal embeddings.
|
131 |
+
|
132 |
+
This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
|
133 |
+
"Attention Is All You Need".
|
134 |
+
"""
|
135 |
+
half_dim = embedding_dim // 2
|
136 |
+
emb = math.log(10000) / (half_dim - 1)
|
137 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
|
138 |
+
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
|
139 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
|
140 |
+
if embedding_dim % 2 == 1:
|
141 |
+
# zero pad
|
142 |
+
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
|
143 |
+
if padding_idx is not None:
|
144 |
+
emb[padding_idx, :] = 0
|
145 |
+
|
146 |
+
return emb.to(torch.get_default_dtype())
|
147 |
+
|
148 |
+
@torch.no_grad()
|
149 |
+
def forward(
|
150 |
+
self, input_ids: torch.Tensor = None, inputs_embeds: torch.Tensor = None, past_key_values_length: int = 0
|
151 |
+
):
|
152 |
+
if input_ids is not None:
|
153 |
+
bsz, seq_len = input_ids.size()
|
154 |
+
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
155 |
+
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to(
|
156 |
+
input_ids.device
|
157 |
+
)
|
158 |
+
else:
|
159 |
+
bsz, seq_len = inputs_embeds.size()[:-1]
|
160 |
+
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length)
|
161 |
+
|
162 |
+
# expand embeddings if needed
|
163 |
+
max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
|
164 |
+
if max_pos > self.weights.size(0):
|
165 |
+
self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
|
166 |
+
|
167 |
+
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach()
|
168 |
+
|
169 |
+
def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_length):
|
170 |
+
"""
|
171 |
+
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
inputs_embeds: torch.Tensor
|
175 |
+
|
176 |
+
Returns: torch.Tensor
|
177 |
+
"""
|
178 |
+
input_shape = inputs_embeds.size()[:-1]
|
179 |
+
sequence_length = input_shape[1]
|
180 |
+
|
181 |
+
position_ids = torch.arange(
|
182 |
+
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
|
183 |
+
)
|
184 |
+
return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length
|
185 |
+
|
186 |
+
|
187 |
+
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->IndicTrans
|
188 |
+
class IndicTransAttention(nn.Module):
|
189 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
190 |
+
|
191 |
+
def __init__(
|
192 |
+
self,
|
193 |
+
embed_dim: int,
|
194 |
+
num_heads: int,
|
195 |
+
dropout: float = 0.0,
|
196 |
+
is_decoder: bool = False,
|
197 |
+
bias: bool = True,
|
198 |
+
):
|
199 |
+
super().__init__()
|
200 |
+
self.embed_dim = embed_dim
|
201 |
+
self.num_heads = num_heads
|
202 |
+
self.dropout = dropout
|
203 |
+
self.head_dim = embed_dim // num_heads
|
204 |
+
|
205 |
+
if (self.head_dim * num_heads) != self.embed_dim:
|
206 |
+
raise ValueError(
|
207 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
208 |
+
f" and `num_heads`: {num_heads})."
|
209 |
+
)
|
210 |
+
self.scaling = self.head_dim**-0.5
|
211 |
+
self.is_decoder = is_decoder
|
212 |
+
|
213 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
214 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
215 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
216 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
217 |
+
|
218 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
219 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
220 |
+
|
221 |
+
def forward(
|
222 |
+
self,
|
223 |
+
hidden_states: torch.Tensor,
|
224 |
+
key_value_states: Optional[torch.Tensor] = None,
|
225 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
226 |
+
attention_mask: Optional[torch.Tensor] = None,
|
227 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
228 |
+
output_attentions: bool = False,
|
229 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
230 |
+
"""Input shape: Batch x Time x Channel"""
|
231 |
+
|
232 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
233 |
+
# for the decoder
|
234 |
+
is_cross_attention = key_value_states is not None
|
235 |
+
|
236 |
+
bsz, tgt_len, _ = hidden_states.size()
|
237 |
+
|
238 |
+
# get query proj
|
239 |
+
query_states = self.q_proj(hidden_states) * self.scaling
|
240 |
+
# get key, value proj
|
241 |
+
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
242 |
+
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
243 |
+
# the provided `key_value_states` to support prefix tuning
|
244 |
+
if (
|
245 |
+
is_cross_attention
|
246 |
+
and past_key_value is not None
|
247 |
+
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
248 |
+
):
|
249 |
+
# reuse k,v, cross_attentions
|
250 |
+
key_states = past_key_value[0]
|
251 |
+
value_states = past_key_value[1]
|
252 |
+
elif is_cross_attention:
|
253 |
+
# cross_attentions
|
254 |
+
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
255 |
+
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
256 |
+
elif past_key_value is not None:
|
257 |
+
# reuse k, v, self_attention
|
258 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
259 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
260 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
261 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
262 |
+
else:
|
263 |
+
# self_attention
|
264 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
265 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
266 |
+
|
267 |
+
if self.is_decoder:
|
268 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
269 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
270 |
+
# key/value_states (first "if" case)
|
271 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
272 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
273 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
274 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
275 |
+
past_key_value = (key_states, value_states)
|
276 |
+
|
277 |
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
278 |
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
279 |
+
key_states = key_states.reshape(*proj_shape)
|
280 |
+
value_states = value_states.reshape(*proj_shape)
|
281 |
+
|
282 |
+
src_len = key_states.size(1)
|
283 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
284 |
+
|
285 |
+
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
286 |
+
raise ValueError(
|
287 |
+
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
288 |
+
f" {attn_weights.size()}"
|
289 |
+
)
|
290 |
+
|
291 |
+
if attention_mask is not None:
|
292 |
+
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
293 |
+
raise ValueError(
|
294 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
295 |
+
)
|
296 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
297 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
298 |
+
|
299 |
+
attn_weights = F.softmax(attn_weights, dim=-1)
|
300 |
+
|
301 |
+
if layer_head_mask is not None:
|
302 |
+
if layer_head_mask.size() != (self.num_heads,):
|
303 |
+
raise ValueError(
|
304 |
+
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
305 |
+
f" {layer_head_mask.size()}"
|
306 |
+
)
|
307 |
+
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
308 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
309 |
+
|
310 |
+
if output_attentions:
|
311 |
+
# this operation is a bit awkward, but it's required to
|
312 |
+
# make sure that attn_weights keeps its gradient.
|
313 |
+
# In order to do so, attn_weights have to be reshaped
|
314 |
+
# twice and have to be reused in the following
|
315 |
+
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
316 |
+
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
317 |
+
else:
|
318 |
+
attn_weights_reshaped = None
|
319 |
+
|
320 |
+
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
|
321 |
+
|
322 |
+
attn_output = torch.bmm(attn_probs, value_states)
|
323 |
+
|
324 |
+
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
325 |
+
raise ValueError(
|
326 |
+
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
|
327 |
+
f" {attn_output.size()}"
|
328 |
+
)
|
329 |
+
|
330 |
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
331 |
+
attn_output = attn_output.transpose(1, 2)
|
332 |
+
|
333 |
+
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
334 |
+
# partitioned across GPUs when using tensor-parallelism.
|
335 |
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
336 |
+
|
337 |
+
attn_output = self.out_proj(attn_output)
|
338 |
+
|
339 |
+
return attn_output, attn_weights_reshaped, past_key_value
|
340 |
+
|
341 |
+
|
342 |
+
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->IndicTrans
|
343 |
+
class IndicTransEncoderLayer(nn.Module):
|
344 |
+
def __init__(self, config: IndicTransConfig):
|
345 |
+
super().__init__()
|
346 |
+
self.embed_dim = config.encoder_embed_dim
|
347 |
+
self.self_attn = IndicTransAttention(
|
348 |
+
embed_dim=self.embed_dim,
|
349 |
+
num_heads=config.encoder_attention_heads,
|
350 |
+
dropout=config.attention_dropout,
|
351 |
+
)
|
352 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
353 |
+
self.dropout = config.dropout
|
354 |
+
self.activation_fn = ACT2FN[config.activation_function]
|
355 |
+
self.activation_dropout = config.activation_dropout
|
356 |
+
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
357 |
+
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
358 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
359 |
+
self.normalize_before = config.encoder_normalize_before
|
360 |
+
|
361 |
+
def forward(
|
362 |
+
self,
|
363 |
+
hidden_states: torch.Tensor,
|
364 |
+
attention_mask: torch.Tensor,
|
365 |
+
layer_head_mask: torch.Tensor,
|
366 |
+
output_attentions: bool = False,
|
367 |
+
) -> torch.Tensor:
|
368 |
+
"""
|
369 |
+
Args:
|
370 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
371 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
372 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
373 |
+
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
374 |
+
`(encoder_attention_heads,)`.
|
375 |
+
output_attentions (`bool`, *optional*):
|
376 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
377 |
+
returned tensors for more detail.
|
378 |
+
"""
|
379 |
+
residual = hidden_states
|
380 |
+
if self.normalize_before:
|
381 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
382 |
+
hidden_states, attn_weights, _ = self.self_attn(
|
383 |
+
hidden_states=hidden_states,
|
384 |
+
attention_mask=attention_mask,
|
385 |
+
layer_head_mask=layer_head_mask,
|
386 |
+
output_attentions=output_attentions,
|
387 |
+
)
|
388 |
+
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
389 |
+
hidden_states = residual + hidden_states
|
390 |
+
if not self.normalize_before:
|
391 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
392 |
+
|
393 |
+
residual = hidden_states
|
394 |
+
if self.normalize_before:
|
395 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
396 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
397 |
+
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
398 |
+
hidden_states = self.fc2(hidden_states)
|
399 |
+
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
400 |
+
hidden_states = residual + hidden_states
|
401 |
+
if not self.normalize_before:
|
402 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
403 |
+
|
404 |
+
if hidden_states.dtype == torch.float16 and (
|
405 |
+
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
406 |
+
):
|
407 |
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
408 |
+
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
409 |
+
|
410 |
+
outputs = (hidden_states,)
|
411 |
+
|
412 |
+
if output_attentions:
|
413 |
+
outputs += (attn_weights,)
|
414 |
+
|
415 |
+
return outputs
|
416 |
+
|
417 |
+
|
418 |
+
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->IndicTrans
|
419 |
+
class IndicTransDecoderLayer(nn.Module):
|
420 |
+
def __init__(self, config: IndicTransConfig):
|
421 |
+
super().__init__()
|
422 |
+
self.embed_dim = config.decoder_embed_dim
|
423 |
+
|
424 |
+
self.self_attn = IndicTransAttention(
|
425 |
+
embed_dim=self.embed_dim,
|
426 |
+
num_heads=config.decoder_attention_heads,
|
427 |
+
dropout=config.attention_dropout,
|
428 |
+
is_decoder=True,
|
429 |
+
)
|
430 |
+
self.dropout = config.dropout
|
431 |
+
self.activation_fn = ACT2FN[config.activation_function]
|
432 |
+
self.activation_dropout = config.activation_dropout
|
433 |
+
|
434 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
435 |
+
self.encoder_attn = IndicTransAttention(
|
436 |
+
self.embed_dim,
|
437 |
+
config.decoder_attention_heads,
|
438 |
+
dropout=config.attention_dropout,
|
439 |
+
is_decoder=True,
|
440 |
+
)
|
441 |
+
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
442 |
+
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
|
443 |
+
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
|
444 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
445 |
+
self.normalize_before = config.decoder_normalize_before
|
446 |
+
|
447 |
+
def forward(
|
448 |
+
self,
|
449 |
+
hidden_states: torch.Tensor,
|
450 |
+
attention_mask: Optional[torch.Tensor] = None,
|
451 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
452 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
453 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
454 |
+
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
455 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
456 |
+
output_attentions: Optional[bool] = False,
|
457 |
+
use_cache: Optional[bool] = True,
|
458 |
+
) -> torch.Tensor:
|
459 |
+
"""
|
460 |
+
Args:
|
461 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
462 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
463 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
464 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
465 |
+
cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
|
466 |
+
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
467 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
468 |
+
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
469 |
+
`(encoder_attention_heads,)`.
|
470 |
+
cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
|
471 |
+
size `(decoder_attention_heads,)`.
|
472 |
+
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
473 |
+
output_attentions (`bool`, *optional*):
|
474 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
475 |
+
returned tensors for more detail.
|
476 |
+
"""
|
477 |
+
residual = hidden_states
|
478 |
+
if self.normalize_before:
|
479 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
480 |
+
|
481 |
+
# Self Attention
|
482 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
483 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
484 |
+
# add present self-attn cache to positions 1,2 of present_key_value tuple
|
485 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
486 |
+
hidden_states=hidden_states,
|
487 |
+
past_key_value=self_attn_past_key_value,
|
488 |
+
attention_mask=attention_mask,
|
489 |
+
layer_head_mask=layer_head_mask,
|
490 |
+
output_attentions=output_attentions,
|
491 |
+
)
|
492 |
+
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
493 |
+
hidden_states = residual + hidden_states
|
494 |
+
if not self.normalize_before:
|
495 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
496 |
+
|
497 |
+
# Cross-Attention Block
|
498 |
+
cross_attn_present_key_value = None
|
499 |
+
cross_attn_weights = None
|
500 |
+
if encoder_hidden_states is not None:
|
501 |
+
residual = hidden_states
|
502 |
+
if self.normalize_before:
|
503 |
+
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
504 |
+
|
505 |
+
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
506 |
+
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
507 |
+
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
508 |
+
hidden_states=hidden_states,
|
509 |
+
key_value_states=encoder_hidden_states,
|
510 |
+
attention_mask=encoder_attention_mask,
|
511 |
+
layer_head_mask=cross_attn_layer_head_mask,
|
512 |
+
past_key_value=cross_attn_past_key_value,
|
513 |
+
output_attentions=output_attentions,
|
514 |
+
)
|
515 |
+
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
516 |
+
hidden_states = residual + hidden_states
|
517 |
+
if not self.normalize_before:
|
518 |
+
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
519 |
+
|
520 |
+
# add cross-attn to positions 3,4 of present_key_value tuple
|
521 |
+
present_key_value = present_key_value + cross_attn_present_key_value
|
522 |
+
|
523 |
+
# Fully Connected
|
524 |
+
residual = hidden_states
|
525 |
+
if self.normalize_before:
|
526 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
527 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
528 |
+
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
529 |
+
hidden_states = self.fc2(hidden_states)
|
530 |
+
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
531 |
+
hidden_states = residual + hidden_states
|
532 |
+
if not self.normalize_before:
|
533 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
534 |
+
|
535 |
+
outputs = (hidden_states,)
|
536 |
+
|
537 |
+
if output_attentions:
|
538 |
+
outputs += (self_attn_weights, cross_attn_weights)
|
539 |
+
|
540 |
+
if use_cache:
|
541 |
+
outputs += (present_key_value,)
|
542 |
+
|
543 |
+
return outputs
|
544 |
+
|
545 |
+
|
546 |
+
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100PretrainedModel->IndicTrans
|
547 |
+
class IndicTransPreTrainedModel(PreTrainedModel):
|
548 |
+
config_class = IndicTransConfig
|
549 |
+
base_model_prefix = "model"
|
550 |
+
supports_gradient_checkpointing = True
|
551 |
+
_no_split_modules = ["IndicTransAttention"]
|
552 |
+
|
553 |
+
def _init_weights(self, module):
|
554 |
+
std = self.config.init_std
|
555 |
+
if isinstance(module, nn.Linear):
|
556 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
557 |
+
if module.bias is not None:
|
558 |
+
module.bias.data.zero_()
|
559 |
+
elif isinstance(module, nn.Embedding):
|
560 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
561 |
+
if module.padding_idx is not None:
|
562 |
+
module.weight.data[module.padding_idx].zero_()
|
563 |
+
|
564 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
565 |
+
if isinstance(module, (IndicTransDecoder, IndicTransEncoder)):
|
566 |
+
module.gradient_checkpointing = value
|
567 |
+
|
568 |
+
|
569 |
+
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100EncoderLayer->IndicTrans
|
570 |
+
class IndicTransEncoder(IndicTransPreTrainedModel):
|
571 |
+
"""
|
572 |
+
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
573 |
+
[`IndicTransEncoderLayer`].
|
574 |
+
|
575 |
+
Args:
|
576 |
+
config: IndicTransConfig
|
577 |
+
embed_tokens (nn.Embedding): output embedding
|
578 |
+
"""
|
579 |
+
|
580 |
+
def __init__(self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None):
|
581 |
+
super().__init__(config)
|
582 |
+
|
583 |
+
self.dropout = config.dropout
|
584 |
+
self.layerdrop = config.encoder_layerdrop
|
585 |
+
|
586 |
+
embed_dim = config.encoder_embed_dim
|
587 |
+
self.padding_idx = config.pad_token_id
|
588 |
+
self.max_source_positions = config.max_source_positions
|
589 |
+
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
590 |
+
|
591 |
+
self.embed_tokens = nn.Embedding(config.encoder_vocab_size, embed_dim, self.padding_idx)
|
592 |
+
|
593 |
+
if embed_tokens is not None:
|
594 |
+
self.embed_tokens.weight = embed_tokens.weight
|
595 |
+
|
596 |
+
self.embed_positions = IndicTransSinusoidalPositionalEmbedding(
|
597 |
+
config.max_source_positions,
|
598 |
+
embed_dim,
|
599 |
+
self.padding_idx,
|
600 |
+
)
|
601 |
+
self.layers = nn.ModuleList([IndicTransEncoderLayer(config) for _ in range(config.encoder_layers)])
|
602 |
+
self.layer_norm = nn.LayerNorm(embed_dim) if config.encoder_normalize_before else None
|
603 |
+
self.layernorm_embedding = nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
|
604 |
+
|
605 |
+
self.gradient_checkpointing = False
|
606 |
+
# Initialize weights and apply final processing
|
607 |
+
self.post_init()
|
608 |
+
|
609 |
+
def forward(
|
610 |
+
self,
|
611 |
+
input_ids: Optional[torch.Tensor] = None,
|
612 |
+
attention_mask: Optional[torch.Tensor] = None,
|
613 |
+
head_mask: Optional[torch.Tensor] = None,
|
614 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
615 |
+
output_attentions: Optional[bool] = None,
|
616 |
+
output_hidden_states: Optional[bool] = None,
|
617 |
+
return_dict: Optional[bool] = None,
|
618 |
+
):
|
619 |
+
r"""
|
620 |
+
Args:
|
621 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
622 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
623 |
+
provide it.
|
624 |
+
|
625 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
626 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
627 |
+
|
628 |
+
[What are input IDs?](../glossary#input-ids)
|
629 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
630 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
631 |
+
|
632 |
+
- 1 for tokens that are **not masked**,
|
633 |
+
- 0 for tokens that are **masked**.
|
634 |
+
|
635 |
+
[What are attention masks?](../glossary#attention-mask)
|
636 |
+
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
637 |
+
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
638 |
+
|
639 |
+
- 1 indicates the head is **not masked**,
|
640 |
+
- 0 indicates the head is **masked**.
|
641 |
+
|
642 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
643 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
644 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
645 |
+
than the model's internal embedding lookup matrix.
|
646 |
+
output_attentions (`bool`, *optional*):
|
647 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
648 |
+
returned tensors for more detail.
|
649 |
+
output_hidden_states (`bool`, *optional*):
|
650 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
651 |
+
for more detail.
|
652 |
+
return_dict (`bool`, *optional*):
|
653 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
654 |
+
"""
|
655 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
656 |
+
output_hidden_states = (
|
657 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
658 |
+
)
|
659 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
660 |
+
|
661 |
+
# retrieve input_ids and inputs_embeds
|
662 |
+
if input_ids is not None and inputs_embeds is not None:
|
663 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
664 |
+
elif input_ids is not None:
|
665 |
+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
666 |
+
input_shape = input_ids.size()
|
667 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
668 |
+
elif inputs_embeds is not None:
|
669 |
+
input_shape = inputs_embeds.size()[:-1]
|
670 |
+
else:
|
671 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
672 |
+
|
673 |
+
if inputs_embeds is None:
|
674 |
+
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
675 |
+
|
676 |
+
embed_pos = self.embed_positions(input_ids, inputs_embeds)
|
677 |
+
embed_pos = embed_pos.to(inputs_embeds.device)
|
678 |
+
|
679 |
+
hidden_states = inputs_embeds + embed_pos
|
680 |
+
if self.layernorm_embedding is not None:
|
681 |
+
x = self.layernorm_embedding(hidden_states)
|
682 |
+
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
683 |
+
|
684 |
+
# expand attention_mask
|
685 |
+
if attention_mask is not None:
|
686 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
687 |
+
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
|
688 |
+
|
689 |
+
encoder_states = () if output_hidden_states else None
|
690 |
+
all_attentions = () if output_attentions else None
|
691 |
+
|
692 |
+
# check if head_mask has a correct number of layers specified if desired
|
693 |
+
if head_mask is not None:
|
694 |
+
if head_mask.size()[0] != len(self.layers):
|
695 |
+
raise ValueError(
|
696 |
+
f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
|
697 |
+
f" {head_mask.size()[0]}."
|
698 |
+
)
|
699 |
+
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
700 |
+
|
701 |
+
for idx, encoder_layer in enumerate(self.layers):
|
702 |
+
if output_hidden_states:
|
703 |
+
encoder_states = encoder_states + (hidden_states,)
|
704 |
+
|
705 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
706 |
+
dropout_probability = torch.rand([])
|
707 |
+
|
708 |
+
skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False
|
709 |
+
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
710 |
+
# under deepspeed zero3 all gpus must run in sync
|
711 |
+
|
712 |
+
if self.gradient_checkpointing and self.training:
|
713 |
+
# create gradient checkpointing function
|
714 |
+
def create_custom_forward(module):
|
715 |
+
def custom_forward(*inputs):
|
716 |
+
return module(*inputs, output_attentions)
|
717 |
+
|
718 |
+
return custom_forward
|
719 |
+
|
720 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
721 |
+
create_custom_forward(encoder_layer),
|
722 |
+
hidden_states,
|
723 |
+
attention_mask,
|
724 |
+
(head_mask[idx] if head_mask is not None else None),
|
725 |
+
)
|
726 |
+
else:
|
727 |
+
layer_outputs = encoder_layer(
|
728 |
+
hidden_states,
|
729 |
+
attention_mask,
|
730 |
+
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
731 |
+
output_attentions=output_attentions,
|
732 |
+
)
|
733 |
+
|
734 |
+
hidden_states = layer_outputs[0]
|
735 |
+
|
736 |
+
if skip_the_layer:
|
737 |
+
layer_outputs = (None, None)
|
738 |
+
|
739 |
+
if output_attentions:
|
740 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
741 |
+
|
742 |
+
if self.layer_norm is not None:
|
743 |
+
hidden_states = self.layer_norm(hidden_states)
|
744 |
+
|
745 |
+
if output_hidden_states:
|
746 |
+
encoder_states = encoder_states + (hidden_states,)
|
747 |
+
|
748 |
+
if not return_dict:
|
749 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
750 |
+
return BaseModelOutput(
|
751 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
752 |
+
)
|
753 |
+
|
754 |
+
|
755 |
+
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100DecoderLayer->IndicTrans
|
756 |
+
class IndicTransDecoder(IndicTransPreTrainedModel):
|
757 |
+
"""
|
758 |
+
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`IndicTransDecoderLayer`]
|
759 |
+
|
760 |
+
Args:
|
761 |
+
config: IndicTransConfig
|
762 |
+
embed_tokens (nn.Embedding): output embedding
|
763 |
+
"""
|
764 |
+
|
765 |
+
def __init__(self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None):
|
766 |
+
super().__init__(config)
|
767 |
+
self.dropout = config.dropout
|
768 |
+
self.layerdrop = config.decoder_layerdrop
|
769 |
+
|
770 |
+
embed_dim = config.encoder_embed_dim
|
771 |
+
self.padding_idx = config.pad_token_id
|
772 |
+
self.max_target_positions = config.max_target_positions
|
773 |
+
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
774 |
+
|
775 |
+
self.embed_tokens = nn.Embedding(config.decoder_vocab_size, embed_dim, self.padding_idx)
|
776 |
+
|
777 |
+
if embed_tokens is not None:
|
778 |
+
self.embed_tokens.weight = embed_tokens.weight
|
779 |
+
|
780 |
+
self.embed_positions = IndicTransSinusoidalPositionalEmbedding(
|
781 |
+
config.max_target_positions,
|
782 |
+
embed_dim,
|
783 |
+
self.padding_idx,
|
784 |
+
)
|
785 |
+
self.layers = nn.ModuleList([IndicTransDecoderLayer(config) for _ in range(config.decoder_layers)])
|
786 |
+
self.layer_norm = nn.LayerNorm(embed_dim) if config.decoder_normalize_before else None
|
787 |
+
self.layernorm_embedding = nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
|
788 |
+
|
789 |
+
self.gradient_checkpointing = False
|
790 |
+
# Initialize weights and apply final processing
|
791 |
+
self.post_init()
|
792 |
+
|
793 |
+
def forward(
|
794 |
+
self,
|
795 |
+
input_ids: Optional[torch.Tensor] = None,
|
796 |
+
attention_mask: Optional[torch.Tensor] = None,
|
797 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
798 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
799 |
+
head_mask: Optional[torch.Tensor] = None,
|
800 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
801 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
802 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
803 |
+
use_cache: Optional[bool] = None,
|
804 |
+
output_attentions: Optional[bool] = None,
|
805 |
+
output_hidden_states: Optional[bool] = None,
|
806 |
+
return_dict: Optional[bool] = None,
|
807 |
+
):
|
808 |
+
r"""
|
809 |
+
Args:
|
810 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
811 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
812 |
+
provide it.
|
813 |
+
|
814 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
815 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
816 |
+
|
817 |
+
[What are input IDs?](../glossary#input-ids)
|
818 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
819 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
820 |
+
|
821 |
+
- 1 for tokens that are **not masked**,
|
822 |
+
- 0 for tokens that are **masked**.
|
823 |
+
|
824 |
+
[What are attention masks?](../glossary#attention-mask)
|
825 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
|
826 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
827 |
+
of the decoder.
|
828 |
+
encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
|
829 |
+
Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
|
830 |
+
selected in `[0, 1]`:
|
831 |
+
|
832 |
+
- 1 for tokens that are **not masked**,
|
833 |
+
- 0 for tokens that are **masked**.
|
834 |
+
|
835 |
+
[What are attention masks?](../glossary#attention-mask)
|
836 |
+
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
837 |
+
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
838 |
+
|
839 |
+
- 1 indicates the head is **not masked**,
|
840 |
+
- 0 indicates the head is **masked**.
|
841 |
+
|
842 |
+
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
843 |
+
Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
|
844 |
+
cross-attention on hidden heads. Mask values selected in `[0, 1]`:
|
845 |
+
|
846 |
+
- 1 indicates the head is **not masked**,
|
847 |
+
- 0 indicates the head is **masked**.
|
848 |
+
|
849 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
850 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
851 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
852 |
+
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
853 |
+
|
854 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
855 |
+
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
856 |
+
|
857 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
858 |
+
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
859 |
+
all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
|
860 |
+
shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
|
861 |
+
`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
|
862 |
+
control over how to convert `input_ids` indices into associated vectors than the model's internal
|
863 |
+
embedding lookup matrix.
|
864 |
+
output_attentions (`bool`, *optional*):
|
865 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
866 |
+
returned tensors for more detail.
|
867 |
+
output_hidden_states (`bool`, *optional*):
|
868 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
869 |
+
for more detail.
|
870 |
+
return_dict (`bool`, *optional*):
|
871 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
872 |
+
"""
|
873 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
874 |
+
output_hidden_states = (
|
875 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
876 |
+
)
|
877 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
878 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
879 |
+
|
880 |
+
# retrieve input_ids and inputs_embeds
|
881 |
+
if input_ids is not None and inputs_embeds is not None:
|
882 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
883 |
+
elif input_ids is not None:
|
884 |
+
input_shape = input_ids.size()
|
885 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
886 |
+
elif inputs_embeds is not None:
|
887 |
+
input_shape = inputs_embeds.size()[:-1]
|
888 |
+
else:
|
889 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
890 |
+
|
891 |
+
# past_key_values_length
|
892 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
893 |
+
|
894 |
+
if inputs_embeds is None:
|
895 |
+
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
896 |
+
|
897 |
+
# create causal mask
|
898 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
899 |
+
combined_attention_mask = None
|
900 |
+
if input_shape[-1] > 1:
|
901 |
+
combined_attention_mask = _make_causal_mask(
|
902 |
+
input_shape,
|
903 |
+
inputs_embeds.dtype,
|
904 |
+
device=inputs_embeds.device,
|
905 |
+
past_key_values_length=past_key_values_length,
|
906 |
+
)
|
907 |
+
|
908 |
+
if attention_mask is not None and combined_attention_mask is not None:
|
909 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
910 |
+
combined_attention_mask = combined_attention_mask + _expand_mask(
|
911 |
+
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
912 |
+
)
|
913 |
+
|
914 |
+
# expand encoder attention mask
|
915 |
+
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
916 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
917 |
+
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
|
918 |
+
|
919 |
+
# embed positions
|
920 |
+
positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length)
|
921 |
+
positions = positions.to(inputs_embeds.device)
|
922 |
+
|
923 |
+
hidden_states = inputs_embeds + positions
|
924 |
+
if self.layernorm_embedding is not None:
|
925 |
+
hidden_states = self.layernorm_embedding(hidden_states)
|
926 |
+
|
927 |
+
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
928 |
+
|
929 |
+
if self.gradient_checkpointing and self.training:
|
930 |
+
if use_cache:
|
931 |
+
logger.warning_once(
|
932 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting" " `use_cache=False`..."
|
933 |
+
)
|
934 |
+
use_cache = False
|
935 |
+
|
936 |
+
# decoder layers
|
937 |
+
all_hidden_states = () if output_hidden_states else None
|
938 |
+
all_self_attns = () if output_attentions else None
|
939 |
+
all_cross_attentions = () if output_attentions else None
|
940 |
+
next_decoder_cache = () if use_cache else None
|
941 |
+
|
942 |
+
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
943 |
+
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
944 |
+
if attn_mask is not None:
|
945 |
+
if attn_mask.size()[0] != len(self.layers):
|
946 |
+
raise ValueError(
|
947 |
+
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
|
948 |
+
f" {head_mask.size()[0]}."
|
949 |
+
)
|
950 |
+
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
951 |
+
|
952 |
+
for idx, decoder_layer in enumerate(self.layers):
|
953 |
+
if output_hidden_states:
|
954 |
+
all_hidden_states += (hidden_states,)
|
955 |
+
|
956 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
957 |
+
dropout_probability = torch.rand([])
|
958 |
+
|
959 |
+
skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False
|
960 |
+
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
961 |
+
# under deepspeed zero3 all gpus must run in sync
|
962 |
+
|
963 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
964 |
+
|
965 |
+
if self.gradient_checkpointing and self.training:
|
966 |
+
|
967 |
+
def create_custom_forward(module):
|
968 |
+
def custom_forward(*inputs):
|
969 |
+
# None for past_key_value
|
970 |
+
return module(*inputs, output_attentions, use_cache)
|
971 |
+
|
972 |
+
return custom_forward
|
973 |
+
|
974 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
975 |
+
create_custom_forward(decoder_layer),
|
976 |
+
hidden_states,
|
977 |
+
combined_attention_mask,
|
978 |
+
encoder_hidden_states,
|
979 |
+
encoder_attention_mask,
|
980 |
+
head_mask[idx] if head_mask is not None else None,
|
981 |
+
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
982 |
+
None,
|
983 |
+
)
|
984 |
+
else:
|
985 |
+
layer_outputs = decoder_layer(
|
986 |
+
hidden_states,
|
987 |
+
attention_mask=combined_attention_mask,
|
988 |
+
encoder_hidden_states=encoder_hidden_states,
|
989 |
+
encoder_attention_mask=encoder_attention_mask,
|
990 |
+
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
991 |
+
cross_attn_layer_head_mask=(
|
992 |
+
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
993 |
+
),
|
994 |
+
past_key_value=past_key_value,
|
995 |
+
output_attentions=output_attentions,
|
996 |
+
use_cache=use_cache,
|
997 |
+
)
|
998 |
+
|
999 |
+
hidden_states = layer_outputs[0]
|
1000 |
+
|
1001 |
+
if skip_the_layer:
|
1002 |
+
continue
|
1003 |
+
|
1004 |
+
if use_cache:
|
1005 |
+
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
|
1006 |
+
|
1007 |
+
if output_attentions:
|
1008 |
+
all_self_attns += (layer_outputs[1],)
|
1009 |
+
all_cross_attentions += (layer_outputs[2],)
|
1010 |
+
|
1011 |
+
if self.layer_norm is not None:
|
1012 |
+
hidden_states = self.layer_norm(hidden_states)
|
1013 |
+
|
1014 |
+
# add hidden states from the last decoder layer
|
1015 |
+
if output_hidden_states:
|
1016 |
+
all_hidden_states += (hidden_states,)
|
1017 |
+
|
1018 |
+
next_cache = next_decoder_cache if use_cache else None
|
1019 |
+
if not return_dict:
|
1020 |
+
return tuple(
|
1021 |
+
v
|
1022 |
+
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
|
1023 |
+
if v is not None
|
1024 |
+
)
|
1025 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
1026 |
+
last_hidden_state=hidden_states,
|
1027 |
+
past_key_values=next_cache,
|
1028 |
+
hidden_states=all_hidden_states,
|
1029 |
+
attentions=all_self_attns,
|
1030 |
+
cross_attentions=all_cross_attentions,
|
1031 |
+
)
|
1032 |
+
|
1033 |
+
|
1034 |
+
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100Model->IndicTrans
|
1035 |
+
class IndicTransModel(IndicTransPreTrainedModel):
|
1036 |
+
_tied_weights_keys = None
|
1037 |
+
|
1038 |
+
def __init__(self, config: IndicTransConfig):
|
1039 |
+
super().__init__(config)
|
1040 |
+
|
1041 |
+
self.encoder = IndicTransEncoder(config)
|
1042 |
+
self.decoder = IndicTransDecoder(config)
|
1043 |
+
|
1044 |
+
# Initialize weights and apply final processing
|
1045 |
+
self.post_init()
|
1046 |
+
|
1047 |
+
def get_encoder(self):
|
1048 |
+
return self.encoder
|
1049 |
+
|
1050 |
+
def get_decoder(self):
|
1051 |
+
return self.decoder
|
1052 |
+
|
1053 |
+
def forward(
|
1054 |
+
self,
|
1055 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1056 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1057 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
1058 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
1059 |
+
head_mask: Optional[torch.Tensor] = None,
|
1060 |
+
decoder_head_mask: Optional[torch.Tensor] = None,
|
1061 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
1062 |
+
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
1063 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
1064 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1065 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
1066 |
+
use_cache: Optional[bool] = None,
|
1067 |
+
output_attentions: Optional[bool] = None,
|
1068 |
+
output_hidden_states: Optional[bool] = None,
|
1069 |
+
return_dict: Optional[bool] = None,
|
1070 |
+
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
|
1071 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1072 |
+
output_hidden_states = (
|
1073 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1074 |
+
)
|
1075 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1076 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1077 |
+
|
1078 |
+
if encoder_outputs is None:
|
1079 |
+
encoder_outputs = self.encoder(
|
1080 |
+
input_ids=input_ids,
|
1081 |
+
attention_mask=attention_mask,
|
1082 |
+
head_mask=head_mask,
|
1083 |
+
inputs_embeds=inputs_embeds,
|
1084 |
+
output_attentions=output_attentions,
|
1085 |
+
output_hidden_states=output_hidden_states,
|
1086 |
+
return_dict=return_dict,
|
1087 |
+
)
|
1088 |
+
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
|
1089 |
+
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
1090 |
+
encoder_outputs = BaseModelOutput(
|
1091 |
+
last_hidden_state=encoder_outputs[0],
|
1092 |
+
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
1093 |
+
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
1094 |
+
)
|
1095 |
+
|
1096 |
+
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
|
1097 |
+
decoder_outputs = self.decoder(
|
1098 |
+
input_ids=decoder_input_ids,
|
1099 |
+
attention_mask=decoder_attention_mask,
|
1100 |
+
encoder_hidden_states=encoder_outputs[0],
|
1101 |
+
encoder_attention_mask=attention_mask,
|
1102 |
+
head_mask=decoder_head_mask,
|
1103 |
+
cross_attn_head_mask=cross_attn_head_mask,
|
1104 |
+
past_key_values=past_key_values,
|
1105 |
+
inputs_embeds=decoder_inputs_embeds,
|
1106 |
+
use_cache=use_cache,
|
1107 |
+
output_attentions=output_attentions,
|
1108 |
+
output_hidden_states=output_hidden_states,
|
1109 |
+
return_dict=return_dict,
|
1110 |
+
)
|
1111 |
+
|
1112 |
+
if not return_dict:
|
1113 |
+
return decoder_outputs + encoder_outputs
|
1114 |
+
|
1115 |
+
return Seq2SeqModelOutput(
|
1116 |
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
1117 |
+
past_key_values=decoder_outputs.past_key_values,
|
1118 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
1119 |
+
decoder_attentions=decoder_outputs.attentions,
|
1120 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
1121 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
1122 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
1123 |
+
encoder_attentions=encoder_outputs.attentions,
|
1124 |
+
)
|
1125 |
+
|
1126 |
+
|
1127 |
+
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->IndicTrans
|
1128 |
+
class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
|
1129 |
+
base_model_prefix = "model"
|
1130 |
+
_tied_weights_keys = None
|
1131 |
+
|
1132 |
+
def __init__(self, config: IndicTransConfig):
|
1133 |
+
super().__init__(config)
|
1134 |
+
self.model = IndicTransModel(config)
|
1135 |
+
self.lm_head = nn.Linear(config.decoder_embed_dim, config.decoder_vocab_size, bias=False)
|
1136 |
+
|
1137 |
+
if config.share_decoder_input_output_embed:
|
1138 |
+
self.lm_head.weight = self.model.decoder.embed_tokens.weight
|
1139 |
+
|
1140 |
+
self.post_init()
|
1141 |
+
|
1142 |
+
def tie_weights(self):
|
1143 |
+
pass
|
1144 |
+
|
1145 |
+
def get_encoder(self):
|
1146 |
+
return self.model.get_encoder()
|
1147 |
+
|
1148 |
+
def get_decoder(self):
|
1149 |
+
return self.model.get_decoder()
|
1150 |
+
|
1151 |
+
def get_output_embeddings(self):
|
1152 |
+
return self.lm_head
|
1153 |
+
|
1154 |
+
def set_output_embeddings(self, new_embeddings):
|
1155 |
+
self.lm_head = new_embeddings
|
1156 |
+
|
1157 |
+
def forward(
|
1158 |
+
self,
|
1159 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1160 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1161 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
1162 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
1163 |
+
head_mask: Optional[torch.Tensor] = None,
|
1164 |
+
decoder_head_mask: Optional[torch.Tensor] = None,
|
1165 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
1166 |
+
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
1167 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
1168 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1169 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
1170 |
+
labels: Optional[torch.LongTensor] = None,
|
1171 |
+
use_cache: Optional[bool] = None,
|
1172 |
+
output_attentions: Optional[bool] = None,
|
1173 |
+
output_hidden_states: Optional[bool] = None,
|
1174 |
+
return_dict: Optional[bool] = None,
|
1175 |
+
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
|
1176 |
+
r"""
|
1177 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1178 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
1179 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
1180 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
1181 |
+
|
1182 |
+
Returns:
|
1183 |
+
"""
|
1184 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1185 |
+
|
1186 |
+
if labels is not None:
|
1187 |
+
if decoder_input_ids is None:
|
1188 |
+
decoder_input_ids = shift_tokens_right(
|
1189 |
+
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
1190 |
+
)
|
1191 |
+
|
1192 |
+
outputs = self.model(
|
1193 |
+
input_ids,
|
1194 |
+
attention_mask=attention_mask,
|
1195 |
+
decoder_input_ids=decoder_input_ids,
|
1196 |
+
encoder_outputs=encoder_outputs,
|
1197 |
+
decoder_attention_mask=decoder_attention_mask,
|
1198 |
+
head_mask=head_mask,
|
1199 |
+
decoder_head_mask=decoder_head_mask,
|
1200 |
+
cross_attn_head_mask=cross_attn_head_mask,
|
1201 |
+
past_key_values=past_key_values,
|
1202 |
+
inputs_embeds=inputs_embeds,
|
1203 |
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
1204 |
+
use_cache=use_cache,
|
1205 |
+
output_attentions=output_attentions,
|
1206 |
+
output_hidden_states=output_hidden_states,
|
1207 |
+
return_dict=return_dict,
|
1208 |
+
)
|
1209 |
+
lm_logits = self.lm_head(outputs[0])
|
1210 |
+
|
1211 |
+
masked_lm_loss = None
|
1212 |
+
if labels is not None:
|
1213 |
+
# move labels to the correct device to enable PP
|
1214 |
+
labels = labels.to(lm_logits.device)
|
1215 |
+
loss_fct = nn.CrossEntropyLoss()
|
1216 |
+
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.decoder_vocab_size), labels.view(-1))
|
1217 |
+
|
1218 |
+
if not return_dict:
|
1219 |
+
output = (lm_logits,) + outputs[1:]
|
1220 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
1221 |
+
|
1222 |
+
return Seq2SeqLMOutput(
|
1223 |
+
loss=masked_lm_loss,
|
1224 |
+
logits=lm_logits,
|
1225 |
+
past_key_values=outputs.past_key_values,
|
1226 |
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
1227 |
+
decoder_attentions=outputs.decoder_attentions,
|
1228 |
+
cross_attentions=outputs.cross_attentions,
|
1229 |
+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
1230 |
+
encoder_hidden_states=outputs.encoder_hidden_states,
|
1231 |
+
encoder_attentions=outputs.encoder_attentions,
|
1232 |
+
)
|
1233 |
+
|
1234 |
+
def prepare_inputs_for_generation(
|
1235 |
+
self,
|
1236 |
+
decoder_input_ids,
|
1237 |
+
past_key_values=None,
|
1238 |
+
attention_mask=None,
|
1239 |
+
head_mask=None,
|
1240 |
+
decoder_head_mask=None,
|
1241 |
+
cross_attn_head_mask=None,
|
1242 |
+
use_cache=None,
|
1243 |
+
encoder_outputs=None,
|
1244 |
+
**kwargs,
|
1245 |
+
):
|
1246 |
+
# cut decoder_input_ids if past is used
|
1247 |
+
if past_key_values is not None:
|
1248 |
+
decoder_input_ids = decoder_input_ids[:, -1:]
|
1249 |
+
|
1250 |
+
return {
|
1251 |
+
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
1252 |
+
"encoder_outputs": encoder_outputs,
|
1253 |
+
"past_key_values": past_key_values,
|
1254 |
+
"decoder_input_ids": decoder_input_ids,
|
1255 |
+
"attention_mask": attention_mask,
|
1256 |
+
"head_mask": head_mask,
|
1257 |
+
"decoder_head_mask": decoder_head_mask,
|
1258 |
+
"cross_attn_head_mask": cross_attn_head_mask,
|
1259 |
+
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
1260 |
+
}
|
1261 |
+
|
1262 |
+
@staticmethod
|
1263 |
+
def _reorder_cache(past_key_values, beam_idx):
|
1264 |
+
reordered_past = ()
|
1265 |
+
for layer_past in past_key_values:
|
1266 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
1267 |
+
return reordered_past
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3a1ffde8f165ef552456958dcb162905f6b380c7873e5c05226106a5c26145f1
|
3 |
+
size 913515337
|