zRzRzRzRzRzRzR
commited on
Commit
•
f308259
1
Parent(s):
37fe000
add set_input_embeddings(self, value):
Browse files- modeling_chatglm.py +3 -0
modeling_chatglm.py
CHANGED
@@ -769,6 +769,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
769 |
def get_input_embeddings(self):
|
770 |
return self.embedding.word_embeddings
|
771 |
|
|
|
|
|
|
|
772 |
def get_prompt(self, batch_size, device, dtype=torch.half):
|
773 |
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
|
774 |
past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
|
|
|
769 |
def get_input_embeddings(self):
|
770 |
return self.embedding.word_embeddings
|
771 |
|
772 |
+
def set_input_embeddings(self, value):
|
773 |
+
self.embedding.word_embeddings = value
|
774 |
+
|
775 |
def get_prompt(self, batch_size, device, dtype=torch.half):
|
776 |
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
|
777 |
past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
|