|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""t5mod_helper.py""" |
|
import torch |
|
from torch import nn |
|
from model.t5mod import T5DecoderYMT3, MultiChannelT5Decoder |
|
from typing import Optional, Callable, Union, Literal |
|
|
|
|
|
@torch.no_grad() |
|
def task_cond_dec_generate(decoder: Union[T5DecoderYMT3, MultiChannelT5Decoder], |
|
decoder_type: Literal["t5", "multi-t5"], |
|
embed_tokens: nn.Embedding, |
|
lm_head: nn.Module, |
|
encoder_hidden_states: torch.FloatTensor, |
|
shift_right_fn: Callable, |
|
prefix_ids: Optional[torch.LongTensor] = None, |
|
max_length: int = 1024, |
|
stop_at_eos: bool = True, |
|
eos_id: Optional[int] = 1, |
|
pad_id: Optional[int] = 0, |
|
decoder_start_token_id: Optional[int] = 0, |
|
debug: bool = False) -> torch.LongTensor: |
|
""" |
|
Generate sequence by task conditioning on the decoder side |
|
:An extension of transofrmers.generate() function for the model with |
|
conditioning only on the decoder side. |
|
|
|
Args: |
|
decoder: T5DecoderYMT3 or MultiChannelT5Decoder, any decoder model with T5Stack architecture |
|
decoder_type: Literal["t5", "multi-t5"], type of decoder |
|
embed_tokens: nn.Embedding, embedding layer for the decoder |
|
lm_head: nn.Module, language model head |
|
encoder_hidden_states: torch.FloatTensor, (B, T, D) or (B, K, T, D) last hidden states |
|
shift_right_fn: Callable, shift_right function of the decoder |
|
prefix_ids: torch.LongTensor, (B, prefix_len) prefix ids typically used as task conditioning to decoder. |
|
max_length: int, max token length to generate (default is 1024) |
|
stop_at_eos: bool, whether to early-stop when all predictions in the batch are the <eos> token. |
|
eos_id: int, the id of the <eos> token (default is 1) |
|
pad_id: int, the id of the <pad> token (default is 0) |
|
decoder_start_token_id: int, the id of the <bos> token (default is 0) |
|
debug: bool, whether to print debug information |
|
|
|
Returns: |
|
pred_ids: torch.LongTensor, (B, task_len + N) or (B, C, task_len + N) predicted token ids |
|
""" |
|
bsz = encoder_hidden_states.shape[0] |
|
device = encoder_hidden_states.device |
|
|
|
|
|
if decoder_type == "t5": |
|
dec_input_shape = (bsz, 1) |
|
elif decoder_type == "multi-t5": |
|
dec_input_shape = (bsz, decoder.num_channels, 1) |
|
else: |
|
raise ValueError(f"decoder_type {decoder_type} is not supported.") |
|
|
|
|
|
if prefix_ids is not None and prefix_ids.numel() > 0: |
|
dec_input_ids = shift_right_fn(prefix_ids) |
|
prefix_length = prefix_ids.shape[-1] |
|
else: |
|
|
|
dec_input_ids = torch.tile(torch.LongTensor([decoder_start_token_id]).to(device), dec_input_shape) |
|
prefix_length = 0 |
|
dec_inputs_embeds = embed_tokens(dec_input_ids) |
|
|
|
|
|
""" |
|
- initial inputs_embeds can be a sequence, without using past_key_values |
|
- dec_hs: (B, 1, D) |
|
- past_key_values: Tuple of length M for M layers of decoder |
|
- pred_ids: (B, prefix_len) where N is the length of prefix_ids |
|
""" |
|
dec_hs, past_key_values = decoder(inputs_embeds=dec_inputs_embeds, |
|
encoder_hidden_states=encoder_hidden_states, |
|
return_dict=False) |
|
logits = lm_head(dec_hs) |
|
pred_ids = logits.argmax(-1) |
|
|
|
|
|
unfinished_sequences = torch.ones(dec_input_shape, dtype=torch.long, device=device) |
|
|
|
|
|
if decoder_type == "t5": |
|
dec_input_ids = pred_ids[:, -1].unsqueeze(-1) |
|
elif decoder_type == "multi-t5": |
|
dec_input_ids = pred_ids[:, :, -1].unsqueeze(-1) |
|
for i in range(max_length - prefix_length - 1): |
|
if debug: |
|
past_key_values_length = past_key_values[0][0].shape[ |
|
2] |
|
print(f'i = {i}, past_key_values_length = {past_key_values_length}, pred_ids.shape = {pred_ids.shape}') |
|
|
|
|
|
dec_inputs_embeds = embed_tokens(dec_input_ids) |
|
dec_hs, _past_key_values = decoder(inputs_embeds=dec_inputs_embeds, |
|
encoder_hidden_states=encoder_hidden_states, |
|
past_key_values=past_key_values, |
|
return_dict=False) |
|
logits = lm_head(dec_hs) |
|
_pred_ids = logits.argmax(-1) |
|
|
|
|
|
dec_input_ids = _pred_ids.clone( |
|
) |
|
past_key_values = _past_key_values |
|
|
|
|
|
if eos_id is not None: |
|
if pad_id is None: |
|
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") |
|
_pred_ids = _pred_ids * unfinished_sequences + pad_id * (1 - unfinished_sequences) |
|
|
|
|
|
pred_ids = torch.cat((pred_ids, _pred_ids), dim=-1) |
|
|
|
|
|
if eos_id is not None: |
|
unfinished_sequences = unfinished_sequences * _pred_ids.ne(eos_id).long() |
|
|
|
|
|
if stop_at_eos is True and unfinished_sequences.max() == 0: |
|
break |
|
|
|
return pred_ids |
|
|