boris's picture
feat: support pod (#139)
803ccbf unverified
raw
history blame
6.19 kB
# coding=utf-8
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# An implementation of SM3 from:
#
# Memory-Efficient Adaptive Optimization, https://arxiv.org/pdf/1901.11150.pdf
# Rohan Anil, Vineet Gupta, Tomer Koren, Yoram Singer
#
# Author: Rohan Anil (rohananil at google dot com)
#
"""SM3 Implementation."""
import functools
from typing import Any, NamedTuple
import chex
import jax
import jax.numpy as jnp
import optax
from .quantization_utils import QuantizedValue
class SM3State(NamedTuple):
count: chex.Array
stats: Any
# Per parameter optimizer state used in data-parallel training.
class ParameterStats(NamedTuple):
"""State associated to each parameter of the model being trained."""
diagonal_statistics: chex.Array # Accumulator for diagonal preconditioner
diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
def sm3(
learning_rate, beta1=0.9, beta2=0.999, diagonal_epsilon=1e-10, normalize_grads=False
):
"""SM3 optimizer.
Memory-Efficient Adaptive Optimization, Rohan Anil, Vineet Gupta, Tomer Koren,
Yoram Singer
https://arxiv.org/abs/1901.11150
Args:
learning_rate: the step size used to update the parameters.
beta1: momentum parameter.
beta2: second moment averaging parameter.
diagonal_epsilon: epsilon for sm3
normalize_grads: Whether to normalize grads. Author finds it useful when
grads are high variance.
Returns:
a GradientTransformation.
"""
def _quantize_momentum(momentum_statistics):
return QuantizedValue.from_float_value(momentum_statistics, jnp.int8)
def init_fn(params):
"""Initialise the optimiser's state."""
def _init(param):
accumulators = [jnp.zeros([s]) for s in param.shape]
momentum = _quantize_momentum(jnp.zeros_like(param))
return ParameterStats(accumulators, momentum)
return SM3State(
count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params)
)
def _get_expanded_shape(shape, i):
rank = len(shape)
# Replaces a `shape` of [M, N, K] with 1 in all dimensions except for i.
# For eg: i = 1 returns [1, N, 1].
return [1] * i + [shape[i]] + [1] * (rank - i - 1)
def _moving_averages(grad, accumulators):
w = (1.0 - beta2) if beta2 != 1.0 else 1.0
if grad.ndim < 2:
return beta2 * accumulators[0] + w * grad**2
else:
min_accumulator = functools.reduce(jnp.minimum, accumulators)
return beta2 * min_accumulator + w * grad**2
def _moving_averages_momentum(grad, momentum):
w = (1.0 - beta1) if beta1 != 1.0 else 1.0
return beta1 * momentum.to_float() + w * grad
def _sketch_diagonal_statistics(grad, updated_diagonal_statistics):
all_diagonal_statistics = []
for i in range(grad.ndim):
axes = list(range(i)) + list(range(i + 1, grad.ndim))
dim_diagonal_statistics = jnp.max(updated_diagonal_statistics, axis=axes)
all_diagonal_statistics.append(dim_diagonal_statistics)
if grad.ndim == 1:
all_diagonal_statistics[0] = updated_diagonal_statistics
return all_diagonal_statistics
def update_fn(updates, state, params=None):
del params
stats = state.stats
if normalize_grads:
updates = jax.tree_map(lambda g: g / (jnp.linalg.norm(g) + 1e-16), updates)
# Reshape all vectors into N-d tensors to compute min over them.
# [n], [m] -> [n, 1], [1, m]
expanded_diagonal_statistics = jax.tree_multimap(
lambda grad, state: [ # pylint:disable=g-long-lambda
jnp.reshape(
state.diagonal_statistics[i], _get_expanded_shape(grad.shape, i)
)
for i in range(grad.ndim)
],
updates,
stats,
)
# Compute new diagonal statistics
new_diagonal_statistics = jax.tree_multimap(
_moving_averages, updates, expanded_diagonal_statistics
)
# Compute preconditioners (1/sqrt(s)) where s is the statistics.
new_preconditioners = jax.tree_map(
lambda t: 1.0 / jnp.sqrt(t + diagonal_epsilon), new_diagonal_statistics
)
preconditioned_grads = jax.tree_multimap(
lambda g, p: g * p, updates, new_preconditioners
)
# Compute updated momentum (also handle quantization)
updated_momentum = jax.tree_multimap(
lambda preconditioned_grad, state: _moving_averages_momentum( # pylint:disable=g-long-lambda
preconditioned_grad, state.diagonal_momentum
),
preconditioned_grads,
stats,
)
# Update diagonal statistics.
updated_diagonal_statistics = jax.tree_multimap(
_sketch_diagonal_statistics, updates, new_diagonal_statistics
)
# Update momentum.
new_sm3_stats = jax.tree_multimap(
lambda momentum, diagonal_stats: ParameterStats( # pylint:disable=g-long-lambda
diagonal_stats, _quantize_momentum(momentum)
),
updated_momentum,
updated_diagonal_statistics,
)
lr = learning_rate
if callable(learning_rate):
lr = learning_rate(state.count)
new_updates = jax.tree_map(lambda pg: -lr * pg, updated_momentum)
return new_updates, SM3State(count=state.count + 1, stats=new_sm3_stats)
return optax.GradientTransformation(init_fn, update_fn)