File size: 2,452 Bytes
d5175d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import logging
import unittest

import torch
from fairseq.optim.adam import FairseqAdam
from fairseq.optim.fp16_optimizer import MemoryEfficientFP16Optimizer
from omegaconf import OmegaConf


@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
class TestMemoryEfficientFP16(unittest.TestCase):
    def setUp(self):
        logging.disable(logging.CRITICAL)

    def tearDown(self):
        logging.disable(logging.NOTSET)

    def test_load_state_dict(self):
        # define simple FP16 model
        model = torch.nn.Linear(5, 5).cuda().half()
        params = list(model.parameters())

        # initialize memory efficient FP16 optimizer
        # with pseudo DictConfigs
        optimizer = FairseqAdam(
            cfg=OmegaConf.create(
                vars(
                    argparse.Namespace(
                        adam_betas="(0.9, 0.999)",
                        adam_eps=1e-8,
                        weight_decay=0.0,
                        lr=[0.00001],
                    )
                )
            ),
            params=params,
        )
        me_optimizer = MemoryEfficientFP16Optimizer(
            cfg=OmegaConf.create(
                {
                    "common": vars(
                        argparse.Namespace(
                            fp16_init_scale=1,
                            fp16_scale_window=1,
                            fp16_scale_tolerance=1,
                            threshold_loss_scale=1,
                            min_loss_scale=1e-4,
                        )
                    )
                }
            ),
            params=params,
            optimizer=optimizer,
        )

        # optimizer state is created in the first step
        loss = model(torch.rand(5).cuda().half()).sum()
        me_optimizer.backward(loss)
        me_optimizer.step()

        # reload state
        state = me_optimizer.state_dict()
        me_optimizer.load_state_dict(state)
        for k, v in me_optimizer.optimizer.state.items():
            self.assertTrue(k.dtype == torch.float16)
            for v_i in v.values():
                if torch.is_tensor(v_i):
                    self.assertTrue(v_i.dtype == torch.float32)


if __name__ == "__main__":
    unittest.main()