File size: 2,452 Bytes
d5175d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import logging
import unittest
import torch
from fairseq.optim.adam import FairseqAdam
from fairseq.optim.fp16_optimizer import MemoryEfficientFP16Optimizer
from omegaconf import OmegaConf
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
class TestMemoryEfficientFP16(unittest.TestCase):
def setUp(self):
logging.disable(logging.CRITICAL)
def tearDown(self):
logging.disable(logging.NOTSET)
def test_load_state_dict(self):
# define simple FP16 model
model = torch.nn.Linear(5, 5).cuda().half()
params = list(model.parameters())
# initialize memory efficient FP16 optimizer
# with pseudo DictConfigs
optimizer = FairseqAdam(
cfg=OmegaConf.create(
vars(
argparse.Namespace(
adam_betas="(0.9, 0.999)",
adam_eps=1e-8,
weight_decay=0.0,
lr=[0.00001],
)
)
),
params=params,
)
me_optimizer = MemoryEfficientFP16Optimizer(
cfg=OmegaConf.create(
{
"common": vars(
argparse.Namespace(
fp16_init_scale=1,
fp16_scale_window=1,
fp16_scale_tolerance=1,
threshold_loss_scale=1,
min_loss_scale=1e-4,
)
)
}
),
params=params,
optimizer=optimizer,
)
# optimizer state is created in the first step
loss = model(torch.rand(5).cuda().half()).sum()
me_optimizer.backward(loss)
me_optimizer.step()
# reload state
state = me_optimizer.state_dict()
me_optimizer.load_state_dict(state)
for k, v in me_optimizer.optimizer.state.items():
self.assertTrue(k.dtype == torch.float16)
for v_i in v.values():
if torch.is_tensor(v_i):
self.assertTrue(v_i.dtype == torch.float32)
if __name__ == "__main__":
unittest.main()
|