Passing output of a CNN to BILSTM

The problem is the data passed to LSTM and it can be solved inside your network. The LSTM expects 3D data while Conv2D produces 4D. There are two possibilities you can adopt:

1) make a reshape (batch_size, H, W*channel);

2) make a reshape (batch_size, W, H*channel).

In these ways, you have 3D data to use inside your LSTM. below an example

def ReshapeLayer(x):
    
    shape = x.shape
    
    # 1 possibility: H,W*channel
    reshape = Reshape((shape[1],shape[2]*shape[3]))(x)
    
    # 2 possibility: W,H*channel
    # transpose = Permute((2,1,3))(x)
    # reshape = Reshape((shape[1],shape[2]*shape[3]))(transpose)
    
    return reshape

model = Sequential()
model.add(Conv2D(filters = 16, kernel_size = 3, input_shape = (32,32,3)))
model.add(Lambda(ReshapeLayer))  # <============
model.add(LSTM(16))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss="categorical_crossentropy", optimizer="adam",)
model.summary()

Leave a Comment