akore commited on
Commit
80ef07e
1 Parent(s): d10ca53

Update modeling_atomformer.py

Browse files
Files changed (1) hide show
  1. modeling_atomformer.py +2 -2
modeling_atomformer.py CHANGED
@@ -2516,7 +2516,7 @@ class AtomformerEncoder(nn.Module):
2516
  for blk in self.blocks:
2517
  input_embeds, pos_embeds = blk(input_embeds, pos_embeds, attention_mask)
2518
 
2519
- return input_embeds[:, :-1]
2520
 
2521
 
2522
  class AtomformerPreTrainedModel(PreTrainedModel): # type: ignore
@@ -2550,7 +2550,7 @@ class AtomformerModel(AtomformerPreTrainedModel):
2550
  ) -> torch.Tensor:
2551
  """Forward function call for the transformer model."""
2552
  output: torch.Tensor = self.encoder(input_ids, coords, attention_mask)
2553
- return output
2554
 
2555
 
2556
  class AtomformerForMaskedAM(AtomformerPreTrainedModel):
 
2516
  for blk in self.blocks:
2517
  input_embeds, pos_embeds = blk(input_embeds, pos_embeds, attention_mask)
2518
 
2519
+ return input_embeds, pos_embeds
2520
 
2521
 
2522
  class AtomformerPreTrainedModel(PreTrainedModel): # type: ignore
 
2550
  ) -> torch.Tensor:
2551
  """Forward function call for the transformer model."""
2552
  output: torch.Tensor = self.encoder(input_ids, coords, attention_mask)
2553
+ return output[0][:, :-1]
2554
 
2555
 
2556
  class AtomformerForMaskedAM(AtomformerPreTrainedModel):