What is a neural network flatten layer?

A neural network is an artificial intelligence model inspired by how the human brain functions. Neural networks comprise of numerous neurons that take inputs and produce outputs using a set of trainable parameters, such as weights and biases.

The flatten layer

The flatten layer is a component of the convolutional neural networks (CNN's). A complete convolutional neural network can be broken down into two parts:

  • CNN: The convolutional neural network that comprises the convolutional layers.

  • ANN: The artificial neural network that comprises dense layers.

The flatten layer lies between the CNN and the ANN, and its job is to convert the output of the CNN into an input that the ANN can process, as we can see in the diagram below.

Structure of a complete convolutional neural network
Structure of a complete convolutional neural network

Why is it used?

The flatten layer is used to convert the feature map that it received from the max-pooling layer into a format that the dense layers can understand.

A feature map is essentially a multi-dimensional array that contains pixel values; the dense layers require a one-dimensional array as input for processing. So the flatten layer is used to flatten the feature maps into a one-dimensional array for the dense layers.

Internal working

Now that we have looked at why we need a flatten layer, let's look at how the layer converts the multi-dimensional array to a one-dimensional array with the help of a diagram.

How the flatten layer works
How the flatten layer works

The diagram shows that the flatten layer takes feature maps as inputs from the max-pooling layer. The feature maps can have the shape (Height, Width, Depth) where Height x Width represents the pixel density of a single feature map and Depth shows the number of feature maps.

It then iterates over the feature maps and inserts the pixel values one by one into an array that has a size equal to the value Height x Width x Depth.

Code example

Now let's look at the Python code example below that shows the input and the output of the flatten layer by using the flatten() function offered by the numpy library.

import numpy as np
featureMaps = [
[[3,4,7], [5,6,1], [1,8,12]],
[[6,12,10], [0,21,13], [17,9,19]],
]
featureMaps = np.array(featureMaps);
print("****************** INPUT TO FLATTEN LAYER ******************")
for i, featureMap in enumerate(featureMaps):
print("\nFeature Map " + str(i + 1) + " : \n" + str(featureMap))
flattenOutput = featureMaps.flatten()
print("\n\n****************** OUTPUT OF FLATTEN LAYER ******************")
print("\nFlattened Layer Output: \n" + str(flattenOutput), end = "\n\n")

When we run the code, we see two feature maps that resemble the inputs to the flatten layer. Similarly, we also get a one-dimensional array that represents the output of the flatten layer that contains all the pixel values of the feature maps sequentially.

Code explanation

  • Line 1: First, we import the numpy library.

  • Lines 3–6: Now, we create a list that represents two feature maps of the shape (3,3).

  • Line 8: We then convert the list to a numpy array using the np.array() function.

  • Lines 12–13: We iterate over each feature map in the array and display them on the console.

  • Line 15: We use the flatten() function to flatten the multi-dimension array to a one-dimensional array that resembles the output of a flatten layer.

  • Line 19: Finally, we print the output of the flatten() function onto the console to display how the output of the flatten layer would look like.

Conclusion

The flatten layer is a crucial component of convolutional neural networks as they connect CNN's to ANN's allowing the neural network models to learn complex patterns and make predictions. While the flatten layer performs a simple operation of converting multi-dimensional arrays to a one-dimensional array, it is a fundamental tool in image-processing neural network models.

Free Resources

Copyright ©2025 Educative, Inc. All rights reserved