import os import re from collections import defaultdict import numpy as np import torch from torch import nn from contextlib import AbstractContextManager # helper functions def item(x): return np.array(x).item() def _prompt_to_parts(prompt, repeat=5): # In order to allow easy formatting for prompts, we take string prompts # in the format "[INST] [X] [/INST] Sure, I'll summarize this" # and split them into a list of strings ["[INST]", 0, 0, 0, 0, 0, " [/INST] Sure, I'll summarize this"]. # Notice how each instance of [X] is replaced by multiple 0 placeholders (according to `~repeat`). # This is in line with the SELFIE paper, where each interpreted token is inserted 5 times, probably to make # the interpretation less likely to avoid it. split_prompt = re.split(r' *\[X\]', prompt) parts = [] for i in range(len(split_prompt)): cur_part = split_prompt[i] if cur_part != '': # if we have multiple [X] in procession, there will be a '' between them in split_prompt parts.append(cur_part) if i < len(split_prompt) - 1: parts.extend([0] * repeat) print('Prompt parts:', parts) return parts class Hook(AbstractContextManager): # Hook could be easily absorbed into SubstitutionHook instead, but I like it better to have them both. # Seems like the right way from an aesthetic point of view. def __init__(self, module, fn): self.registered_hook = module.register_forward_hook(fn) def __enter__(self): return self def __exit__(self, type, value, traceback): self.close() def close(self): self.registered_hook.remove() class SubstitutionHook(Hook): # This is where the substitution takes place, and it will be used by InterpretationPrompt later. def __init__(self, module, positions_dict, values_dict): assert set(positions_dict.keys()) == set(values_dict.keys()) keys = positions_dict.keys() def fn(module, input, output): device = output[0].device dtype = output[0].dtype for key in keys: num_positions = len(positions_dict[key]) values = values_dict[key].unsqueeze(1).expand(-1, num_positions, -1) # batch_size x num_positions x hidden_dim positions = positions_dict[key] print(f'{positions=} {values.shape=} {output[0].shape=}') output[0][:, positions, :] = values.to(dtype).to(device) self.registered_hook.remove() # in generation with use_cache=True, after the first step the rest of the steps are one at a time return output self.registered_hook = module.register_forward_hook(fn) # functions class InterpretationPrompt: def __init__(self, tokenizer, prompt, placeholder_token=' '): prompt_parts = _prompt_to_parts(prompt) if placeholder_token is None: placeholder_token_id = tokenizer.eos_token_id else: placeholder_token_id = item(tokenizer.encode(placeholder_token, add_special_tokens=False)) assert placeholder_token_id != tokenizer.eos_token_id self.tokens = [] self.placeholders = defaultdict(list) for part in prompt_parts: if type(part) == str: self.tokens.extend(tokenizer.encode(part, add_special_tokens=False)) elif type(part) == int: self.placeholders[part].append(len(self.tokens)) self.tokens.append(placeholder_token_id) else: raise NotImplementedError def generate(self, model, embeds, k, layers_format='model.layers.{k}', **generation_kwargs): num_seqs = len(embeds[0]) # assumes the placeholder 0 exists tokens_batch = torch.tensor([self.tokens[:] for _ in range(num_seqs)]).to(model.device) module = model.get_submodule(layers_format.format(k=k)) with SubstitutionHook(module, positions_dict=self.placeholders, values_dict=embeds): generated = model.generate(tokens_batch, **generation_kwargs) return generated