What is the numpy.squeeze() function in NumPy?

Overview

The squeeze() function in NumPy is used to remove an axis of length 1 from an input array.

Axes in NumPy are defined for arrays having more than one dimension. For example, a 2-D array has two corresponding axes: the axes running vertically downward across rows (this is axis 0), and the axes running horizontally across columns (this is axis 1).

Syntax

numpy.squeeze(a, axis=None)
Syntax for the squeeze() function

Parameter value

The squeeze() function takes the following parameter values.

  • a: This is the input array. It is a required parameter.
  • axis: This selects a subset of the length in the given shape. It is an optional parameter.

Return value

The squeeze() function returns the input array, a, but with the subset of the dimension with length 1 removed.

Example

import numpy as np
# creating an input array
a = np.array([[[1], [2], [3], [4]]])
# getting the length of a
print(a.shape)
# removing the dimensions with length 1
b = np.squeeze(a)
# obtaining the shape of the new array
print(b.shape)

Code explanation

  • Line 1: We import the numpy module.
  • Line 3: We create an input array, a , using the array() function.
  • Line 6: We obtain and print the dimensions of a using the shape attribute.
  • Line 9: We remove the dimension of length 1 from the input array, a, using the squeeze() function. The result is assigned to a variable, b.
  • Line 12: We obtain and print the squeezed array, b, with the dimensions of length1 removed.