Spaces:
Paused
Paused
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
from enum import Enum | |
import itertools | |
from typing import Any, Callable, Dict, Iterable, List, Set, Type, Union | |
import torch | |
from detectron2.config import CfgNode | |
from detectron2.solver.build import maybe_add_gradient_clipping | |
def match_name_keywords(n, name_keywords): | |
out = False | |
for b in name_keywords: | |
if b in n: | |
out = True | |
break | |
return out | |
def build_custom_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer: | |
""" | |
Build an optimizer from config. | |
""" | |
params: List[Dict[str, Any]] = [] | |
memo: Set[torch.nn.parameter.Parameter] = set() | |
custom_multiplier_name = cfg.SOLVER.CUSTOM_MULTIPLIER_NAME | |
optimizer_type = cfg.SOLVER.OPTIMIZER | |
for key, value in model.named_parameters(recurse=True): | |
if not value.requires_grad: | |
continue | |
# Avoid duplicating parameters | |
if value in memo: | |
continue | |
memo.add(value) | |
lr = cfg.SOLVER.BASE_LR | |
weight_decay = cfg.SOLVER.WEIGHT_DECAY | |
if "backbone" in key: | |
lr = lr * cfg.SOLVER.BACKBONE_MULTIPLIER | |
if match_name_keywords(key, custom_multiplier_name): | |
lr = lr * cfg.SOLVER.CUSTOM_MULTIPLIER | |
print('Costum LR', key, lr) | |
param = {"params": [value], "lr": lr} | |
if optimizer_type != 'ADAMW': | |
param['weight_decay'] = weight_decay | |
params += [param] | |
def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class | |
# detectron2 doesn't have full model gradient clipping now | |
clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE | |
enable = ( | |
cfg.SOLVER.CLIP_GRADIENTS.ENABLED | |
and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model" | |
and clip_norm_val > 0.0 | |
) | |
class FullModelGradientClippingOptimizer(optim): | |
def step(self, closure=None): | |
all_params = itertools.chain(*[x["params"] for x in self.param_groups]) | |
torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) | |
super().step(closure=closure) | |
return FullModelGradientClippingOptimizer if enable else optim | |
if optimizer_type == 'SGD': | |
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)( | |
params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM, | |
nesterov=cfg.SOLVER.NESTEROV | |
) | |
elif optimizer_type == 'ADAMW': | |
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)( | |
params, cfg.SOLVER.BASE_LR, | |
weight_decay=cfg.SOLVER.WEIGHT_DECAY | |
) | |
else: | |
raise NotImplementedError(f"no optimizer type {optimizer_type}") | |
if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model": | |
optimizer = maybe_add_gradient_clipping(cfg, optimizer) | |
return optimizer |