[python] How does python numpy.where() work?

I am playing with numpy and digging through documentation and I have come across some magic. Namely I am talking about numpy.where():

>>> x = np.arange(9.).reshape(3, 3)
>>> np.where( x > 5 )
(array([2, 2, 2]), array([0, 1, 2]))

How do they achieve internally that you are able to pass something like x > 5 into a method? I guess it has something to do with __gt__ but I am looking for a detailed explanation.

This question is related to python numpy magic-methods

The answer is


Old Answer it is kind of confusing. It gives you the LOCATIONS (all of them) of where your statment is true.

so:

>>> a = np.arange(100)
>>> np.where(a > 30)
(array([31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
       48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
       65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
       82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98,
       99]),)
>>> np.where(a == 90)
(array([90]),)

a = a*40
>>> np.where(a > 1000)
(array([26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
       43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
       60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
       77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93,
       94, 95, 96, 97, 98, 99]),)
>>> a[25]
1000
>>> a[26]
1040

I use it as an alternative to list.index(), but it has many other uses as well. I have never used it with 2D arrays.

http://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html

New Answer It seems that the person was asking something more fundamental.

The question was how could YOU implement something that allows a function (such as where) to know what was requested.

First note that calling any of the comparison operators do an interesting thing.

a > 1000
array([False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True`,  True,  True,  True,  True,  True,  True,  True,  True,  True], dtype=bool)`

This is done by overloading the "__gt__" method. For instance:

>>> class demo(object):
    def __gt__(self, item):
        print item


>>> a = demo()
>>> a > 4
4

As you can see, "a > 4" was valid code.

You can get a full list and documentation of all overloaded functions here: http://docs.python.org/reference/datamodel.html

Something that is incredible is how simple it is to do this. ALL operations in python are done in such a way. Saying a > b is equivalent to a.gt(b)!


np.where returns a tuple of length equal to the dimension of the numpy ndarray on which it is called (in other words ndim) and each item of tuple is a numpy ndarray of indices of all those values in the initial ndarray for which the condition is True. (Please don't confuse dimension with shape)

For example:

x=np.arange(9).reshape(3,3)
print(x)
array([[0, 1, 2],
      [3, 4, 5],
      [6, 7, 8]])
y = np.where(x>4)
print(y)
array([1, 2, 2, 2], dtype=int64), array([2, 0, 1, 2], dtype=int64))


y is a tuple of length 2 because x.ndim is 2. The 1st item in tuple contains row numbers of all elements greater than 4 and the 2nd item contains column numbers of all items greater than 4. As you can see, [1,2,2,2] corresponds to row numbers of 5,6,7,8 and [2,0,1,2] corresponds to column numbers of 5,6,7,8 Note that the ndarray is traversed along first dimension(row-wise).

Similarly,

x=np.arange(27).reshape(3,3,3)
np.where(x>4)


will return a tuple of length 3 because x has 3 dimensions.

But wait, there's more to np.where!

when two additional arguments are added to np.where; it will do a replace operation for all those pairwise row-column combinations which are obtained by the above tuple.

x=np.arange(9).reshape(3,3)
y = np.where(x>4, 1, 0)
print(y)
array([[0, 0, 0],
   [0, 0, 1],
   [1, 1, 1]])

Examples related to python

programming a servo thru a barometer Is there a way to view two blocks of code from the same file simultaneously in Sublime Text? python variable NameError Why my regexp for hyphenated words doesn't work? Comparing a variable with a string python not working when redirecting from bash script is it possible to add colors to python output? Get Public URL for File - Google Cloud Storage - App Engine (Python) Real time face detection OpenCV, Python xlrd.biffh.XLRDError: Excel xlsx file; not supported Could not load dynamic library 'cudart64_101.dll' on tensorflow CPU-only installation

Examples related to numpy

Unable to allocate array with shape and data type How to fix 'Object arrays cannot be loaded when allow_pickle=False' for imdb.load_data() function? Numpy, multiply array with scalar TypeError: only integer scalar arrays can be converted to a scalar index with 1D numpy indices array Could not install packages due to a "Environment error :[error 13]: permission denied : 'usr/local/bin/f2py'" Pytorch tensor to numpy array Numpy Resize/Rescale Image what does numpy ndarray shape do? How to round a numpy array? numpy array TypeError: only integer scalar arrays can be converted to a scalar index

Examples related to magic-methods

PHP - Indirect modification of overloaded property Best practice: PHP Magic Methods __set and __get Python __call__ special method practical example How does python numpy.where() work? PHP __get and __set magic methods What is the difference between __str__ and __repr__?