NumPy where() method

The NumPy where() method tells you where, in a NumPy array, the given condition is met.

Types of arguments


1. Single argument

If the where() is called with a single argument, this argument is the condition. Such a function call returns an array of indices.

Examples:

  1. Indices where the array elements fulfill the given condition​:
import numpy as np
arr = np.array([1,3,5,7,11]) # creating an ndarray
index_arr = np.where(arr < 6) #calling the where method
print(index_arr)
  1. An array of elements rather than indices:
arr = np.array([1,3,5,7,11]) # creating an ndarray
elements_arr = arr[np.where(arr < 6)]
print(elements_arr)
  1. Multi-dimensional arrays:
arr = np.array([[1,2,3],[11,22,33]])
index_multi_arr = np.where(arr < 15)
print(index_multi_arr)

2. Multiple arguments

The diagram below shows what multiple arguments represent in a where() method.

svg viewer
arr = np.array([1,2,3,4,5,6,7,8,9])
transformed_arr = np.where(arr<5, arr*10, 0)
print(transformed_arr)

A common challenge

Most people who are new to NumPy find the notation, mentioned below, ​to be a bit confusing:

transformed_arr = np.where([True, False, True], [1,2,3], [10,20,30])
print(transformed_arr)
svg viewer

Now, let’s see a tougher example:

# Try to determine the output yourself before
# executing the code or moving forward
transformed_arr = np.where([[True, False], [True, True]],
[[1, 2], [3, 4]],
[[10, 20], [30, 40]])
print(transformed_arr)
svg viewer

Free Resources

Copyright ©2024 Educative, Inc. All rights reserved