The NumPy where()
method tells you where, in a NumPy array, the given condition is met.
If the where()
is called with a single argument, this argument is the condition. Such a function call returns an array of indices.
import numpy as nparr = np.array([1,3,5,7,11]) # creating an ndarrayindex_arr = np.where(arr < 6) #calling the where methodprint(index_arr)
arr = np.array([1,3,5,7,11]) # creating an ndarrayelements_arr = arr[np.where(arr < 6)]print(elements_arr)
arr = np.array([[1,2,3],[11,22,33]])index_multi_arr = np.where(arr < 15)print(index_multi_arr)
The diagram below shows what multiple arguments represent in a where()
method.
arr = np.array([1,2,3,4,5,6,7,8,9])transformed_arr = np.where(arr<5, arr*10, 0)print(transformed_arr)
Most people who are new to NumPy find the notation, mentioned below, to be a bit confusing:
transformed_arr = np.where([True, False, True], [1,2,3], [10,20,30])print(transformed_arr)
Now, let’s see a tougher example:
# Try to determine the output yourself before# executing the code or moving forwardtransformed_arr = np.where([[True, False], [True, True]],[[1, 2], [3, 4]],[[10, 20], [30, 40]])print(transformed_arr)
Free Resources