Update interpret.py
Browse files- interpret.py +2 -2
interpret.py
CHANGED
@@ -90,10 +90,10 @@ class InterpretationPrompt:
|
|
90 |
else:
|
91 |
raise NotImplementedError
|
92 |
|
93 |
-
def generate(self, model, embeds, k,
|
94 |
num_seqs = len(embeds[0]) # assumes the placeholder 0 exists
|
95 |
tokens_batch = torch.tensor([self.tokens[:] for _ in range(num_seqs)]).to(model.device)
|
96 |
-
module = model.get_submodule(
|
97 |
with SubstitutionHook(module, positions_dict=self.placeholders, values_dict=embeds):
|
98 |
generated = model.generate(tokens_batch, **generation_kwargs)
|
99 |
return generated
|
|
|
90 |
else:
|
91 |
raise NotImplementedError
|
92 |
|
93 |
+
def generate(self, model, embeds, k, layers_format='model.layers.{k}', **generation_kwargs):
|
94 |
num_seqs = len(embeds[0]) # assumes the placeholder 0 exists
|
95 |
tokens_batch = torch.tensor([self.tokens[:] for _ in range(num_seqs)]).to(model.device)
|
96 |
+
module = model.get_submodule(layers_format.format(k=k))
|
97 |
with SubstitutionHook(module, positions_dict=self.placeholders, values_dict=embeds):
|
98 |
generated = model.generate(tokens_batch, **generation_kwargs)
|
99 |
return generated
|