What is the numpy.split() function in NumPy?

Overview

The split() function in NumPy is used to split an input array into multiple subarrays as specified by an integer value.

Syntax

numpy.split(ary, indices_or_sections, axis=0)
Syntax for the split() function in NumPy

Parameter value

The split() function takes the following parameter values:

  • ary: This is the input array to be split. This is a required parameter.
  • indices_or_sections: This is an integer representation of the number of section of the array to be split. An error is raised if the number of splits specified is not possible. This is a required parameter.
  • axis: This is the axis along which the split is done. This is an optional parameter.

Return value

The split() function returns a list of subarrays of the input array.

Example

import numpy as np
# creating the input arrray
a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
# splitting the input array into 3 sub-arrays
myarray = np.split(a, 3)
# printing the split array
print(myarray)

Explanation

  • Line 1: We import the numpy module.
  • Line 3: We create an input array a using the array() function.
  • Line 6: We split the input array into 3 subarrays using the splt() function. The result is assigned to a variable myarray.
  • Line 9: We print the new split array myarray.

Free Resources