|
import torch |
|
from torch import Tensor, LongTensor |
|
from transformers import T5ForConditionalGeneration, T5Config |
|
from transformers import TextIteratorStreamer |
|
from transformers.generation.configuration_utils import GenerationConfig |
|
|
|
class TextToTextModel(T5ForConditionalGeneration): |
|
def __init__(self, config: T5Config) -> None: |
|
''' |
|
TextToTextModel继承T5ForConditionalGeneration |
|
''' |
|
super().__init__(config) |
|
|
|
@torch.no_grad() |
|
def my_generate(self, |
|
input_ids: LongTensor, |
|
attention_mask: LongTensor, |
|
max_seq_len: int=256, |
|
search_type: str='beam', |
|
streamer: TextIteratorStreamer=None, |
|
) -> Tensor: |
|
''' |
|
自定义gennerate方法方便调用、测试 |
|
search_type: ['greedy', 'beam', 'sampling', 'contrastive', ] |
|
|
|
- *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and |
|
`do_sample=False` |
|
- *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0.` |
|
and `top_k>1` |
|
- *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and |
|
`do_sample=True` |
|
- *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and |
|
`do_sample=False` |
|
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if |
|
`num_beams>1` and `do_sample=True` |
|
''' |
|
generation_config = GenerationConfig() |
|
generation_config.remove_invalid_values = True |
|
generation_config.eos_token_id = 1 |
|
generation_config.pad_token_id = 0 |
|
generation_config.decoder_start_token_id = self.config.decoder_start_token_id |
|
generation_config.max_new_tokens = max_seq_len |
|
|
|
|
|
if search_type == 'greedy': |
|
generation_config.num_beams = 1 |
|
generation_config.do_sample = False |
|
elif search_type == 'beam': |
|
generation_config.top_k = 50 |
|
generation_config.num_beams = 5 |
|
generation_config.do_sample = True |
|
generation_config.top_p = 0.95 |
|
generation_config.no_repeat_ngram_size = 4 |
|
generation_config.length_penalty = -2.0 |
|
generation_config.early_stopping = True |
|
elif search_type == 'sampling': |
|
generation_config.num_beams = 1 |
|
generation_config.do_sample = True |
|
generation_config.top_k = 50 |
|
generation_config.temperature = 0.98 |
|
generation_config.top_p = 0.80 |
|
generation_config.no_repeat_ngram_size = 4 |
|
elif search_type == 'contrastive': |
|
generation_config.penalty_alpha = 0.5 |
|
generation_config.top_k = 50 |
|
|
|
result = self.generate( |
|
inputs=input_ids, |
|
attention_mask=attention_mask, |
|
generation_config=generation_config, |
|
streamer=streamer, |
|
) |
|
|
|
return result |
|
|