np.where
returns a tuple of length equal to the dimension of the numpy ndarray on which it is called (in other words ndim
) and each item of tuple is a numpy ndarray of indices of all those values in the initial ndarray for which the condition is True. (Please don't confuse dimension with shape)
For example:
x=np.arange(9).reshape(3,3)
print(x)
array([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
y = np.where(x>4)
print(y)
array([1, 2, 2, 2], dtype=int64), array([2, 0, 1, 2], dtype=int64))
y is a tuple of length 2 because x.ndim
is 2. The 1st item in tuple contains row numbers of all elements greater than 4 and the 2nd item contains column numbers of all items greater than 4. As you can see, [1,2,2,2] corresponds to row numbers of 5,6,7,8 and [2,0,1,2] corresponds to column numbers of 5,6,7,8
Note that the ndarray is traversed along first dimension(row-wise).
Similarly,
x=np.arange(27).reshape(3,3,3)
np.where(x>4)
will return a tuple of length 3 because x has 3 dimensions.
But wait, there's more to np.where!
when two additional arguments are added to np.where
; it will do a replace operation for all those pairwise row-column combinations which are obtained by the above tuple.
x=np.arange(9).reshape(3,3)
y = np.where(x>4, 1, 0)
print(y)
array([[0, 0, 0],
[0, 0, 1],
[1, 1, 1]])