Spaces:
Sleeping
Sleeping
import types | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
from torch import nn | |
from transformers import T5ForConditionalGeneration, T5EncoderModel, AutoModel, LogitsProcessor, LogitsProcessorList | |
from functools import partial | |
from compute_lng import compute_lng | |
from undecorate import unwrap | |
from types import MethodType | |
from utils import * | |
from ling_disc import DebertaReplacedTokenizer | |
from const import * | |
def vae_sample(mu, logvar): | |
std = torch.exp(0.5 * logvar) | |
eps = torch.randn_like(std) | |
return eps * std + mu | |
class VAE(nn.Module): | |
def __init__(self, args): | |
super().__init__() | |
self.encoder = nn.Sequential( | |
nn.Linear(args.input_dim, args.hidden_dim), | |
nn.ReLU(), | |
nn.Linear(args.hidden_dim, args.hidden_dim), | |
nn.ReLU(), | |
) | |
self.decoder = nn.Sequential( | |
nn.Linear(args.latent_dim, args.hidden_dim), | |
nn.ReLU(), | |
nn.Linear(args.hidden_dim, args.hidden_dim), | |
nn.ReLU(), | |
nn.Linear(args.hidden_dim, args.input_dim), | |
) | |
self.fc_mu = nn.Linear(args.hidden_dim, args.latent_dim) | |
self.fc_var = nn.Linear(args.hidden_dim, args.latent_dim) | |
def forward(self, x): | |
h = self.encoder(x) | |
mu = self.fc_mu(h) | |
logvar = self.fc_var(h) | |
x = vae_sample(mu, logvar) | |
o = self.decoder(x) | |
return o, (mu, logvar) | |
class LingGenerator(nn.Module): | |
def __init__(self, args, hidden_dim=1000): | |
super().__init__() | |
self.gen = T5EncoderModel.from_pretrained('google/flan-t5-small') | |
self.hidden_size = self.gen.config.d_model | |
self.ling_embed = nn.Linear(args.lng_dim, self.hidden_size) | |
# self.gen = nn.Sequential( | |
# nn.Linear(args.lng_dim, 2*hidden_dim), | |
# nn.ReLU(), | |
# nn.BatchNorm1d(2*hidden_dim), | |
# nn.Linear(2*hidden_dim, 2*hidden_dim), | |
# nn.ReLU(), | |
# nn.BatchNorm1d(2*hidden_dim), | |
# nn.Linear(2*hidden_dim, hidden_dim), | |
# nn.ReLU(), | |
# ) | |
self.gen_type = args.linggen_type | |
self.gen_input = args.linggen_input | |
if self.gen_type == 'vae': | |
self.gen_mu = nn.Linear(hidden_dim, args.lng_dim) | |
self.gen_logvar = nn.Linear(hidden_dim, args.lng_dim) | |
elif self.gen_type == 'det': | |
self.projection = nn.Linear(self.hidden_size, args.lng_dim) | |
def forward(self, batch): | |
inputs_embeds = self.gen.shared(batch['sentence1_input_ids']) | |
inputs_att_mask = batch['sentence1_attention_mask'] | |
bs = inputs_embeds.shape[0] | |
if self.gen_input == 's+l': | |
sent1_ling = self.ling_embed(batch['sentence1_ling']) | |
sent1_ling = sent1_ling.view(bs, 1, -1) | |
inputs_embeds = inputs_embeds + sent1_ling | |
gen = self.gen(inputs_embeds=inputs_embeds, | |
attention_mask=inputs_att_mask).last_hidden_state.mean(1) | |
# gen = self.gen(batch['sentence1_ling']) | |
cache = {} | |
if self.gen_type == 'vae': | |
mu = self.gen_mu(gen) | |
logvar = self.gen_logvar(gen) | |
output = vae_sample(mu, logvar) | |
cache['linggen_mu'] = mu | |
cache['linggen_logvar'] = logvar | |
elif self.gen_type == 'det': | |
output = self.projection(gen) | |
return output, cache | |
class LingDisc(nn.Module): | |
def __init__(self, | |
model_name, | |
disc_type, | |
disc_ckpt, | |
lng_dim=40, | |
quant_nbins=1, | |
disc_lng_dim=None, | |
lng_ids=None, | |
**kwargs): | |
super().__init__() | |
if disc_type == 't5': | |
self.encoder = T5EncoderModel.from_pretrained(model_name) | |
hidden_dim = self.encoder.config.d_model | |
self.dropout = nn.Dropout(0.2) | |
self.lng_dim = disc_lng_dim if disc_lng_dim else lng_dim | |
self.quant = quant_nbins > 1 | |
self.quant = False | |
if self.quant: | |
self.ling_classifier = nn.Linear(hidden_dim, self.lng_dim * quant_nbins) | |
else: | |
self.ling_classifier = nn.Linear(hidden_dim, self.lng_dim) | |
lng_ids = torch.tensor(lng_ids) if lng_ids is not None else None | |
# from const import used_indices | |
# lng_ids = torch.tensor(used_indices) | |
self.register_buffer('lng_ids', lng_ids) | |
elif disc_type == 'deberta': | |
self.encoder= DebertaReplacedTokenizer.from_pretrained( | |
pretrained_model_name_or_path=disc_ckpt, | |
tok_model_name = model_name, | |
problem_type='regression', num_labels=40) | |
self.quant = False | |
self.disc_type = disc_type | |
def forward(self, **batch): | |
if not 'attention_mask' in batch: | |
if 'input_ids' in batch: | |
att_mask = torch.ones_like(batch['input_ids']) | |
else: | |
att_mask = torch.ones_like(batch['logits'])[:,:,0] | |
else: | |
att_mask = batch['attention_mask'] | |
if 'input_ids' in batch: | |
enc_output = self.encoder(input_ids=batch['input_ids'], | |
attention_mask=att_mask) | |
elif 'logits' in batch: | |
logits = batch['logits'] | |
scores = F.softmax(logits, dim = -1) | |
onehot = F.one_hot(logits.argmax(-1), num_classes=logits.shape[2]).float().to(logits.device) | |
onehot_ = scores - scores.detach() + onehot | |
embed_layer = self.encoder.get_input_embeddings() | |
if isinstance(embed_layer, nn.Sequential): | |
for i, module in enumerate(embed_layer): | |
if i == 0: | |
embeds = torch.matmul(onehot_, module.weight) | |
else: | |
embeds = module(embeds) | |
else: | |
embeds = onehot_ @ embed_layer.weight | |
embeds = torch.matmul(onehot_, embed_layer.weight) | |
enc_output = self.encoder(inputs_embeds=embeds, | |
attention_mask=att_mask) | |
if self.disc_type == 't5': | |
sent_emb = self.dropout(enc_output.last_hidden_state.mean(1)) | |
bs = sent_emb.shape[0] | |
output = self.ling_classifier(sent_emb) | |
if self.quant: | |
output = output.reshape(bs, -1, self.lng_dim) | |
if self.lng_ids is not None: | |
output = torch.index_select(output, 1, self.lng_ids) | |
elif self.disc_type == 'deberta': | |
output = enc_output.logits | |
return output | |
class SemEmb(nn.Module): | |
def __init__(self, backbone, sep_token_id): | |
super().__init__() | |
self.backbone = backbone | |
self.sep_token_id = sep_token_id | |
hidden_dim = self.backbone.config.d_model | |
self.projection = nn.Sequential(nn.ReLU(), | |
nn.Dropout(0.2), | |
nn.Linear(hidden_dim, 1)) | |
def forward(self, **batch): | |
bs = batch['sentence1_attention_mask'].shape[0] | |
ones = torch.ones((bs, 1), device=batch['sentence1_attention_mask'].device) | |
sep = torch.ones((bs, 1), dtype=torch.long, | |
device=batch['sentence1_attention_mask'].device) * self.sep_token_id | |
att_mask = torch.cat([batch['sentence1_attention_mask'], ones, batch['sentence2_attention_mask']], dim=1) | |
if 'logits' in batch: | |
input_ids = torch.cat([batch['sentence1_input_ids'], sep], dim=1) | |
embeds1 = self.backbone.shared(input_ids) | |
logits = batch['logits'] | |
scores = F.softmax(logits, dim = -1) | |
onehot = F.one_hot(logits.argmax(-1), num_classes=logits.shape[2]).float().to(logits.device) | |
onehot_ = scores - scores.detach() + onehot | |
embeds2 = onehot_ @ self.backbone.shared.weight | |
embeds1_2 = torch.cat([embeds1, embeds2], dim=1) | |
hidden_units = self.backbone(inputs_embeds=embeds1_2, | |
attention_mask=att_mask).last_hidden_state.mean(1) | |
elif 'sentence2_input_ids' in batch: | |
input_ids = torch.cat([batch['sentence1_input_ids'], sep, batch['sentence2_input_ids']], dim=1) | |
hidden_units = self.backbone(input_ids=input_ids, | |
attention_mask=att_mask).last_hidden_state.mean(1) | |
probs = self.projection(hidden_units) | |
return probs | |
def prepare_inputs_for_generation( | |
combine_method, | |
ling2_only, | |
self, | |
input_ids, | |
past_key_values=None, | |
attention_mask=None, | |
head_mask=None, | |
decoder_head_mask=None, | |
cross_attn_head_mask=None, | |
use_cache=None, | |
encoder_outputs=None, | |
sent1_ling=None, | |
sent2_ling=None, | |
**kwargs | |
): | |
# cut decoder_input_ids if past is used | |
if past_key_values is not None: | |
input_ids = input_ids[:, -1:] | |
input_ids = input_ids.clone() | |
decoder_inputs_embeds = self.shared(input_ids) | |
if combine_method == 'decoder_add_first': | |
sent2_ling = torch.cat([sent2_ling, | |
torch.repeat_interleave(torch.zeros_like(sent2_ling), input_ids.shape[1] - 1, dim=1)], dim = 1) | |
if combine_method == 'decoder_concat': | |
if ling2_only: | |
decoder_inputs_embeds = torch.cat([sent2_ling, decoder_inputs_embeds], dim=1) | |
else: | |
decoder_inputs_embeds = torch.cat([sent1_ling, sent2_ling, decoder_inputs_embeds], dim=1) | |
elif combine_method == 'decoder_add'or (past_key_values is None and combine_method == 'decoder_add_first'): | |
if ling2_only: | |
decoder_inputs_embeds = decoder_inputs_embeds + sent2_ling | |
else: | |
decoder_inputs_embeds = decoder_inputs_embeds + sent1_ling + sent2_ling | |
return { | |
"decoder_inputs_embeds": decoder_inputs_embeds, | |
"past_key_values": past_key_values, | |
"encoder_outputs": encoder_outputs, | |
"attention_mask": attention_mask, | |
"head_mask": head_mask, | |
"decoder_head_mask": decoder_head_mask, | |
"cross_attn_head_mask": cross_attn_head_mask, | |
"use_cache": use_cache, | |
} | |
class LogitsAdd(LogitsProcessor): | |
def __init__(self, sent2_ling): | |
super().__init__() | |
self.sent2_ling = sent2_ling | |
def __call__(self, input_ids, scores): | |
return scores + self.sent2_ling | |
class EncoderDecoderVAE(nn.Module): | |
def __init__(self, args, pad_token_id, sepeos_token_id, vocab_size = 32128): | |
super().__init__() | |
self.backbone = T5ForConditionalGeneration.from_pretrained(args.model_name) | |
self.backbone.prepare_inputs_for_generation = types.MethodType( | |
partial(prepare_inputs_for_generation, args.combine_method, args.ling2_only), | |
self.backbone) | |
self.args = args | |
self.pad_token_id = pad_token_id | |
self.eos_token_id = sepeos_token_id | |
hidden_dim = self.backbone.config.d_model if not 'logits' in args.combine_method else vocab_size | |
if args.combine_method == 'fusion1': | |
self.fusion = nn.Sequential( | |
nn.Linear(hidden_dim + 2 * args.lng_dim, hidden_dim), | |
) | |
elif args.combine_method == 'fusion2': | |
self.fusion = nn.Sequential( | |
nn.Linear(hidden_dim + 2 * args.lng_dim, hidden_dim), | |
nn.ReLU(), | |
nn.Linear(hidden_dim, hidden_dim), | |
) | |
elif 'concat' in args.combine_method or 'add' in args.combine_method: | |
if args.ling_embed_type == 'two-layer': | |
self.ling_embed = nn.Sequential( | |
nn.Linear(args.lng_dim, args.lng_dim), | |
nn.ReLU(), | |
nn.Linear(args.lng_dim, hidden_dim), | |
) | |
else: | |
self.ling_embed = nn.Linear(args.lng_dim, hidden_dim) | |
self.ling_dropout = nn.Dropout(args.ling_dropout) | |
if args.ling_vae: | |
self.ling_mu = nn.Linear(hidden_dim, hidden_dim) | |
self.ling_logvar = nn.Linear(hidden_dim, hidden_dim) | |
nn.init.xavier_uniform_(self.ling_embed.weight) | |
nn.init.xavier_uniform_(self.ling_mu.weight) | |
nn.init.xavier_uniform_(self.ling_logvar.weight) | |
generate_with_grad = unwrap(self.backbone.generate) | |
self.backbone.generate_with_grad = MethodType(generate_with_grad, self.backbone) | |
def get_fusion_layer(self): | |
if 'fusion' in self.args.combine_method: | |
return self.fusion | |
elif 'concat' in self.args.combine_method or 'add' in self.args.combine_method: | |
return self.ling_embed | |
else: | |
return None | |
def sample(self, mu, logvar): | |
std = torch.exp(0.5 * logvar) | |
return mu + std * torch.randn_like(std) | |
def encode(self, batch): | |
if 'inputs_embeds' in batch: | |
inputs_embeds = batch['inputs_embeds'] | |
else: | |
inputs_embeds = self.backbone.shared(batch['sentence1_input_ids']) | |
inputs_att_mask = batch['sentence1_attention_mask'] | |
bs = inputs_embeds.shape[0] | |
cache = {} | |
if self.args.combine_method in ('input_concat', 'input_add'): | |
if 'sent1_ling_embed' in batch: | |
sent1_ling = batch['sent1_ling_embed'] | |
else: | |
sent1_ling = self.ling_embed(self.ling_dropout(batch['sentence1_ling'])) | |
if 'sent2_ling_embed' in batch: | |
sent2_ling = batch['sent2_ling_embed'] | |
else: | |
sent2_ling = self.ling_embed(self.ling_dropout(batch['sentence2_ling'])) | |
if self.args.ling_vae: | |
sent1_ling = F.leaky_relu(sent1_ling) | |
sent1_mu, sent1_logvar = self.ling_mu(sent1_ling), self.ling_logvar(sent1_ling) | |
sent1_ling = self.sample(sent1_mu, sent1_logvar) | |
sent2_ling = F.leaky_relu(sent2_ling) | |
sent2_mu, sent2_logvar = self.ling_mu(sent2_ling), self.ling_logvar(sent2_ling) | |
sent2_ling = self.sample(sent2_mu, sent2_logvar) | |
cache.update({'sent1_mu': sent1_mu, 'sent1_logvar': sent1_logvar, | |
'sent2_mu': sent2_mu, 'sent2_logvar': sent2_logvar, | |
'sent1_ling': sent1_ling, 'sent2_ling': sent2_ling}) | |
else: | |
cache.update({'sent1_ling': sent1_ling, 'sent2_ling': sent2_ling}) | |
sent1_ling = sent1_ling.view(bs, 1, -1) | |
sent2_ling = sent2_ling.view(bs, 1, -1) | |
if self.args.combine_method == 'input_concat': | |
if self.args.ling2_only: | |
inputs_embeds = torch.cat([inputs_embeds, sent2_ling], dim=1) | |
inputs_att_mask = torch.cat([inputs_att_mask, | |
torch.ones((bs, 1)).to(inputs_embeds.device)], dim=1) | |
else: | |
inputs_embeds = torch.cat([inputs_embeds, sent1_ling, sent2_ling], dim=1) | |
inputs_att_mask = torch.cat([inputs_att_mask, | |
torch.ones((bs, 2)).to(inputs_embeds.device)], dim=1) | |
elif self.args.combine_method == 'input_add': | |
if self.args.ling2_only: | |
inputs_embeds = inputs_embeds + sent2_ling | |
else: | |
inputs_embeds = inputs_embeds + sent1_ling + sent2_ling | |
return self.backbone.encoder(inputs_embeds=inputs_embeds, | |
attention_mask=inputs_att_mask), inputs_att_mask, cache | |
def decode(self, batch, enc_output, inputs_att_mask, generate): | |
bs = inputs_att_mask.shape[0] | |
cache = {} | |
if self.args.combine_method in ('embed_concat', 'decoder_concat', 'decoder_add', 'logits_add', 'decoder_add_first'): | |
if 'sent1_ling_embed' in batch: | |
sent1_ling = batch['sent1_ling_embed'] | |
elif 'sentence1_ling' in batch: | |
sent1_ling = self.ling_embed(self.ling_dropout(batch['sentence1_ling'])) | |
else: | |
sent1_ling = None | |
if 'sent2_ling_embed' in batch: | |
sent2_ling = batch['sent2_ling_embed'] | |
else: | |
sent2_ling = self.ling_embed(self.ling_dropout(batch['sentence2_ling'])) | |
if self.args.ling_vae: | |
sent1_ling = F.leaky_relu(sent1_ling) | |
sent1_mu, sent1_logvar = self.ling_mu(sent1_ling), self.ling_logvar(sent1_ling) | |
sent1_ling = self.sample(sent1_mu, sent1_logvar) | |
sent2_ling = F.leaky_relu(sent2_ling) | |
sent2_mu, sent2_logvar = self.ling_mu(sent2_ling), self.ling_logvar(sent2_ling) | |
sent2_ling = self.sample(sent2_mu, sent2_logvar) | |
cache.update({'sent1_mu': sent1_mu, 'sent1_logvar': sent1_logvar, | |
'sent2_mu': sent2_mu, 'sent2_logvar': sent2_logvar, | |
'sent1_ling': sent1_ling, 'sent2_ling': sent2_ling}) | |
else: | |
cache.update({'sent2_ling': sent2_ling}) | |
if sent1_ling is not None: | |
cache.update({'sent1_ling': sent1_ling}) | |
if sent1_ling is not None: | |
sent1_ling = sent1_ling.view(bs, 1, -1) | |
sent2_ling = sent2_ling.view(bs, 1, -1) | |
if self.args.combine_method == 'decoder_add_first' and not generate: | |
sent2_ling = torch.cat([sent2_ling, | |
torch.repeat_interleave(torch.zeros_like(sent2_ling), batch['sentence2_input_ids'].shape[1] - 1, dim=1)], dim = 1) | |
else: | |
sent1_ling, sent2_ling = None, None | |
if self.args.combine_method == 'embed_concat': | |
enc_output.last_hidden_state = torch.cat([enc_output.last_hidden_state, | |
sent1_ling, sent2_ling], dim=1) | |
inputs_att_mask = torch.cat([inputs_att_mask, | |
torch.ones((bs, 2)).to(inputs_att_mask.device)], dim=1) | |
elif 'fusion' in self.args.combine_method: | |
sent1_ling = batch['sentence1_ling'].unsqueeze(1)\ | |
.expand(-1, enc_output.last_hidden_state.shape[1], -1) | |
sent2_ling = batch['sentence2_ling'].unsqueeze(1)\ | |
.expand(-1, enc_output.last_hidden_state.shape[1], -1) | |
if self.args.ling2_only: | |
combined_embedding = torch.cat([enc_output.last_hidden_state, sent2_ling], dim=2) | |
else: | |
combined_embedding = torch.cat([enc_output.last_hidden_state, sent1_ling, sent2_ling], dim=2) | |
enc_output.last_hidden_state = self.fusion(combined_embedding) | |
if generate: | |
if self.args.combine_method == 'logits_add': | |
logits_processor = LogitsProcessorList([LogitsAdd(sent2_ling.view(bs, -1))]) | |
else: | |
logits_processor = LogitsProcessorList() | |
dec_output = self.backbone.generate_with_grad( | |
attention_mask=inputs_att_mask, | |
encoder_outputs=enc_output, | |
sent1_ling=sent1_ling, | |
sent2_ling=sent2_ling, | |
return_dict_in_generate=True, | |
output_scores=True, | |
logits_processor = logits_processor, | |
# renormalize_logits=True, | |
# do_sample=True, | |
# top_p=0.8, | |
eos_token_id=self.eos_token_id, | |
# min_new_tokens=3, | |
# repetition_penalty=1.2, | |
max_length=self.args.max_length, | |
) | |
scores = torch.stack(dec_output.scores, 1) | |
cache.update({'scores': scores}) | |
return dec_output.sequences, cache | |
decoder_input_ids = self.backbone._shift_right(batch['sentence2_input_ids']) | |
decoder_inputs_embeds = self.backbone.shared(decoder_input_ids) | |
decoder_att_mask = batch['sentence2_attention_mask'] | |
labels = batch['sentence2_input_ids'].clone() | |
labels[labels == self.pad_token_id] = -100 | |
if self.args.combine_method == 'decoder_concat': | |
if self.args.ling2_only: | |
decoder_inputs_embeds = torch.cat([sent2_ling, decoder_inputs_embeds], dim=1) | |
decoder_att_mask = torch.cat([torch.ones((bs, 1)).to(decoder_inputs_embeds.device), decoder_att_mask], dim=1) | |
labels = torch.cat([torch.ones((bs, 1), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id, | |
labels], dim=1) | |
else: | |
decoder_inputs_embeds = torch.cat([sent1_ling, sent2_ling, decoder_inputs_embeds], dim=1) | |
decoder_att_mask = torch.cat([torch.ones((bs, 2)).to(decoder_inputs_embeds.device), decoder_att_mask], dim=1) | |
labels = torch.cat([torch.ones((bs, 2), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id, | |
labels], dim=1) | |
elif self.args.combine_method == 'decoder_add' or self.args.combine_method == 'decoder_add_first' : | |
if self.args.ling2_only: | |
decoder_inputs_embeds = decoder_inputs_embeds + self.args.combine_weight * sent2_ling | |
else: | |
decoder_inputs_embeds = decoder_inputs_embeds + sent1_ling + sent2_ling | |
dec_output = self.backbone( | |
decoder_inputs_embeds=decoder_inputs_embeds, | |
decoder_attention_mask=decoder_att_mask, | |
encoder_outputs=enc_output, | |
attention_mask=inputs_att_mask, | |
labels=labels, | |
) | |
if self.args.combine_method == 'logits_add': | |
dec_output.logits = dec_output.logits + self.args.combine_weight * sent2_ling | |
vocab_size = dec_output.logits.size(-1) | |
dec_output.loss = F.cross_entropy(dec_output.logits.view(-1, vocab_size), labels.view(-1)) | |
return dec_output, cache | |
def forward(self, batch, generate=False): | |
enc_output, enc_att_mask, cache = self.encode(batch) | |
dec_output, cache2 = self.decode(batch, enc_output, enc_att_mask, generate) | |
cache.update(cache2) | |
return dec_output, enc_output, cache | |
def infer_with_cache(self, batch): | |
dec_output, _, cache = self(batch, generate = True) | |
return dec_output, cache | |
def infer(self, batch): | |
dec_output, _ = self.infer_with_cache(batch) | |
return dec_output | |
def infer_with_feedback_BP(self, ling_disc, sem_emb, batch, tokenizer): | |
from torch.autograd import grad | |
interpolations = [] | |
def line_search(): | |
best_val = None | |
best_loss = None | |
eta = 1e3 | |
sem_prob = 1 | |
patience = 4 | |
while patience > 0: | |
param_ = param - eta * grads | |
with torch.no_grad(): | |
new_loss, pred = get_loss(param_) | |
max_len = pred.shape[1] | |
lens = torch.where(pred == self.eos_token_id, 1, 0).argmax(-1) + 1 | |
batch.update({ | |
'sentence2_input_ids': pred, | |
'sentence2_attention_mask': sequence_mask(lens, max_len = max_len) | |
}) | |
sem_prob = torch.sigmoid(sem_emb(**batch)).item() | |
# if sem_prob <= 0.1: | |
# patience -= 1 | |
if new_loss < loss and sem_prob >= 0.90 and lens.item() > 1: | |
return param_ | |
eta *= 2.25 | |
patience -= 1 | |
return False | |
def get_loss(param): | |
if self.args.feedback_param == 'l': | |
batch.update({'sent2_ling_embed': param}) | |
elif self.args.feedback_param == 's': | |
batch.update({'inputs_embeds': param}) | |
if self.args.feedback_param == 'logits': | |
logits = param | |
pred = param.argmax(-1) | |
else: | |
pred, cache = self.infer_with_cache(batch) | |
logits = cache['scores'] | |
out = ling_disc(logits = logits) | |
probs = F.softmax(out, 1) | |
if ling_disc.quant: | |
loss = F.cross_entropy(out, batch['sentence2_discr']) | |
else: | |
loss = F.mse_loss(out, batch['sentence2_ling']) | |
return loss, pred | |
if self.args.feedback_param == 'l': | |
ling2_embed = self.ling_embed(batch['sentence2_ling']) | |
param = torch.nn.Parameter(ling2_embed, requires_grad = True) | |
elif self.args.feedback_param == 's': | |
inputs_embeds = self.backbone.shared(batch['sentence1_input_ids']) | |
param = torch.nn.Parameter(inputs_embeds, requires_grad = True) | |
elif self.args.feedback_param == 'logits': | |
logits = self.infer_with_cache(batch)[1]['scores'] | |
param = torch.nn.Parameter(logits, requires_grad = True) | |
target_np = batch['sentence2_ling'][0].cpu().numpy() | |
while True: | |
loss, pred = get_loss(param) | |
pred_text = tokenizer.batch_decode(pred.cpu().numpy(), | |
skip_special_tokens=True)[0] | |
interpolations.append(pred_text) | |
if loss < 1: | |
break | |
self.zero_grad() | |
grads = grad(loss, param)[0] | |
param = line_search() | |
if param is False: | |
break | |
return pred, [pred_text, interpolations] | |
def set_grad(module, state): | |
if module is not None: | |
for p in module.parameters(): | |
p.requires_grad = state | |
def set_grad_except(model, name, state): | |
for n, p in model.named_parameters(): | |
if not name in n: | |
p.requires_grad = state | |
class SemEmbPipeline(): | |
def __init__(self, | |
ckpt = "/data/mohamed/checkpoints/ling_conversion_sem_emb_best.pt"): | |
self.tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base") | |
self.model = SemEmb(T5EncoderModel.from_pretrained('google/flan-t5-base'), self.tokenizer.get_vocab()['</s>']) | |
state = torch.load(ckpt) | |
self.model.load_state_dict(state['model'], strict=False) | |
self.model.eval() | |
self.model.cuda() | |
def __call__(self, sentence1, sentence2): | |
sentence1 = self.tokenizer(sentence1, return_attention_mask = True, return_tensors = 'pt') | |
sentence2 = self.tokenizer(sentence2, return_attention_mask = True, return_tensors = 'pt') | |
sem_logit = self.model( | |
sentence1_input_ids = sentence1.input_ids.cuda(), | |
sentence1_attention_mask = sentence1.attention_mask.cuda(), | |
sentence2_input_ids = sentence2.input_ids.cuda(), | |
sentence2_attention_mask = sentence2.attention_mask.cuda(), | |
) | |
sem_prob = torch.sigmoid(sem_logit).item() | |
return sem_prob | |
class LingDiscPipeline(): | |
def __init__(self, | |
model_name="google/flan-t5-base", | |
disc_type='deberta', | |
disc_ckpt='/data/mohamed/checkpoints/ling_disc/deberta-v3-small_flan-t5-base_40', | |
# disc_type='t5', | |
# disc_ckpt='/data/mohamed/checkpoints/ling_conversion_ling_disc.pt', | |
): | |
self.tokenizer = T5Tokenizer.from_pretrained(model_name) | |
self.model = LingDisc(model_name, disc_type, disc_ckpt) | |
self.model.eval() | |
self.model.cuda() | |
def __call__(self, sentence): | |
inputs = self.tokenizer(sentence, return_tensors = 'pt') | |
with torch.no_grad(): | |
ling_pred = self.model(input_ids=inputs.input_ids.cuda()) | |
return ling_pred | |
def get_model(args, tokenizer, device): | |
if args.pretrain_disc or args.disc_loss or args.disc_ckpt: | |
ling_disc = LingDisc(args.model_name, args.disc_type, args.disc_ckpt).to(device) | |
else: | |
ling_disc = None | |
if args.linggen_type != 'none': | |
ling_gen = LingGenerator(args).to(device) | |
if args.sem_loss or args.sem_ckpt: | |
if args.sem_loss_type == 'shared': | |
sem_emb = seld.backbone.encoder | |
elif args.sem_loss_type == 'dedicated': | |
sem_emb = SemEmb(T5EncoderModel.from_pretrained('google/flan-t5-base'), tokenizer.eos_token_id).to(device) | |
else: | |
raise NotImplementedError('Semantic loss type') | |
else: | |
sem_emb = None | |
if not args.pretrain_disc: | |
model = EncoderDecoderVAE(args, tokenizer.pad_token_id, tokenizer.eos_token_id).to(device) | |
if args.use_lora: | |
target_modules = ["Attention.k", "Attention.q", "Attention.v", "Attention.o", "lm_head", "wi_0", "wi_1", "wo"] | |
target_modules = '|'.join(f'(.*{module})' for module in target_modules) | |
target_modules = f'backbone.({target_modules})' | |
config = LoraConfig( | |
r=args.lora_r, | |
lora_alpha=args.lora_r * 2, | |
target_modules=target_modules, | |
lora_dropout=0.1, | |
bias="lora_only", | |
modules_to_save=['ling_embed'], | |
) | |
model = get_peft_model(model, config) | |
model.print_trainable_parameters() | |
else: | |
model = ling_disc | |
return model, ling_disc, sem_emb | |