[machine-learning] What is the role of "Flatten" in Keras?

I am trying to understand the role of the Flatten function in Keras. Below is my code, which is a simple two-layer network. It takes in 2-dimensional data of shape (3, 2), and outputs 1-dimensional data of shape (1, 4):

model = Sequential()
model.add(Dense(16, input_shape=(3, 2)))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(4))
model.compile(loss='mean_squared_error', optimizer='SGD')

x = np.array([[[1, 2], [3, 4], [5, 6]]])

y = model.predict(x)

print y.shape

This prints out that y has shape (1, 4). However, if I remove the Flatten line, then it prints out that y has shape (1, 3, 4).

I don't understand this. From my understanding of neural networks, the model.add(Dense(16, input_shape=(3, 2))) function is creating a hidden fully-connected layer, with 16 nodes. Each of these nodes is connected to each of the 3x2 input elements. Therefore, the 16 nodes at the output of this first layer are already "flat". So, the output shape of the first layer should be (1, 16). Then, the second layer takes this as an input, and outputs data of shape (1, 4).

So if the output of the first layer is already "flat" and of shape (1, 16), why do I need to further flatten it?

The answer is


Here I would like to present another alternative to Flatten function. This may help to understand what is going on internally. The alternative method adds three more code lines. Instead of using

#==========================================Build a Model
model = tf.keras.models.Sequential()

model.add(keras.layers.Flatten(input_shape=(28, 28, 3)))#reshapes to (2352)=28x28x3
model.add(layers.experimental.preprocessing.Rescaling(1./255))#normalize
model.add(keras.layers.Dense(128,activation=tf.nn.relu))
model.add(keras.layers.Dense(2,activation=tf.nn.softmax))

model.build()
model.summary()# summary of the model

we can use

    #==========================================Build a Model
    tensor = tf.keras.backend.placeholder(dtype=tf.float32, shape=(None, 28, 28, 3))
    
    model = tf.keras.models.Sequential()
    
    model.add(keras.layers.InputLayer(input_tensor=tensor))
    model.add(keras.layers.Reshape([2352]))
model.add(layers.experimental.preprocessing.Rescaling(1./255))#normalize
    model.add(keras.layers.Dense(128,activation=tf.nn.relu))
    model.add(keras.layers.Dense(2,activation=tf.nn.softmax))
    
    model.build()
    model.summary()# summary of the model

In the second case, we first create a tensor (using a placeholder) and then create an Input layer. After, we reshape the tensor to flat form. So basically,

Create tensor->Create InputLayer->Reshape == Flatten

Flatten is a convenient function, doing all this automatically. Of course both ways has its specific use cases. Keras provides enough flexibility to manipulate the way you want to create a model.


short read:

Flattening a tensor means to remove all of the dimensions except for one. This is exactly what the Flatten layer do.

long read:

If we take the original model (with the Flatten layer) created in consideration we can get the following model summary:

Layer (type)                 Output Shape              Param #   
=================================================================
D16 (Dense)                  (None, 3, 16)             48        
_________________________________________________________________
A (Activation)               (None, 3, 16)             0         
_________________________________________________________________
F (Flatten)                  (None, 48)                0         
_________________________________________________________________
D4 (Dense)                   (None, 4)                 196       
=================================================================
Total params: 244
Trainable params: 244
Non-trainable params: 0

For this summary the next image will hopefully provide little more sense on the input and output sizes for each layer.

The output shape for the Flatten layer as you can read is (None, 48). Here is the tip. You should read it (1, 48) or (2, 48) or ... or (16, 48) ... or (32, 48), ...

In fact, None on that position means any batch size. For the inputs to recall, the first dimension means the batch size and the second means the number of input features.

The role of the Flatten layer in Keras is super simple:

A flatten operation on a tensor reshapes the tensor to have the shape that is equal to the number of elements contained in tensor non including the batch dimension.

enter image description here


Note: I used the model.summary() method to provide the output shape and parameter details.


It is rule of thumb that the first layer in your network should be the same shape as your data. For example our data is 28x28 images, and 28 layers of 28 neurons would be infeasible, so it makes more sense to 'flatten' that 28,28 into a 784x1. Instead of wriitng all the code to handle that ourselves, we add the Flatten() layer at the begining, and when the arrays are loaded into the model later, they'll automatically be flattened for us.


Flatten make explicit how you serialize a multidimensional tensor (tipically the input one). This allows the mapping between the (flattened) input tensor and the first hidden layer. If the first hidden layer is "dense" each element of the (serialized) input tensor will be connected with each element of the hidden array. If you do not use Flatten, the way the input tensor is mapped onto the first hidden layer would be ambiguous.


enter image description here This is how Flatten works converting Matrix to single array.


I came across this recently, it certainly helped me understand: https://www.cs.ryerson.ca/~aharley/vis/conv/

So there's an input, a Conv2D, MaxPooling2D etc, the Flatten layers are at the end and show exactly how they are formed and how they go on to define the final classifications (0-9).


Examples related to machine-learning

Error in Python script "Expected 2D array, got 1D array instead:"? How to predict input image using trained model in Keras? What is the role of "Flatten" in Keras? How to concatenate two layers in keras? How to save final model using keras? scikit-learn random state in splitting dataset Why binary_crossentropy and categorical_crossentropy give different performances for the same problem? What is the meaning of the word logits in TensorFlow? Can anyone explain me StandardScaler? Can Keras with Tensorflow backend be forced to use CPU or GPU at will?

Examples related to tensorflow

Could not load dynamic library 'cudart64_101.dll' on tensorflow CPU-only installation Module 'tensorflow' has no attribute 'contrib' Tensorflow 2.0 - AttributeError: module 'tensorflow' has no attribute 'Session' Could not install packages due to an EnvironmentError: [WinError 5] Access is denied: How do I use TensorFlow GPU? Which TensorFlow and CUDA version combinations are compatible? Could not find a version that satisfies the requirement tensorflow pip3: command not found How to import keras from tf.keras in Tensorflow? Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX AVX2

Examples related to neural-network

How to initialize weights in PyTorch? Keras input explanation: input_shape, units, batch_size, dim, etc What is the role of "Flatten" in Keras? How to concatenate two layers in keras? Why binary_crossentropy and categorical_crossentropy give different performances for the same problem? What is the meaning of the word logits in TensorFlow? How to return history of validation loss in Keras Keras model.summary() result - Understanding the # of Parameters Where do I call the BatchNormalization function in Keras? How to interpret "loss" and "accuracy" for a machine learning model

Examples related to deep-learning

How to initialize weights in PyTorch? What is the use of verbose in Keras while validating the model? How to import keras from tf.keras in Tensorflow? Keras input explanation: input_shape, units, batch_size, dim, etc Pytorch reshape tensor dimension What is the role of "Flatten" in Keras? Best way to save a trained model in PyTorch? Update TensorFlow Why binary_crossentropy and categorical_crossentropy give different performances for the same problem? Keras, How to get the output of each layer?

Examples related to keras

Could not load dynamic library 'cudart64_101.dll' on tensorflow CPU-only installation How to fix 'Object arrays cannot be loaded when allow_pickle=False' for imdb.load_data() function? Tensorflow 2.0 - AttributeError: module 'tensorflow' has no attribute 'Session' What is the use of verbose in Keras while validating the model? Save and load weights in keras How to import keras from tf.keras in Tensorflow? How to check which version of Keras is installed? Can I run Keras model on gpu? How to check if keras tensorflow backend is GPU or CPU version? Keras input explanation: input_shape, units, batch_size, dim, etc