aamirshakir
commited on
Commit
•
e4614dd
1
Parent(s):
2a9c989
Update README.md
Browse files
README.md
CHANGED
@@ -2684,7 +2684,7 @@ def pooling(outputs: torch.Tensor, inputs: Dict, strategy: str = 'cls') -> np.n
|
|
2684 |
outputs = outputs[:, 0]
|
2685 |
elif strategy == 'mean':
|
2686 |
outputs = torch.sum(
|
2687 |
-
outputs * inputs["attention_mask"][:, :, None], dim=1) / torch.sum(inputs["attention_mask"])
|
2688 |
else:
|
2689 |
raise NotImplementedError
|
2690 |
return outputs.detach().cpu().numpy()
|
|
|
2684 |
outputs = outputs[:, 0]
|
2685 |
elif strategy == 'mean':
|
2686 |
outputs = torch.sum(
|
2687 |
+
outputs * inputs["attention_mask"][:, :, None], dim=1) / torch.sum(inputs["attention_mask"], dim=1, keepdim=True)
|
2688 |
else:
|
2689 |
raise NotImplementedError
|
2690 |
return outputs.detach().cpu().numpy()
|