How to create a keras layer with a custom gradient in TF2.0?

First of all, the “unification” of the APIs (as you call it) under keras doesn’t prevent you from doing things like you did in TensorFlow 1.x. Sessions might be gone but you can still define your model like any python function and train it eagerly without keras (i.e. through tf.GradientTape)

Now, if you want to build a keras model with a custom layer that performs a custom operation and has a custom gradient, you should do the following:

a) Write a function that performs your custom operation and define your custom gradient. More info on how to do this here.

@tf.custom_gradient
def custom_op(x):
    result = ... # do forward computation
    def custom_grad(dy):
        grad = ... # compute gradient
        return grad
    return result, custom_grad

Note that in the function you should treat x and dy as Tensors and not numpy arrays (i.e. perform tensor operations)

b) Create a custom keras layer that performs your custom_op. For this example I’ll assume that your layer doesn’t have any trainable parameters or change the shape of its input, but it doesn’t make much difference if it does. For that you can refer to the guide that you posted check this one.

class CustomLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(CustomLayer, self).__init__()

    def call(self, x):
        return custom_op(x)  # you don't need to explicitly define the custom gradient
                             # as long as you registered it with the previous method

Now you can use this layer in a keras model and it will work. For example:

inp = tf.keras.layers.Input(input_shape)
conv = tf.keras.layers.Conv2D(...)(inp)  # add params like the number of filters
cust = CustomLayer()(conv)  # no parameters in custom layer
flat = tf.keras.layers.Flatten()(cust)
fc = tf.keras.layers.Dense(num_classes)(flat)

model = tf.keras.models.Model(inputs=[inp], outputs=[fc])
model.compile(loss=..., optimizer=...)  # add loss function and optimizer
model.fit(...)  # fit the model

Leave a Comment