Spaces:
Runtime error
Runtime error
import tensorflow as tf | |
from tensorflow import keras | |
class OrthogonalRegularizer(keras.regularizers.Regularizer): | |
"""Reference: https://keras.io/examples/vision/pointnet/#build-a-model""" | |
def __init__(self, num_features, l2reg=0.001): | |
self.num_features = num_features | |
self.l2reg = l2reg | |
self.identity = tf.eye(num_features) | |
def __call__(self, x): | |
identity = tf.cast(self.identity, x.dtype) | |
x = tf.reshape(x, (tf.shape(x)[0], self.num_features, self.num_features)) | |
xxt = tf.tensordot(x, x, axes=(2, 2)) | |
xxt = tf.reshape(xxt, (tf.shape(x)[0] * tf.shape(x)[0], self.num_features, self.num_features)) | |
return tf.reduce_sum(self.l2reg * tf.square(xxt - identity)) | |
def get_config(self): | |
config = {"num_features": self.num_features, "l2reg": self.l2reg} | |
return config | |