LingConv / model.py
mohdelgaar's picture
Update layout and samples
674b430
import types
import torch
import torch.nn.functional as F
import numpy as np
from torch import nn
from transformers import T5ForConditionalGeneration, T5EncoderModel, AutoModel, LogitsProcessor, LogitsProcessorList
from functools import partial
from compute_lng import compute_lng
from undecorate import unwrap
from types import MethodType
from utils import *
from ling_disc import DebertaReplacedTokenizer
from const import *
def vae_sample(mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
class VAE(nn.Module):
def __init__(self, args):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(args.input_dim, args.hidden_dim),
nn.ReLU(),
nn.Linear(args.hidden_dim, args.hidden_dim),
nn.ReLU(),
)
self.decoder = nn.Sequential(
nn.Linear(args.latent_dim, args.hidden_dim),
nn.ReLU(),
nn.Linear(args.hidden_dim, args.hidden_dim),
nn.ReLU(),
nn.Linear(args.hidden_dim, args.input_dim),
)
self.fc_mu = nn.Linear(args.hidden_dim, args.latent_dim)
self.fc_var = nn.Linear(args.hidden_dim, args.latent_dim)
def forward(self, x):
h = self.encoder(x)
mu = self.fc_mu(h)
logvar = self.fc_var(h)
x = vae_sample(mu, logvar)
o = self.decoder(x)
return o, (mu, logvar)
class LingGenerator(nn.Module):
def __init__(self, args, hidden_dim=1000):
super().__init__()
self.gen = T5EncoderModel.from_pretrained('google/flan-t5-small')
self.hidden_size = self.gen.config.d_model
self.ling_embed = nn.Linear(args.lng_dim, self.hidden_size)
# self.gen = nn.Sequential(
# nn.Linear(args.lng_dim, 2*hidden_dim),
# nn.ReLU(),
# nn.BatchNorm1d(2*hidden_dim),
# nn.Linear(2*hidden_dim, 2*hidden_dim),
# nn.ReLU(),
# nn.BatchNorm1d(2*hidden_dim),
# nn.Linear(2*hidden_dim, hidden_dim),
# nn.ReLU(),
# )
self.gen_type = args.linggen_type
self.gen_input = args.linggen_input
if self.gen_type == 'vae':
self.gen_mu = nn.Linear(hidden_dim, args.lng_dim)
self.gen_logvar = nn.Linear(hidden_dim, args.lng_dim)
elif self.gen_type == 'det':
self.projection = nn.Linear(self.hidden_size, args.lng_dim)
def forward(self, batch):
inputs_embeds = self.gen.shared(batch['sentence1_input_ids'])
inputs_att_mask = batch['sentence1_attention_mask']
bs = inputs_embeds.shape[0]
if self.gen_input == 's+l':
sent1_ling = self.ling_embed(batch['sentence1_ling'])
sent1_ling = sent1_ling.view(bs, 1, -1)
inputs_embeds = inputs_embeds + sent1_ling
gen = self.gen(inputs_embeds=inputs_embeds,
attention_mask=inputs_att_mask).last_hidden_state.mean(1)
# gen = self.gen(batch['sentence1_ling'])
cache = {}
if self.gen_type == 'vae':
mu = self.gen_mu(gen)
logvar = self.gen_logvar(gen)
output = vae_sample(mu, logvar)
cache['linggen_mu'] = mu
cache['linggen_logvar'] = logvar
elif self.gen_type == 'det':
output = self.projection(gen)
return output, cache
class LingDisc(nn.Module):
def __init__(self,
model_name,
disc_type,
disc_ckpt,
lng_dim=40,
quant_nbins=1,
disc_lng_dim=None,
lng_ids=None,
**kwargs):
super().__init__()
if disc_type == 't5':
self.encoder = T5EncoderModel.from_pretrained(model_name)
hidden_dim = self.encoder.config.d_model
self.dropout = nn.Dropout(0.2)
self.lng_dim = disc_lng_dim if disc_lng_dim else lng_dim
self.quant = quant_nbins > 1
self.quant = False
if self.quant:
self.ling_classifier = nn.Linear(hidden_dim, self.lng_dim * quant_nbins)
else:
self.ling_classifier = nn.Linear(hidden_dim, self.lng_dim)
lng_ids = torch.tensor(lng_ids) if lng_ids is not None else None
# from const import used_indices
# lng_ids = torch.tensor(used_indices)
self.register_buffer('lng_ids', lng_ids)
elif disc_type == 'deberta':
self.encoder= DebertaReplacedTokenizer.from_pretrained(
pretrained_model_name_or_path=disc_ckpt,
tok_model_name = model_name,
problem_type='regression', num_labels=40)
self.quant = False
self.disc_type = disc_type
def forward(self, **batch):
if not 'attention_mask' in batch:
if 'input_ids' in batch:
att_mask = torch.ones_like(batch['input_ids'])
else:
att_mask = torch.ones_like(batch['logits'])[:,:,0]
else:
att_mask = batch['attention_mask']
if 'input_ids' in batch:
enc_output = self.encoder(input_ids=batch['input_ids'],
attention_mask=att_mask)
elif 'logits' in batch:
logits = batch['logits']
scores = F.softmax(logits, dim = -1)
onehot = F.one_hot(logits.argmax(-1), num_classes=logits.shape[2]).float().to(logits.device)
onehot_ = scores - scores.detach() + onehot
embed_layer = self.encoder.get_input_embeddings()
if isinstance(embed_layer, nn.Sequential):
for i, module in enumerate(embed_layer):
if i == 0:
embeds = torch.matmul(onehot_, module.weight)
else:
embeds = module(embeds)
else:
embeds = onehot_ @ embed_layer.weight
embeds = torch.matmul(onehot_, embed_layer.weight)
enc_output = self.encoder(inputs_embeds=embeds,
attention_mask=att_mask)
if self.disc_type == 't5':
sent_emb = self.dropout(enc_output.last_hidden_state.mean(1))
bs = sent_emb.shape[0]
output = self.ling_classifier(sent_emb)
if self.quant:
output = output.reshape(bs, -1, self.lng_dim)
if self.lng_ids is not None:
output = torch.index_select(output, 1, self.lng_ids)
elif self.disc_type == 'deberta':
output = enc_output.logits
return output
class SemEmb(nn.Module):
def __init__(self, backbone, sep_token_id):
super().__init__()
self.backbone = backbone
self.sep_token_id = sep_token_id
hidden_dim = self.backbone.config.d_model
self.projection = nn.Sequential(nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim, 1))
def forward(self, **batch):
bs = batch['sentence1_attention_mask'].shape[0]
ones = torch.ones((bs, 1), device=batch['sentence1_attention_mask'].device)
sep = torch.ones((bs, 1), dtype=torch.long,
device=batch['sentence1_attention_mask'].device) * self.sep_token_id
att_mask = torch.cat([batch['sentence1_attention_mask'], ones, batch['sentence2_attention_mask']], dim=1)
if 'logits' in batch:
input_ids = torch.cat([batch['sentence1_input_ids'], sep], dim=1)
embeds1 = self.backbone.shared(input_ids)
logits = batch['logits']
scores = F.softmax(logits, dim = -1)
onehot = F.one_hot(logits.argmax(-1), num_classes=logits.shape[2]).float().to(logits.device)
onehot_ = scores - scores.detach() + onehot
embeds2 = onehot_ @ self.backbone.shared.weight
embeds1_2 = torch.cat([embeds1, embeds2], dim=1)
hidden_units = self.backbone(inputs_embeds=embeds1_2,
attention_mask=att_mask).last_hidden_state.mean(1)
elif 'sentence2_input_ids' in batch:
input_ids = torch.cat([batch['sentence1_input_ids'], sep, batch['sentence2_input_ids']], dim=1)
hidden_units = self.backbone(input_ids=input_ids,
attention_mask=att_mask).last_hidden_state.mean(1)
probs = self.projection(hidden_units)
return probs
def prepare_inputs_for_generation(
combine_method,
ling2_only,
self,
input_ids,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
sent1_ling=None,
sent2_ling=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past_key_values is not None:
input_ids = input_ids[:, -1:]
input_ids = input_ids.clone()
decoder_inputs_embeds = self.shared(input_ids)
if combine_method == 'decoder_add_first':
sent2_ling = torch.cat([sent2_ling,
torch.repeat_interleave(torch.zeros_like(sent2_ling), input_ids.shape[1] - 1, dim=1)], dim = 1)
if combine_method == 'decoder_concat':
if ling2_only:
decoder_inputs_embeds = torch.cat([sent2_ling, decoder_inputs_embeds], dim=1)
else:
decoder_inputs_embeds = torch.cat([sent1_ling, sent2_ling, decoder_inputs_embeds], dim=1)
elif combine_method == 'decoder_add'or (past_key_values is None and combine_method == 'decoder_add_first'):
if ling2_only:
decoder_inputs_embeds = decoder_inputs_embeds + sent2_ling
else:
decoder_inputs_embeds = decoder_inputs_embeds + sent1_ling + sent2_ling
return {
"decoder_inputs_embeds": decoder_inputs_embeds,
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
}
class LogitsAdd(LogitsProcessor):
def __init__(self, sent2_ling):
super().__init__()
self.sent2_ling = sent2_ling
def __call__(self, input_ids, scores):
return scores + self.sent2_ling
class EncoderDecoderVAE(nn.Module):
def __init__(self, args, pad_token_id, sepeos_token_id, vocab_size = 32128):
super().__init__()
self.backbone = T5ForConditionalGeneration.from_pretrained(args.model_name)
self.backbone.prepare_inputs_for_generation = types.MethodType(
partial(prepare_inputs_for_generation, args.combine_method, args.ling2_only),
self.backbone)
self.args = args
self.pad_token_id = pad_token_id
self.eos_token_id = sepeos_token_id
hidden_dim = self.backbone.config.d_model if not 'logits' in args.combine_method else vocab_size
if args.combine_method == 'fusion1':
self.fusion = nn.Sequential(
nn.Linear(hidden_dim + 2 * args.lng_dim, hidden_dim),
)
elif args.combine_method == 'fusion2':
self.fusion = nn.Sequential(
nn.Linear(hidden_dim + 2 * args.lng_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
)
elif 'concat' in args.combine_method or 'add' in args.combine_method:
if args.ling_embed_type == 'two-layer':
self.ling_embed = nn.Sequential(
nn.Linear(args.lng_dim, args.lng_dim),
nn.ReLU(),
nn.Linear(args.lng_dim, hidden_dim),
)
else:
self.ling_embed = nn.Linear(args.lng_dim, hidden_dim)
self.ling_dropout = nn.Dropout(args.ling_dropout)
if args.ling_vae:
self.ling_mu = nn.Linear(hidden_dim, hidden_dim)
self.ling_logvar = nn.Linear(hidden_dim, hidden_dim)
nn.init.xavier_uniform_(self.ling_embed.weight)
nn.init.xavier_uniform_(self.ling_mu.weight)
nn.init.xavier_uniform_(self.ling_logvar.weight)
generate_with_grad = unwrap(self.backbone.generate)
self.backbone.generate_with_grad = MethodType(generate_with_grad, self.backbone)
def get_fusion_layer(self):
if 'fusion' in self.args.combine_method:
return self.fusion
elif 'concat' in self.args.combine_method or 'add' in self.args.combine_method:
return self.ling_embed
else:
return None
def sample(self, mu, logvar):
std = torch.exp(0.5 * logvar)
return mu + std * torch.randn_like(std)
def encode(self, batch):
if 'inputs_embeds' in batch:
inputs_embeds = batch['inputs_embeds']
else:
inputs_embeds = self.backbone.shared(batch['sentence1_input_ids'])
inputs_att_mask = batch['sentence1_attention_mask']
bs = inputs_embeds.shape[0]
cache = {}
if self.args.combine_method in ('input_concat', 'input_add'):
if 'sent1_ling_embed' in batch:
sent1_ling = batch['sent1_ling_embed']
else:
sent1_ling = self.ling_embed(self.ling_dropout(batch['sentence1_ling']))
if 'sent2_ling_embed' in batch:
sent2_ling = batch['sent2_ling_embed']
else:
sent2_ling = self.ling_embed(self.ling_dropout(batch['sentence2_ling']))
if self.args.ling_vae:
sent1_ling = F.leaky_relu(sent1_ling)
sent1_mu, sent1_logvar = self.ling_mu(sent1_ling), self.ling_logvar(sent1_ling)
sent1_ling = self.sample(sent1_mu, sent1_logvar)
sent2_ling = F.leaky_relu(sent2_ling)
sent2_mu, sent2_logvar = self.ling_mu(sent2_ling), self.ling_logvar(sent2_ling)
sent2_ling = self.sample(sent2_mu, sent2_logvar)
cache.update({'sent1_mu': sent1_mu, 'sent1_logvar': sent1_logvar,
'sent2_mu': sent2_mu, 'sent2_logvar': sent2_logvar,
'sent1_ling': sent1_ling, 'sent2_ling': sent2_ling})
else:
cache.update({'sent1_ling': sent1_ling, 'sent2_ling': sent2_ling})
sent1_ling = sent1_ling.view(bs, 1, -1)
sent2_ling = sent2_ling.view(bs, 1, -1)
if self.args.combine_method == 'input_concat':
if self.args.ling2_only:
inputs_embeds = torch.cat([inputs_embeds, sent2_ling], dim=1)
inputs_att_mask = torch.cat([inputs_att_mask,
torch.ones((bs, 1)).to(inputs_embeds.device)], dim=1)
else:
inputs_embeds = torch.cat([inputs_embeds, sent1_ling, sent2_ling], dim=1)
inputs_att_mask = torch.cat([inputs_att_mask,
torch.ones((bs, 2)).to(inputs_embeds.device)], dim=1)
elif self.args.combine_method == 'input_add':
if self.args.ling2_only:
inputs_embeds = inputs_embeds + sent2_ling
else:
inputs_embeds = inputs_embeds + sent1_ling + sent2_ling
return self.backbone.encoder(inputs_embeds=inputs_embeds,
attention_mask=inputs_att_mask), inputs_att_mask, cache
def decode(self, batch, enc_output, inputs_att_mask, generate):
bs = inputs_att_mask.shape[0]
cache = {}
if self.args.combine_method in ('embed_concat', 'decoder_concat', 'decoder_add', 'logits_add', 'decoder_add_first'):
if 'sent1_ling_embed' in batch:
sent1_ling = batch['sent1_ling_embed']
elif 'sentence1_ling' in batch:
sent1_ling = self.ling_embed(self.ling_dropout(batch['sentence1_ling']))
else:
sent1_ling = None
if 'sent2_ling_embed' in batch:
sent2_ling = batch['sent2_ling_embed']
else:
sent2_ling = self.ling_embed(self.ling_dropout(batch['sentence2_ling']))
if self.args.ling_vae:
sent1_ling = F.leaky_relu(sent1_ling)
sent1_mu, sent1_logvar = self.ling_mu(sent1_ling), self.ling_logvar(sent1_ling)
sent1_ling = self.sample(sent1_mu, sent1_logvar)
sent2_ling = F.leaky_relu(sent2_ling)
sent2_mu, sent2_logvar = self.ling_mu(sent2_ling), self.ling_logvar(sent2_ling)
sent2_ling = self.sample(sent2_mu, sent2_logvar)
cache.update({'sent1_mu': sent1_mu, 'sent1_logvar': sent1_logvar,
'sent2_mu': sent2_mu, 'sent2_logvar': sent2_logvar,
'sent1_ling': sent1_ling, 'sent2_ling': sent2_ling})
else:
cache.update({'sent2_ling': sent2_ling})
if sent1_ling is not None:
cache.update({'sent1_ling': sent1_ling})
if sent1_ling is not None:
sent1_ling = sent1_ling.view(bs, 1, -1)
sent2_ling = sent2_ling.view(bs, 1, -1)
if self.args.combine_method == 'decoder_add_first' and not generate:
sent2_ling = torch.cat([sent2_ling,
torch.repeat_interleave(torch.zeros_like(sent2_ling), batch['sentence2_input_ids'].shape[1] - 1, dim=1)], dim = 1)
else:
sent1_ling, sent2_ling = None, None
if self.args.combine_method == 'embed_concat':
enc_output.last_hidden_state = torch.cat([enc_output.last_hidden_state,
sent1_ling, sent2_ling], dim=1)
inputs_att_mask = torch.cat([inputs_att_mask,
torch.ones((bs, 2)).to(inputs_att_mask.device)], dim=1)
elif 'fusion' in self.args.combine_method:
sent1_ling = batch['sentence1_ling'].unsqueeze(1)\
.expand(-1, enc_output.last_hidden_state.shape[1], -1)
sent2_ling = batch['sentence2_ling'].unsqueeze(1)\
.expand(-1, enc_output.last_hidden_state.shape[1], -1)
if self.args.ling2_only:
combined_embedding = torch.cat([enc_output.last_hidden_state, sent2_ling], dim=2)
else:
combined_embedding = torch.cat([enc_output.last_hidden_state, sent1_ling, sent2_ling], dim=2)
enc_output.last_hidden_state = self.fusion(combined_embedding)
if generate:
if self.args.combine_method == 'logits_add':
logits_processor = LogitsProcessorList([LogitsAdd(sent2_ling.view(bs, -1))])
else:
logits_processor = LogitsProcessorList()
dec_output = self.backbone.generate_with_grad(
attention_mask=inputs_att_mask,
encoder_outputs=enc_output,
sent1_ling=sent1_ling,
sent2_ling=sent2_ling,
return_dict_in_generate=True,
output_scores=True,
logits_processor = logits_processor,
# renormalize_logits=True,
# do_sample=True,
# top_p=0.8,
eos_token_id=self.eos_token_id,
# min_new_tokens=3,
# repetition_penalty=1.2,
max_length=self.args.max_length,
)
scores = torch.stack(dec_output.scores, 1)
cache.update({'scores': scores})
return dec_output.sequences, cache
decoder_input_ids = self.backbone._shift_right(batch['sentence2_input_ids'])
decoder_inputs_embeds = self.backbone.shared(decoder_input_ids)
decoder_att_mask = batch['sentence2_attention_mask']
labels = batch['sentence2_input_ids'].clone()
labels[labels == self.pad_token_id] = -100
if self.args.combine_method == 'decoder_concat':
if self.args.ling2_only:
decoder_inputs_embeds = torch.cat([sent2_ling, decoder_inputs_embeds], dim=1)
decoder_att_mask = torch.cat([torch.ones((bs, 1)).to(decoder_inputs_embeds.device), decoder_att_mask], dim=1)
labels = torch.cat([torch.ones((bs, 1), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id,
labels], dim=1)
else:
decoder_inputs_embeds = torch.cat([sent1_ling, sent2_ling, decoder_inputs_embeds], dim=1)
decoder_att_mask = torch.cat([torch.ones((bs, 2)).to(decoder_inputs_embeds.device), decoder_att_mask], dim=1)
labels = torch.cat([torch.ones((bs, 2), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id,
labels], dim=1)
elif self.args.combine_method == 'decoder_add' or self.args.combine_method == 'decoder_add_first' :
if self.args.ling2_only:
decoder_inputs_embeds = decoder_inputs_embeds + self.args.combine_weight * sent2_ling
else:
decoder_inputs_embeds = decoder_inputs_embeds + sent1_ling + sent2_ling
dec_output = self.backbone(
decoder_inputs_embeds=decoder_inputs_embeds,
decoder_attention_mask=decoder_att_mask,
encoder_outputs=enc_output,
attention_mask=inputs_att_mask,
labels=labels,
)
if self.args.combine_method == 'logits_add':
dec_output.logits = dec_output.logits + self.args.combine_weight * sent2_ling
vocab_size = dec_output.logits.size(-1)
dec_output.loss = F.cross_entropy(dec_output.logits.view(-1, vocab_size), labels.view(-1))
return dec_output, cache
def forward(self, batch, generate=False):
enc_output, enc_att_mask, cache = self.encode(batch)
dec_output, cache2 = self.decode(batch, enc_output, enc_att_mask, generate)
cache.update(cache2)
return dec_output, enc_output, cache
def infer_with_cache(self, batch):
dec_output, _, cache = self(batch, generate = True)
return dec_output, cache
def infer(self, batch):
dec_output, _ = self.infer_with_cache(batch)
return dec_output
def infer_with_feedback_BP(self, ling_disc, sem_emb, batch, tokenizer):
from torch.autograd import grad
interpolations = []
def line_search():
best_val = None
best_loss = None
eta = 1e3
sem_prob = 1
patience = 4
while patience > 0:
param_ = param - eta * grads
with torch.no_grad():
new_loss, pred = get_loss(param_)
max_len = pred.shape[1]
lens = torch.where(pred == self.eos_token_id, 1, 0).argmax(-1) + 1
batch.update({
'sentence2_input_ids': pred,
'sentence2_attention_mask': sequence_mask(lens, max_len = max_len)
})
sem_prob = torch.sigmoid(sem_emb(**batch)).item()
# if sem_prob <= 0.1:
# patience -= 1
if new_loss < loss and sem_prob >= 0.90 and lens.item() > 1:
return param_
eta *= 2.25
patience -= 1
return False
def get_loss(param):
if self.args.feedback_param == 'l':
batch.update({'sent2_ling_embed': param})
elif self.args.feedback_param == 's':
batch.update({'inputs_embeds': param})
if self.args.feedback_param == 'logits':
logits = param
pred = param.argmax(-1)
else:
pred, cache = self.infer_with_cache(batch)
logits = cache['scores']
out = ling_disc(logits = logits)
probs = F.softmax(out, 1)
if ling_disc.quant:
loss = F.cross_entropy(out, batch['sentence2_discr'])
else:
loss = F.mse_loss(out, batch['sentence2_ling'])
return loss, pred
if self.args.feedback_param == 'l':
ling2_embed = self.ling_embed(batch['sentence2_ling'])
param = torch.nn.Parameter(ling2_embed, requires_grad = True)
elif self.args.feedback_param == 's':
inputs_embeds = self.backbone.shared(batch['sentence1_input_ids'])
param = torch.nn.Parameter(inputs_embeds, requires_grad = True)
elif self.args.feedback_param == 'logits':
logits = self.infer_with_cache(batch)[1]['scores']
param = torch.nn.Parameter(logits, requires_grad = True)
target_np = batch['sentence2_ling'][0].cpu().numpy()
while True:
loss, pred = get_loss(param)
pred_text = tokenizer.batch_decode(pred.cpu().numpy(),
skip_special_tokens=True)[0]
interpolations.append(pred_text)
if loss < 1:
break
self.zero_grad()
grads = grad(loss, param)[0]
param = line_search()
if param is False:
break
return pred, [pred_text, interpolations]
def set_grad(module, state):
if module is not None:
for p in module.parameters():
p.requires_grad = state
def set_grad_except(model, name, state):
for n, p in model.named_parameters():
if not name in n:
p.requires_grad = state
class SemEmbPipeline():
def __init__(self,
ckpt = "/data/mohamed/checkpoints/ling_conversion_sem_emb_best.pt"):
self.tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
self.model = SemEmb(T5EncoderModel.from_pretrained('google/flan-t5-base'), self.tokenizer.get_vocab()['</s>'])
state = torch.load(ckpt)
self.model.load_state_dict(state['model'], strict=False)
self.model.eval()
self.model.cuda()
def __call__(self, sentence1, sentence2):
sentence1 = self.tokenizer(sentence1, return_attention_mask = True, return_tensors = 'pt')
sentence2 = self.tokenizer(sentence2, return_attention_mask = True, return_tensors = 'pt')
sem_logit = self.model(
sentence1_input_ids = sentence1.input_ids.cuda(),
sentence1_attention_mask = sentence1.attention_mask.cuda(),
sentence2_input_ids = sentence2.input_ids.cuda(),
sentence2_attention_mask = sentence2.attention_mask.cuda(),
)
sem_prob = torch.sigmoid(sem_logit).item()
return sem_prob
class LingDiscPipeline():
def __init__(self,
model_name="google/flan-t5-base",
disc_type='deberta',
disc_ckpt='/data/mohamed/checkpoints/ling_disc/deberta-v3-small_flan-t5-base_40',
# disc_type='t5',
# disc_ckpt='/data/mohamed/checkpoints/ling_conversion_ling_disc.pt',
):
self.tokenizer = T5Tokenizer.from_pretrained(model_name)
self.model = LingDisc(model_name, disc_type, disc_ckpt)
self.model.eval()
self.model.cuda()
def __call__(self, sentence):
inputs = self.tokenizer(sentence, return_tensors = 'pt')
with torch.no_grad():
ling_pred = self.model(input_ids=inputs.input_ids.cuda())
return ling_pred
def get_model(args, tokenizer, device):
if args.pretrain_disc or args.disc_loss or args.disc_ckpt:
ling_disc = LingDisc(args.model_name, args.disc_type, args.disc_ckpt).to(device)
else:
ling_disc = None
if args.linggen_type != 'none':
ling_gen = LingGenerator(args).to(device)
if args.sem_loss or args.sem_ckpt:
if args.sem_loss_type == 'shared':
sem_emb = seld.backbone.encoder
elif args.sem_loss_type == 'dedicated':
sem_emb = SemEmb(T5EncoderModel.from_pretrained('google/flan-t5-base'), tokenizer.eos_token_id).to(device)
else:
raise NotImplementedError('Semantic loss type')
else:
sem_emb = None
if not args.pretrain_disc:
model = EncoderDecoderVAE(args, tokenizer.pad_token_id, tokenizer.eos_token_id).to(device)
if args.use_lora:
target_modules = ["Attention.k", "Attention.q", "Attention.v", "Attention.o", "lm_head", "wi_0", "wi_1", "wo"]
target_modules = '|'.join(f'(.*{module})' for module in target_modules)
target_modules = f'backbone.({target_modules})'
config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_r * 2,
target_modules=target_modules,
lora_dropout=0.1,
bias="lora_only",
modules_to_save=['ling_embed'],
)
model = get_peft_model(model, config)
model.print_trainable_parameters()
else:
model = ling_disc
return model, ling_disc, sem_emb