Spaces:
Running
Running
# coding=utf-8 | |
# Copyright 2022 The Google Research Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# An implementation of SM3 from: | |
# | |
# Memory-Efficient Adaptive Optimization, https://arxiv.org/pdf/1901.11150.pdf | |
# Rohan Anil, Vineet Gupta, Tomer Koren, Yoram Singer | |
# | |
# Author: Rohan Anil (rohananil at google dot com) | |
# | |
"""SM3 Implementation.""" | |
import functools | |
from typing import Any, NamedTuple | |
import chex | |
import jax | |
import jax.numpy as jnp | |
import optax | |
from .quantization_utils import QuantizedValue | |
class SM3State(NamedTuple): | |
count: chex.Array | |
stats: Any | |
# Per parameter optimizer state used in data-parallel training. | |
class ParameterStats(NamedTuple): | |
"""State associated to each parameter of the model being trained.""" | |
diagonal_statistics: chex.Array # Accumulator for diagonal preconditioner | |
diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner | |
def sm3( | |
learning_rate, beta1=0.9, beta2=0.999, diagonal_epsilon=1e-10, normalize_grads=False | |
): | |
"""SM3 optimizer. | |
Memory-Efficient Adaptive Optimization, Rohan Anil, Vineet Gupta, Tomer Koren, | |
Yoram Singer | |
https://arxiv.org/abs/1901.11150 | |
Args: | |
learning_rate: the step size used to update the parameters. | |
beta1: momentum parameter. | |
beta2: second moment averaging parameter. | |
diagonal_epsilon: epsilon for sm3 | |
normalize_grads: Whether to normalize grads. Author finds it useful when | |
grads are high variance. | |
Returns: | |
a GradientTransformation. | |
""" | |
def _quantize_momentum(momentum_statistics): | |
return QuantizedValue.from_float_value(momentum_statistics, jnp.int8) | |
def init_fn(params): | |
"""Initialise the optimiser's state.""" | |
def _init(param): | |
accumulators = [jnp.zeros([s]) for s in param.shape] | |
momentum = _quantize_momentum(jnp.zeros_like(param)) | |
return ParameterStats(accumulators, momentum) | |
return SM3State( | |
count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params) | |
) | |
def _get_expanded_shape(shape, i): | |
rank = len(shape) | |
# Replaces a `shape` of [M, N, K] with 1 in all dimensions except for i. | |
# For eg: i = 1 returns [1, N, 1]. | |
return [1] * i + [shape[i]] + [1] * (rank - i - 1) | |
def _moving_averages(grad, accumulators): | |
w = (1.0 - beta2) if beta2 != 1.0 else 1.0 | |
if grad.ndim < 2: | |
return beta2 * accumulators[0] + w * grad**2 | |
else: | |
min_accumulator = functools.reduce(jnp.minimum, accumulators) | |
return beta2 * min_accumulator + w * grad**2 | |
def _moving_averages_momentum(grad, momentum): | |
w = (1.0 - beta1) if beta1 != 1.0 else 1.0 | |
return beta1 * momentum.to_float() + w * grad | |
def _sketch_diagonal_statistics(grad, updated_diagonal_statistics): | |
all_diagonal_statistics = [] | |
for i in range(grad.ndim): | |
axes = list(range(i)) + list(range(i + 1, grad.ndim)) | |
dim_diagonal_statistics = jnp.max(updated_diagonal_statistics, axis=axes) | |
all_diagonal_statistics.append(dim_diagonal_statistics) | |
if grad.ndim == 1: | |
all_diagonal_statistics[0] = updated_diagonal_statistics | |
return all_diagonal_statistics | |
def update_fn(updates, state, params=None): | |
del params | |
stats = state.stats | |
if normalize_grads: | |
updates = jax.tree_map(lambda g: g / (jnp.linalg.norm(g) + 1e-16), updates) | |
# Reshape all vectors into N-d tensors to compute min over them. | |
# [n], [m] -> [n, 1], [1, m] | |
expanded_diagonal_statistics = jax.tree_multimap( | |
lambda grad, state: [ # pylint:disable=g-long-lambda | |
jnp.reshape( | |
state.diagonal_statistics[i], _get_expanded_shape(grad.shape, i) | |
) | |
for i in range(grad.ndim) | |
], | |
updates, | |
stats, | |
) | |
# Compute new diagonal statistics | |
new_diagonal_statistics = jax.tree_multimap( | |
_moving_averages, updates, expanded_diagonal_statistics | |
) | |
# Compute preconditioners (1/sqrt(s)) where s is the statistics. | |
new_preconditioners = jax.tree_map( | |
lambda t: 1.0 / jnp.sqrt(t + diagonal_epsilon), new_diagonal_statistics | |
) | |
preconditioned_grads = jax.tree_multimap( | |
lambda g, p: g * p, updates, new_preconditioners | |
) | |
# Compute updated momentum (also handle quantization) | |
updated_momentum = jax.tree_multimap( | |
lambda preconditioned_grad, state: _moving_averages_momentum( # pylint:disable=g-long-lambda | |
preconditioned_grad, state.diagonal_momentum | |
), | |
preconditioned_grads, | |
stats, | |
) | |
# Update diagonal statistics. | |
updated_diagonal_statistics = jax.tree_multimap( | |
_sketch_diagonal_statistics, updates, new_diagonal_statistics | |
) | |
# Update momentum. | |
new_sm3_stats = jax.tree_multimap( | |
lambda momentum, diagonal_stats: ParameterStats( # pylint:disable=g-long-lambda | |
diagonal_stats, _quantize_momentum(momentum) | |
), | |
updated_momentum, | |
updated_diagonal_statistics, | |
) | |
lr = learning_rate | |
if callable(learning_rate): | |
lr = learning_rate(state.count) | |
new_updates = jax.tree_map(lambda pg: -lr * pg, updated_momentum) | |
return new_updates, SM3State(count=state.count + 1, stats=new_sm3_stats) | |
return optax.GradientTransformation(init_fn, update_fn) | |