|
""" |
|
This module includes all the classes and functions for the nested autoencoder. |
|
""" |
|
|
|
from transformers import PreTrainedModel |
|
from transformers import T5ForConditionalGeneration, AutoModelForSeq2SeqLM |
|
import datasets |
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
import random |
|
import os |
|
from .configuration_detime import DeTiMEAutoConfig |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
|
|
|
|
class CNNEncoder(nn.Module): |
|
def __init__(self, hidden_size1, hidden_size3): |
|
super().__init__() |
|
|
|
self.encoder = nn.Sequential( |
|
nn.Conv1d(in_channels=hidden_size1, out_channels=128, kernel_size=3, stride=1, padding=1), |
|
nn.ReLU(), |
|
nn.Conv1d(in_channels=128, out_channels=16, kernel_size=3, stride=1, padding=1), |
|
nn.ReLU(), |
|
|
|
|
|
nn.Conv1d(in_channels=16, out_channels=hidden_size3, kernel_size=3, stride=1, padding=1) |
|
) |
|
|
|
def forward(self, x): |
|
|
|
|
|
encoded = self.encoder(x) |
|
return encoded |
|
|
|
class CNNDecoder(nn.Module): |
|
def __init__(self, hidden_size1, hidden_size3) -> None: |
|
super().__init__() |
|
|
|
|
|
self.decoder = nn.Sequential( |
|
nn.Conv1d(in_channels=hidden_size3, out_channels=16, kernel_size=3, stride=1, padding=1), |
|
nn.ReLU(), |
|
nn.Conv1d(in_channels=16, out_channels=128, kernel_size=3, stride=1, padding=1), |
|
nn.ReLU(), |
|
|
|
|
|
nn.Conv1d(in_channels=128, out_channels=hidden_size1, kernel_size=3, stride=1, padding=1), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, x): |
|
|
|
decoded = self.decoder(x) |
|
|
|
return decoded |
|
|
|
|
|
|
|
class DeTiME(PreTrainedModel): |
|
config_class = DeTiMEAutoConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
model_name_or_path = config.model |
|
|
|
|
|
|
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path) |
|
|
|
|
|
self.model = model |
|
self.config_model = 'CNN' |
|
if self.config_model == 'CNN': |
|
|
|
self.encoder = CNNEncoder( |
|
config.hidden_size1, config.hidden_size3) |
|
self.decoder = CNNDecoder( |
|
config.hidden_size1, config.hidden_size3) |
|
self.encoder.main_input_name = self.model.main_input_name |
|
|
|
|
|
self.encoder.main_input_name = self.model.main_input_name |
|
self.main_input_name = self.model.main_input_name |
|
|
|
def forward(self, input_ids, attention_mask, labels, **kwargs): |
|
output = self.model.encoder( |
|
input_ids=input_ids, attention_mask=attention_mask).last_hidden_state |
|
|
|
if self.config_model == 'CNN': |
|
encoder_output = self.encoder(output) |
|
|
|
|
|
output = self.decoder(encoder_output) |
|
|
|
return self.model.forward(input_ids=input_ids.contiguous(), encoder_outputs=(output.contiguous(), ), labels=labels.contiguous(), **kwargs) |
|
|
|
def generate(self, input_ids, attention_mask, **kwargs): |
|
output = self.model.encoder( |
|
input_ids=input_ids, attention_mask=attention_mask).last_hidden_state |
|
|
|
|
|
|
|
if self.config_model == 'CNN': |
|
encoder_output = self.encoder(output) |
|
|
|
|
|
output = self.decoder(encoder_output) |
|
elif self.config_model == 'RNN': |
|
output = self.encoder(output) |
|
|
|
|
|
|
|
return self.model.generate(input_ids=input_ids.contiguous(), encoder_outputs=(output.contiguous(), ), **kwargs) |
|
|
|
|