|
from typing import Optional, Union |
|
|
|
import torch |
|
import transformers |
|
import streamlit as st |
|
|
|
from plotly import graph_objects as go |
|
|
|
|
|
class Generator: |
|
def __init__(self, lm_model_name, device, entropy=None): |
|
|
|
self.device = device |
|
|
|
self.tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
lm_model_name |
|
) |
|
self.lm = transformers.AutoModelForCausalLM.from_pretrained( |
|
lm_model_name |
|
).to(device) |
|
self.lm.eval() |
|
|
|
self.lm.config.pad_token_id = self.lm.config.eos_token_id |
|
self.tokenizer.add_special_tokens( |
|
{"pad_token": self.tokenizer.decode(self.lm.config.eos_token_id)} |
|
) |
|
self.caif_sampler = None |
|
self.ordinary_sampler = None |
|
self.entropy_based_stats = { |
|
"skips": 0, |
|
"avg_entropy": 0, |
|
"count": 0, |
|
} |
|
self.entropy = entropy |
|
|
|
def set_caif_sampler(self, sampler): |
|
self.caif_sampler = sampler |
|
|
|
def set_ordinary_sampler(self, sampler): |
|
self.ordinary_sampler = sampler |
|
|
|
def sample_sequences( |
|
self, |
|
num_samples: int, |
|
input_prompt: Optional[str], |
|
max_length: int, |
|
caif_period: int, |
|
caif_tokens_num: Union[int, None] = None, |
|
entropy: float = None, |
|
progress_bar=None, |
|
**sampler_kwargs |
|
): |
|
self.entropy = entropy |
|
|
|
input_ids, past, ended_sequences = self.get_input_ids( |
|
input_prompt, |
|
num_samples, |
|
) |
|
text = st.empty() |
|
plot = st.empty() |
|
gen_history = [] |
|
layout = go.Layout({ |
|
"xaxis": { |
|
"title": "# Tokens" |
|
}, |
|
"yaxis": { |
|
"title": "Desired Attribute" |
|
}, |
|
"plot_bgcolor": '#FFFFFF', |
|
"template": "plotly_white", |
|
"hovermode": "x", |
|
|
|
}) |
|
inp_len = len(input_ids[0]) |
|
for i in range(max_length): |
|
is_caif_step = ( |
|
i % caif_period == 0 and self.caif_sampler is not None |
|
) |
|
input_ids, past, ended_sequences = self.generation_step( |
|
input_ids, |
|
past, |
|
ended_sequences, |
|
is_caif_step, |
|
caif_tokens_num=caif_tokens_num, |
|
**sampler_kwargs |
|
) |
|
progress_bar.progress((i+1)/max_length) |
|
if ended_sequences.all(): |
|
break |
|
current_decoded = self.tokenizer.decode(input_ids[0]) |
|
if self.caif_sampler is not None: |
|
probs = torch.exp( |
|
self.caif_sampler.get_classifier_log_probs( |
|
current_decoded, target_cls_id=sampler_kwargs["target_cls_id"] |
|
) |
|
).item() |
|
gen_history += [probs] |
|
scatter_data = go.Scatter({ |
|
"x": list(range(len(gen_history))), |
|
"y": gen_history, |
|
"hovertext": [self.tokenizer.decode(t) for t in input_ids[0][inp_len:]] |
|
}) |
|
fig = go.Figure([scatter_data], layout=layout) |
|
plot.plotly_chart(fig, use_container_width=True) |
|
text.text(current_decoded) |
|
|
|
return ( |
|
[ |
|
self.tokenizer.decode(sequence, skip_special_tokens=True) |
|
for sequence in input_ids |
|
], |
|
input_ids, |
|
) |
|
|
|
def generation_step( |
|
self, |
|
input_ids, |
|
past, |
|
ended_sequences, |
|
is_caif_step: bool, |
|
caif_tokens_num=None, |
|
**sampler_kwargs |
|
): |
|
prepared_inputs = self.lm.prepare_inputs_for_generation( |
|
input_ids, past, use_cache=True |
|
) |
|
outputs = self.lm( |
|
**prepared_inputs, |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
return_dict=True |
|
) |
|
|
|
past = outputs.past_key_values |
|
if self.entropy is not None: |
|
normalized = torch.nn.functional.log_softmax( |
|
outputs.logits, dim=-1 |
|
) |
|
p = torch.exp(normalized) |
|
output_probs = p |
|
output_information = -normalized |
|
output_entropy = (output_probs * output_information).sum(-1)[:, -1] |
|
batch_size = output_entropy.shape[0] |
|
caif_mask = torch.ge(output_entropy, self.entropy) |
|
ordinary_mask = ~caif_mask |
|
self.entropy_based_stats["skips"] += caif_mask.sum() / batch_size |
|
self.entropy_based_stats["count"] += 1 |
|
self.entropy_based_stats["avg_entropy"] += ( |
|
output_entropy.sum() / batch_size |
|
) |
|
flatten_entropy = output_entropy.view(-1).cpu().tolist() |
|
if "entropy" not in self.entropy_based_stats.keys(): |
|
self.entropy_based_stats["entropy"] = flatten_entropy |
|
else: |
|
self.entropy_based_stats["entropy"] += flatten_entropy |
|
|
|
if caif_mask.sum() == 0: |
|
next_tokens_sampler = self.ordinary_sampler |
|
next_tokens = next_tokens_sampler( |
|
input_ids, |
|
outputs.logits, |
|
caif_tokens_num=caif_tokens_num, |
|
**sampler_kwargs |
|
) |
|
next_tokens = ( |
|
next_tokens * (1 - ended_sequences.long()) |
|
+ self.lm.config.eos_token_id * ended_sequences.long() |
|
).long() |
|
|
|
elif caif_mask.sum() == batch_size: |
|
next_tokens_sampler = self.caif_sampler |
|
next_tokens = next_tokens_sampler( |
|
input_ids, |
|
outputs.logits, |
|
caif_tokens_num=caif_tokens_num, |
|
**sampler_kwargs |
|
) |
|
next_tokens = ( |
|
next_tokens * (1 - ended_sequences.long()) |
|
+ self.lm.config.eos_token_id * ended_sequences.long() |
|
).long() |
|
|
|
else: |
|
next_tokens_caif = self.caif_sampler( |
|
input_ids[caif_mask], |
|
outputs.logits[caif_mask], |
|
caif_tokens_num=caif_tokens_num, |
|
**sampler_kwargs |
|
) |
|
next_tokens_ordinary = self.ordinary_sampler( |
|
input_ids[ordinary_mask], |
|
outputs.logits[ordinary_mask], |
|
caif_tokens_num=caif_tokens_num, |
|
**sampler_kwargs |
|
) |
|
next_tokens_caif = ( |
|
next_tokens_caif * (1 - ended_sequences[caif_mask].long()) |
|
+ self.lm.config.eos_token_id |
|
* ended_sequences[caif_mask].long() |
|
).long() |
|
next_tokens_ordinary = ( |
|
next_tokens_ordinary |
|
* (1 - ended_sequences[ordinary_mask].long()) |
|
+ self.lm.config.eos_token_id |
|
* ended_sequences[ordinary_mask].long() |
|
).long() |
|
|
|
next_tokens = torch.ones(batch_size).long().to(self.device) |
|
next_tokens[caif_mask] = next_tokens_caif |
|
next_tokens[ordinary_mask] = next_tokens_ordinary |
|
else: |
|
if is_caif_step: |
|
next_tokens_sampler = self.caif_sampler |
|
else: |
|
next_tokens_sampler = self.ordinary_sampler |
|
|
|
next_tokens = next_tokens_sampler( |
|
input_ids, |
|
outputs.logits, |
|
caif_tokens_num=caif_tokens_num, |
|
**sampler_kwargs |
|
) |
|
|
|
next_tokens = ( |
|
next_tokens * (1 - ended_sequences.long()) |
|
+ self.lm.config.eos_token_id * ended_sequences.long() |
|
).long() |
|
|
|
input_ids = torch.cat( |
|
[input_ids, next_tokens[:, None].to(self.device)], dim=-1 |
|
) |
|
|
|
ended_sequences += next_tokens == self.lm.config.eos_token_id |
|
|
|
return input_ids, past, ended_sequences |
|
|
|
def get_input_ids(self, input_prompt, num_samples): |
|
|
|
if input_prompt is not None: |
|
input_prompt = self.tokenizer( |
|
input_prompt, return_tensors="pt" |
|
).input_ids |
|
input_ids = input_prompt |
|
input_ids = input_ids.repeat(num_samples, 1).to(self.device) |
|
past = None |
|
ended_sequences = torch.zeros( |
|
input_ids.shape[0], device=self.device |
|
).bool() |
|
|
|
return input_ids, past, ended_sequences |
|
|
|
@staticmethod |
|
def sample(unscaled_probs, values): |
|
samples = torch.multinomial(unscaled_probs, 1) |
|
return torch.take_along_dim(values, samples, dim=1) |
|
|