File size: 6,932 Bytes
9b4bf4d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
class GPT2MultiHeadConfig(GPT2Config):
model_type = "gpt2-multi-head"
def __init__(
self,
head_locations=None,
head_weights=None,
tie_additional_weights=False,
average_logits=False,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.head_locations = head_locations
self.head_weights = head_weights
self.tie_additional_weights = tie_additional_weights
self.average_logits = average_logits
class GPT2LMMultiHeadModel(GPT2LMHeadModel):
config_class = GPT2MultiHeadConfig
def __init__(self, config: GPT2MultiHeadConfig):
super().__init__(config)
if config.head_locations is not None:
if not len(config.head_locations) + 1 == len(config.head_weights):
raise ValueError("The number of head locations should be equal to the number of head weights minus 1")
self.head_locations = config.head_locations
self.additional_lm_heads = nn.ModuleList(
[nn.Linear(config.n_embd, config.vocab_size, bias=False) for _ in config.head_locations]
)
self.head_weights = config.head_weights
else:
self.head_locations = []
self.additional_lm_heads = nn.ModuleList([])
self.head_weights = [1.0]
self.post_init()
def tie_weights(self):
"""
Tie the weights between the input embeddings and the output embeddings.
If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
weights instead.
"""
super().tie_weights()
if hasattr(self, "additional_lm_heads") and getattr(self.config, "tie_additional_weights", False):
input_embeddings = self.get_input_embeddings()
for classifier in self.additional_lm_heads:
if self.config.torchscript:
classifier.weight = nn.Parameter(input_embeddings.weight.clone())
else:
classifier.weight = input_embeddings.weight
if getattr(classifier, "bias", None) is not None:
classifier.bias.data = nn.functional.pad(
classifier.bias.data,
(
0,
classifier.weight.shape[0] - classifier.bias.shape[0],
),
"constant",
0,
)
if hasattr(classifier, "out_features") and hasattr(input_embeddings, "num_embeddings"):
classifier.out_features = input_embeddings.num_embeddings
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=return_dict,
)
hidden_states = transformer_outputs[2]
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.transformer.first_device)
hidden_states = hidden_states.to(self.lm_head.weight.device)
lm_logits = self.lm_head(hidden_states[-1])
loss = None
if labels is not None:
loss = torch.tensor(0.0, device=hidden_states[-1].device)
lm_logits = []
loss_fct = CrossEntropyLoss()
for index, lm_head, lm_weight in zip(
[*self.head_locations, -1],
[*self.additional_lm_heads, self.lm_head],
self.head_weights,
):
lm_logits.append(lm_head(hidden_states[index]))
# Shift so that tokens < n predict n
shift_logits = lm_logits[-1][..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss += lm_weight * loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if self.config.average_logits:
lm_logits = (torch.vstack(lm_logits) * torch.tensor(self.head_weights)).mean(dim=0)
else:
lm_logits = lm_logits[-1]
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)
|