caif / generator.py
Балаганский Никита Николаевич
everything wrapped in cache
e852933
raw
history blame
No virus
9.41 kB
from typing import Optional, Union
import torch
import transformers
import streamlit as st
from plotly import graph_objects as go
from utils import get_lm
class Generator:
def __init__(self, lm_model_name, device, entropy=None):
self.device = device
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
lm_model_name
)
self.lm = get_lm(lm_model_name).to(device)
self.lm.eval()
self.lm.config.pad_token_id = self.lm.config.eos_token_id
self.tokenizer.add_special_tokens(
{"pad_token": self.tokenizer.decode(self.lm.config.eos_token_id)}
)
self.caif_sampler = None
self.ordinary_sampler = None
self.entropy_based_stats = {
"skips": 0,
"avg_entropy": 0,
"count": 0,
}
self.entropy = entropy
def set_caif_sampler(self, sampler):
self.caif_sampler = sampler
def set_ordinary_sampler(self, sampler):
self.ordinary_sampler = sampler
def sample_sequences(
self,
num_samples: int,
input_prompt: Optional[str],
max_length: int,
caif_period: int,
caif_tokens_num: Union[int, None] = None,
entropy: float = None,
progress_bar=None,
**sampler_kwargs
):
self.entropy = entropy
input_ids, past, ended_sequences = self.get_input_ids(
input_prompt,
num_samples,
)
text = st.empty()
plot = st.empty()
gen_history = []
layout = go.Layout({
"xaxis": {
"title": "# Tokens"
},
"yaxis": {
"title": "Desired Attribute"
},
"plot_bgcolor": '#FFFFFF',
"template": "plotly_white",
"hovermode": "x",
})
inp_len = len(input_ids[0])
if self.caif_sampler is not None:
current_decoded = self.tokenizer.decode(input_ids[0])
probs = torch.exp(
self.caif_sampler.get_classifier_log_probs(
current_decoded, target_cls_id=sampler_kwargs["target_cls_id"]
)
).item()
gen_history += [probs]
for i in range(max_length):
is_caif_step = (
i % caif_period == 0 and self.caif_sampler is not None
)
input_ids, past, ended_sequences = self.generation_step(
input_ids,
past,
ended_sequences,
is_caif_step,
caif_tokens_num=caif_tokens_num,
**sampler_kwargs
)
progress_bar.progress((i+1)/max_length)
if ended_sequences.all():
break
current_decoded = self.tokenizer.decode(input_ids[0])
if self.caif_sampler is not None:
probs = torch.exp(
self.caif_sampler.get_classifier_log_probs(
current_decoded, target_cls_id=sampler_kwargs["target_cls_id"]
)
).item()
gen_history += [probs]
scatter_data = go.Scatter({
"x": list(range(len(gen_history))),
"y": gen_history,
"hovertext": ["[PROMPT]"] + [self.tokenizer.decode(t) for t in input_ids[0][inp_len:]]
})
fig = go.Figure([scatter_data], layout=layout)
plot.plotly_chart(fig, use_container_width=True)
if i == 0:
with st.expander("What is it?"):
st.write("You can see how the probability of the desired attribute varies for every generation step.")
text.text(current_decoded)
return (
[
self.tokenizer.decode(sequence, skip_special_tokens=True)
for sequence in input_ids
],
input_ids,
)
def generation_step(
self,
input_ids,
past,
ended_sequences,
is_caif_step: bool,
caif_tokens_num=None,
**sampler_kwargs
):
prepared_inputs = self.lm.prepare_inputs_for_generation(
input_ids, past, use_cache=True
)
outputs = self.lm(
**prepared_inputs,
output_attentions=False,
output_hidden_states=False,
return_dict=True
)
past = outputs.past_key_values
if self.entropy is not None:
normalized = torch.nn.functional.log_softmax(
outputs.logits, dim=-1
)
p = torch.exp(normalized)
output_probs = p
output_information = -normalized
output_entropy = (output_probs * output_information).sum(-1)[:, -1]
batch_size = output_entropy.shape[0]
caif_mask = torch.ge(output_entropy, self.entropy)
ordinary_mask = ~caif_mask
self.entropy_based_stats["skips"] += caif_mask.sum() / batch_size
self.entropy_based_stats["count"] += 1
self.entropy_based_stats["avg_entropy"] += (
output_entropy.sum() / batch_size
)
flatten_entropy = output_entropy.view(-1).cpu().tolist()
if "entropy" not in self.entropy_based_stats.keys():
self.entropy_based_stats["entropy"] = flatten_entropy
else:
self.entropy_based_stats["entropy"] += flatten_entropy
if caif_mask.sum() == 0:
next_tokens_sampler = self.ordinary_sampler
next_tokens = next_tokens_sampler(
input_ids,
outputs.logits,
caif_tokens_num=caif_tokens_num,
**sampler_kwargs
)
next_tokens = (
next_tokens * (1 - ended_sequences.long())
+ self.lm.config.eos_token_id * ended_sequences.long()
).long()
elif caif_mask.sum() == batch_size:
next_tokens_sampler = self.caif_sampler
next_tokens = next_tokens_sampler(
input_ids,
outputs.logits,
caif_tokens_num=caif_tokens_num,
**sampler_kwargs
)
next_tokens = (
next_tokens * (1 - ended_sequences.long())
+ self.lm.config.eos_token_id * ended_sequences.long()
).long()
else:
next_tokens_caif = self.caif_sampler(
input_ids[caif_mask],
outputs.logits[caif_mask],
caif_tokens_num=caif_tokens_num,
**sampler_kwargs
)
next_tokens_ordinary = self.ordinary_sampler(
input_ids[ordinary_mask],
outputs.logits[ordinary_mask],
caif_tokens_num=caif_tokens_num,
**sampler_kwargs
)
next_tokens_caif = (
next_tokens_caif * (1 - ended_sequences[caif_mask].long())
+ self.lm.config.eos_token_id
* ended_sequences[caif_mask].long()
).long()
next_tokens_ordinary = (
next_tokens_ordinary
* (1 - ended_sequences[ordinary_mask].long())
+ self.lm.config.eos_token_id
* ended_sequences[ordinary_mask].long()
).long()
next_tokens = torch.ones(batch_size).long().to(self.device)
next_tokens[caif_mask] = next_tokens_caif
next_tokens[ordinary_mask] = next_tokens_ordinary
else:
if is_caif_step:
next_tokens_sampler = self.caif_sampler
else:
next_tokens_sampler = self.ordinary_sampler
next_tokens = next_tokens_sampler(
input_ids,
outputs.logits,
caif_tokens_num=caif_tokens_num,
**sampler_kwargs
)
next_tokens = (
next_tokens * (1 - ended_sequences.long())
+ self.lm.config.eos_token_id * ended_sequences.long()
).long()
input_ids = torch.cat(
[input_ids, next_tokens[:, None].to(self.device)], dim=-1
)
ended_sequences += next_tokens == self.lm.config.eos_token_id
return input_ids, past, ended_sequences
def get_input_ids(self, input_prompt, num_samples):
#input_ids = torch.tensor([[self.lm.config.bos_token_id]])
if input_prompt is not None:
input_prompt = self.tokenizer(
input_prompt, return_tensors="pt"
).input_ids
input_ids = input_prompt
input_ids = input_ids.repeat(num_samples, 1).to(self.device)
past = None
ended_sequences = torch.zeros(
input_ids.shape[0], device=self.device
).bool()
return input_ids, past, ended_sequences
@staticmethod
def sample(unscaled_probs, values):
samples = torch.multinomial(unscaled_probs, 1)
return torch.take_along_dim(values, samples, dim=1)