dar-tau commited on
Commit
3f8ef2d
1 Parent(s): 8ee84e0

Update interpret.py

Browse files
Files changed (1) hide show
  1. interpret.py +2 -2
interpret.py CHANGED
@@ -90,10 +90,10 @@ class InterpretationPrompt:
90
  else:
91
  raise NotImplementedError
92
 
93
- def generate(self, model, embeds, k, layer_format='model.layers.{k}', **generation_kwargs):
94
  num_seqs = len(embeds[0]) # assumes the placeholder 0 exists
95
  tokens_batch = torch.tensor([self.tokens[:] for _ in range(num_seqs)]).to(model.device)
96
- module = model.get_submodule(layer_format.format(k=k))
97
  with SubstitutionHook(module, positions_dict=self.placeholders, values_dict=embeds):
98
  generated = model.generate(tokens_batch, **generation_kwargs)
99
  return generated
 
90
  else:
91
  raise NotImplementedError
92
 
93
+ def generate(self, model, embeds, k, layers_format='model.layers.{k}', **generation_kwargs):
94
  num_seqs = len(embeds[0]) # assumes the placeholder 0 exists
95
  tokens_batch = torch.tensor([self.tokens[:] for _ in range(num_seqs)]).to(model.device)
96
+ module = model.get_submodule(layers_format.format(k=k))
97
  with SubstitutionHook(module, positions_dict=self.placeholders, values_dict=embeds):
98
  generated = model.generate(tokens_batch, **generation_kwargs)
99
  return generated