ReNO / training /optim.py
fffiloni's picture
Upload 24 files
ca25718 verified
raw
history blame contribute delete
648 Bytes
import torch
def get_optimizer(
optimizer_name: str, latents: torch.Tensor, lr: float, nesterov: bool
):
if optimizer_name == "adam":
optimizer = torch.optim.Adam([latents], lr=lr, eps=1e-2)
elif optimizer_name == "sgd":
optimizer = torch.optim.SGD([latents], lr=lr, nesterov=nesterov, momentum=0.9)
elif optimizer_name == "lbfgs":
optimizer = torch.optim.LBFGS(
[latents],
lr=lr,
max_iter=10,
history_size=3,
line_search_fn="strong_wolfe",
)
else:
raise ValueError(f"Unknown optimizer {optimizer_name}")
return optimizer