How to get esm2_t33_650M_UR50D Fixed embedding using in the downstram task?
#4
by
xigua666
- opened
Here is my usage process. I'm not sure if it's correct. Can someone help me
def forward(self,encoded_inputs):
# print(seqs)
# encoded_inputs = self.tokenizer(seqs, max_length=65, padding=True, truncation=True, return_tensors='pt')
embedded_data = self.model(**encoded_inputs).last_hidden_state.mean(0) # this code
print("embedded_data shape == >",embedded_data.shape)
print(embedded_data)
# .mean(0) [:,0, :]
output = torch.squeeze(self.main(embedded_data))
# print(output.shape)
return output