How to replace (or insert) intermediate layer in Keras model?

The following function allows you to insert a new layer before, after or to replace each layer in the original model whose name matches a regular expression, including non-sequential models such as DenseNet or ResNet.

import re
from keras.models import Model

def insert_layer_nonseq(model, layer_regex, insert_layer_factory,
                        insert_layer_name=None, position='after'):

    # Auxiliary dictionary to describe the network graph
    network_dict = {'input_layers_of': {}, 'new_output_tensor_of': {}}

    # Set the input layers of each layer
    for layer in model.layers:
        for node in layer._outbound_nodes:
            layer_name = node.outbound_layer.name
            if layer_name not in network_dict['input_layers_of']:
                network_dict['input_layers_of'].update(
                        {layer_name: [layer.name]})
            else:
                network_dict['input_layers_of'][layer_name].append(layer.name)

    # Set the output tensor of the input layer
    network_dict['new_output_tensor_of'].update(
            {model.layers[0].name: model.input})

    # Iterate over all layers after the input
    model_outputs = []
    for layer in model.layers[1:]:

        # Determine input tensors
        layer_input = [network_dict['new_output_tensor_of'][layer_aux] 
                for layer_aux in network_dict['input_layers_of'][layer.name]]
        if len(layer_input) == 1:
            layer_input = layer_input[0]

        # Insert layer if name matches the regular expression
        if re.match(layer_regex, layer.name):
            if position == 'replace':
                x = layer_input
            elif position == 'after':
                x = layer(layer_input)
            elif position == 'before':
                pass
            else:
                raise ValueError('position must be: before, after or replace')

            new_layer = insert_layer_factory()
            if insert_layer_name:
                new_layer.name = insert_layer_name
            else:
                new_layer.name="{}_{}".format(layer.name, 
                                                new_layer.name)
            x = new_layer(x)
            print('New layer: {} Old layer: {} Type: {}'.format(new_layer.name,
                                                            layer.name, position))
            if position == 'before':
                x = layer(x)
        else:
            x = layer(layer_input)

        # Set new output tensor (the original one, or the one of the inserted
        # layer)
        network_dict['new_output_tensor_of'].update({layer.name: x})

        # Save tensor in output list if it is output in initial model
        if layer_name in model.output_names:
            model_outputs.append(x)

    return Model(inputs=model.inputs, outputs=model_outputs)

The difference with respect to the simpler case of a purely sequential model is that before iterating over the layers to find the key layer, you first parse the graph and store the input layers of each layer in an auxiliary dictionary. Then, as you iterate over the layers, you also store the new output tensor of each layer, which is used to determine the input layers of each layer, when building the new model.

A use case would be the following, where a Dropout layer is inserted after each activation layer of ResNet50:

from keras.applications.resnet50 import ResNet50
from keras.models import load_model

model = ResNet50()
def dropout_layer_factory():
    return Dropout(rate=0.2, name="dropout")
model = insert_layer_nonseq(model, '.*activation.*', dropout_layer_factory)

# Fix possible problems with new model
model.save('temp.h5')
model = load_model('temp.h5')

model.summary()

Leave a Comment