Spaces:
Runtime error
Runtime error
File size: 3,295 Bytes
733aa30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from fairseq import utils
class TestUtils(unittest.TestCase):
def test_convert_padding_direction(self):
pad = 1
left_pad = torch.LongTensor(
[
[2, 3, 4, 5, 6],
[1, 7, 8, 9, 10],
[1, 1, 1, 11, 12],
]
)
right_pad = torch.LongTensor(
[
[2, 3, 4, 5, 6],
[7, 8, 9, 10, 1],
[11, 12, 1, 1, 1],
]
)
self.assertAlmostEqual(
right_pad,
utils.convert_padding_direction(
left_pad,
pad,
left_to_right=True,
),
)
self.assertAlmostEqual(
left_pad,
utils.convert_padding_direction(
right_pad,
pad,
right_to_left=True,
),
)
def test_make_positions(self):
pad = 1
left_pad_input = torch.LongTensor(
[
[9, 9, 9, 9, 9],
[1, 9, 9, 9, 9],
[1, 1, 1, 9, 9],
]
)
left_pad_output = torch.LongTensor(
[
[2, 3, 4, 5, 6],
[1, 2, 3, 4, 5],
[1, 1, 1, 2, 3],
]
)
right_pad_input = torch.LongTensor(
[
[9, 9, 9, 9, 9],
[9, 9, 9, 9, 1],
[9, 9, 1, 1, 1],
]
)
right_pad_output = torch.LongTensor(
[
[2, 3, 4, 5, 6],
[2, 3, 4, 5, 1],
[2, 3, 1, 1, 1],
]
)
self.assertAlmostEqual(
left_pad_output,
utils.make_positions(left_pad_input, pad),
)
self.assertAlmostEqual(
right_pad_output,
utils.make_positions(right_pad_input, pad),
)
def test_clip_grad_norm_(self):
params = torch.nn.Parameter(torch.zeros(5)).requires_grad_(False)
grad_norm = utils.clip_grad_norm_(params, 1.0)
self.assertTrue(torch.is_tensor(grad_norm))
self.assertEqual(grad_norm, 0.0)
params = [torch.nn.Parameter(torch.zeros(5)) for i in range(3)]
for p in params:
p.grad = torch.full((5,), fill_value=2.0)
grad_norm = utils.clip_grad_norm_(params, 1.0)
exp_grad_norm = torch.full((15,), fill_value=2.0).norm()
self.assertTrue(torch.is_tensor(grad_norm))
self.assertEqual(grad_norm, exp_grad_norm)
grad_norm = utils.clip_grad_norm_(params, 1.0)
self.assertAlmostEqual(grad_norm, torch.tensor(1.0))
def test_resolve_max_positions_with_tuple(self):
resolved = utils.resolve_max_positions(None, (2000, 100, 2000), 12000)
self.assertEqual(resolved, (2000, 100, 2000))
def assertAlmostEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertLess(utils.item((t1 - t2).abs().max()), 1e-4)
if __name__ == "__main__":
unittest.main()
|