aamirshakir commited on
Commit
e4614dd
1 Parent(s): 2a9c989

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +1 -1
README.md CHANGED
@@ -2684,7 +2684,7 @@ def pooling(outputs: torch.Tensor, inputs: Dict, strategy: str = 'cls') -> np.n
2684
  outputs = outputs[:, 0]
2685
  elif strategy == 'mean':
2686
  outputs = torch.sum(
2687
- outputs * inputs["attention_mask"][:, :, None], dim=1) / torch.sum(inputs["attention_mask"])
2688
  else:
2689
  raise NotImplementedError
2690
  return outputs.detach().cpu().numpy()
 
2684
  outputs = outputs[:, 0]
2685
  elif strategy == 'mean':
2686
  outputs = torch.sum(
2687
+ outputs * inputs["attention_mask"][:, :, None], dim=1) / torch.sum(inputs["attention_mask"], dim=1, keepdim=True)
2688
  else:
2689
  raise NotImplementedError
2690
  return outputs.detach().cpu().numpy()