Keras, Tensorflow : Merge two different model output into one

Trainable weights

Ok. Since you are going to have custom trainable weights, the way to do this in Keras is creating a custom layer.

Now, since your custom layer has no inputs, we will need a hack that will be explained later.

So, this is the layer definition for the custom weights:

from keras.layers import *
from keras.models import Model
from keras.initializers import get as get_init, serialize as serial_init
import keras.backend as K
import tensorflow as tf


class TrainableWeights(Layer):

    #you can pass keras initializers when creating this layer
    #kwargs will take base layer arguments, such as name and others if you want
    def __init__(self, shape, initializer="uniform", **kwargs):
        super(TrainableWeights, self).__init__(**kwargs)
        self.shape = shape
        self.initializer = get_init(initializer)
        

    #build is where you define the weights of the layer
    def build(self, input_shape):
        self.kernel = self.add_weight(name="kernel", 
                                      shape=self.shape, 
                                      initializer=self.initializer, 
                                      trainable=True)
        self.built = True
        

    #call is the layer operation - due to keras limitation, we need an input
    #warning, I'm supposing the input is a tensor with value 1 and no shape or shape (1,)
    def call(self, x):
        return x * self.kernel
    

    #for keras to build the summary properly
    def compute_output_shape(self, input_shape):
        return self.shape
    

    #only needed for saving/loading this layer in model.save()
    def get_config(self):
        config = {'shape': self.shape, 'initializer': serial_init(self.initializer)}
        base_config = super(TrainableWeights, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

Now, this layer should be used like this:

dummyInputs = Input(tensor=K.constant([1]))
trainableWeights = TrainableWeights(shape)(dummyInputs)

Model A

Having the layer defined, we can start modeling.
First, let’s see the model_a side:

#general vars
length = 150
dic_size = 100
embed_size = 12

#for the model_a segment
input_text = Input(shape=(length,))
embedding = Embedding(dic_size, embed_size)(input_text)

#the following two lines are just a resource to reach the desired shape
embedding = LSTM(5)(embedding) 
embedding = Dense(50)(embedding)

#creating model_a here is optional, only if you want to use model_a independently later
model_a = Model(input_text, embedding, name="model_a")

Model B

For this, we are going to use our TrainableWeights layer.
But first, let’s simulate a New_model() as mentioned.

#simulates New_model() #notice the explicit batch_shape for the matrices
newIn1 = Input(batch_shape = (10,10))
newIn2 = Input(batch_shape = (10,30))
newOut1 = Dense(50)(newIn1)
newOut2 = Dense(50)(newIn2)
newOut = Add()([newOut1, newOut2])
new_model = Model([newIn1, newIn2], newOut, name="new_model")   

Now the entire branch:

#the matrices    
dummyInput = Input(tensor = K.constant([1]))
X_in = TrainableWeights((10,10), initializer="uniform")(dummyInput)
M_in = TrainableWeights((10,30), initializer="uniform")(dummyInput)

#the output of the branch   
md_1 = new_model([X_in, M_in])

#optional, only if you want to use model_s independently later
model_s = Model(dummyInput, md_1, name="model_s")

The whole model

Finally, we can join the branches in a whole model.
Notice how I didn’t have to use model_a or model_s here. You can do it if you want, but those submodels are not needed, unless you want later to get them individually for other usages. (Even if you created them, you don’t need to change the code below to use them, they’re already part of the same graph)

#I prefer tf.matmul because it's clear and understandable while K.dot has weird behaviors
mult = Lambda(lambda x: tf.matmul(x[0], x[1], transpose_b=True))([embedding, md_1])

#final model
model = Model([input_text, dummyInput], mult, name="full_model")

Now train it:

model.compile('adam', 'binary_crossentropy', metrics=['accuracy'])
model.fit(np.random.randint(0,dic_size, size=(128,length)),
          np.ones((128, 10)))

Since the output is 2D now, there is no problem about the 'categorical_crossentropy', my comment was because of doubts on the output shape.

Leave a Comment