File size: 648 Bytes
ca25718
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch


def get_optimizer(
    optimizer_name: str, latents: torch.Tensor, lr: float, nesterov: bool
):
    if optimizer_name == "adam":
        optimizer = torch.optim.Adam([latents], lr=lr, eps=1e-2)
    elif optimizer_name == "sgd":
        optimizer = torch.optim.SGD([latents], lr=lr, nesterov=nesterov, momentum=0.9)
    elif optimizer_name == "lbfgs":
        optimizer = torch.optim.LBFGS(
            [latents],
            lr=lr,
            max_iter=10,
            history_size=3,
            line_search_fn="strong_wolfe",
        )
    else:
        raise ValueError(f"Unknown optimizer {optimizer_name}")
    return optimizer