YourMT3 / amt /src /extras /t5_dev.py
mimbres's picture
.
a03c9b4
raw
history blame
1.34 kB
import torch
from transformers import T5Config
from model.t5mod import T5ForConditionalGeneration
a = {
"architectures": ["T5ForConditionalGeneration"],
"d_ff": 1024, # size of the intermediate feed forward layer in each T5Block
"d_kv": 64, # d_kv has to be equal to d_model // num_heads.
# "d_model": 512, # encoder hiddnen size, defined by model_cfg
"decoder_start_token_id": 0,
"dense_act_fn": "gelu_new",
# "dropout_rate": 0.05, # can be overwritten by args in ymt3
"eos_token_id": 1,
"feed_forward_proj": "gated-gelu",
"initializer_factor": 1.0,
# "is_encoder_decoder": True,
"is_gated_act": True,
"layer_norm_epsilon": 1e-06,
"model_type": "t5",
# "num_decoder_layers": 8,
"num_heads": 6,
"num_layers": 8,
"output_past": True,
"pad_token_id": 0,
"relative_attention_num_buckets": 32,
"use_cache": True,
"vocab_size": 1391 # vocab_size is automatically set by the task manager...
}
cfg = T5Config(**a)
cfg.num_decoder_layers = 4
cfg.num_layers = 0
model = T5ForConditionalGeneration(cfg)
print(model)
x = torch.rand(((2, 256, 512)))
out = model.encoder.forward(inputs_embeds=x)
enc_hs = torch.rand((2, 256, 512))
labels = torch.randint(0, 1391, (2, 256))
pred = model(encoder_outputs=(enc_hs,), labels=labels) # important (enc_hs,) comma!