Simon Duerr
add fast af
85bd48b
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A collection of common Haiku modules for use in protein folding."""
import haiku as hk
import jax.numpy as jnp
class Linear(hk.Module):
"""Protein folding specific Linear Module.
This differs from the standard Haiku Linear in a few ways:
* It supports inputs of arbitrary rank
* Initializers are specified by strings
"""
def __init__(self,
num_output: int,
initializer: str = 'linear',
use_bias: bool = True,
bias_init: float = 0.,
name: str = 'linear'):
"""Constructs Linear Module.
Args:
num_output: number of output channels.
initializer: What initializer to use, should be one of {'linear', 'relu',
'zeros'}
use_bias: Whether to include trainable bias
bias_init: Value used to initialize bias.
name: name of module, used for name scopes.
"""
super().__init__(name=name)
self.num_output = num_output
self.initializer = initializer
self.use_bias = use_bias
self.bias_init = bias_init
def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
"""Connects Module.
Args:
inputs: Tensor of shape [..., num_channel]
Returns:
output of shape [..., num_output]
"""
n_channels = int(inputs.shape[-1])
weight_shape = [n_channels, self.num_output]
if self.initializer == 'linear':
weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=1.)
elif self.initializer == 'relu':
weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=2.)
elif self.initializer == 'zeros':
weight_init = hk.initializers.Constant(0.0)
weights = hk.get_parameter('weights', weight_shape, inputs.dtype,
weight_init)
# this is equivalent to einsum('...c,cd->...d', inputs, weights)
# but turns out to be slightly faster
inputs = jnp.swapaxes(inputs, -1, -2)
output = jnp.einsum('...cb,cd->...db', inputs, weights)
output = jnp.swapaxes(output, -1, -2)
if self.use_bias:
bias = hk.get_parameter('bias', [self.num_output], inputs.dtype,
hk.initializers.Constant(self.bias_init))
output += bias
return output