[python] Finding the index of elements based on a condition using python list comprehension

The following Python code appears to be very long winded when coming from a Matlab background

>>> a = [1, 2, 3, 1, 2, 3]
>>> [index for index,value in enumerate(a) if value > 2]
[2, 5]

When in Matlab I can write:

>> a = [1, 2, 3, 1, 2, 3];
>> find(a>2)
ans =
     3     6

Is there a short hand method of writing this in Python, or do I just stick with the long version?


Thank you for all the suggestions and explanation of the rationale for Python's syntax.

After finding the following on the numpy website, I think I have found a solution I like:

http://docs.scipy.org/doc/numpy/user/basics.indexing.html#boolean-or-mask-index-arrays

Applying the information from that website to my problem above, would give the following:

>>> from numpy import array
>>> a = array([1, 2, 3, 1, 2, 3])
>>> b = a>2 
array([False, False, True, False, False, True], dtype=bool)
>>> r = array(range(len(b)))
>>> r(b)
[2, 5]

The following should then work (but I haven't got a Python interpreter on hand to test it):

class my_array(numpy.array):
    def find(self, b):
        r = array(range(len(b)))
        return r(b)


>>> a = my_array([1, 2, 3, 1, 2, 3])
>>> a.find(a>2)
[2, 5]

This question is related to python

The answer is


Even if it's a late answer: I think this is still a very good question and IMHO Python (without additional libraries or toolkits like numpy) is still lacking a convenient method to access the indices of list elements according to a manually defined filter.

You could manually define a function, which provides that functionality:

def indices(list, filtr=lambda x: bool(x)):
    return [i for i,x in enumerate(list) if filtr(x)]

print(indices([1,0,3,5,1], lambda x: x==1))

Yields: [0, 4]

In my imagination the perfect way would be making a child class of list and adding the indices function as class method. In this way only the filter method would be needed:

class MyList(list):
    def __init__(self, *args):
        list.__init__(self, *args)
    def indices(self, filtr=lambda x: bool(x)):
        return [i for i,x in enumerate(self) if filtr(x)]

my_list = MyList([1,0,3,5,1])
my_list.indices(lambda x: x==1)

I elaborated a bit more on that topic here: http://tinyurl.com/jajrr87


For me it works well:

>>> import numpy as np
>>> a = np.array([1, 2, 3, 1, 2, 3])
>>> np.where(a > 2)[0]
[2 5]

Another way:

>>> [i for i in range(len(a)) if a[i] > 2]
[2, 5]

In general, remember that while find is a ready-cooked function, list comprehensions are a general, and thus very powerful solution. Nothing prevents you from writing a find function in Python and use it later as you wish. I.e.:

>>> def find_indices(lst, condition):
...   return [i for i, elem in enumerate(lst) if condition(elem)]
... 
>>> find_indices(a, lambda e: e > 2)
[2, 5]

Note that I'm using lists here to mimic Matlab. It would be more Pythonic to use generators and iterators.


  • In Python, you wouldn't use indexes for this at all, but just deal with the values—[value for value in a if value > 2]. Usually dealing with indexes means you're not doing something the best way.

  • If you do need an API similar to Matlab's, you would use numpy, a package for multidimensional arrays and numerical math in Python which is heavily inspired by Matlab. You would be using a numpy array instead of a list.

    >>> import numpy
    >>> a = numpy.array([1, 2, 3, 1, 2, 3])
    >>> a
    array([1, 2, 3, 1, 2, 3])
    >>> numpy.where(a > 2)
    (array([2, 5]),)
    >>> a > 2
    array([False, False,  True, False, False,  True], dtype=bool)
    >>> a[numpy.where(a > 2)]
    array([3, 3])
    >>> a[a > 2]
    array([3, 3])
    

Maybe another question is, "what are you going to do with those indices once you get them?" If you are going to use them to create another list, then in Python, they are an unnecessary middle step. If you want all the values that match a given condition, just use the builtin filter:

matchingVals = filter(lambda x : x>2, a)

Or write your own list comprhension:

matchingVals = [x for x in a if x > 2]

If you want to remove them from the list, then the Pythonic way is not to necessarily remove from the list, but write a list comprehension as if you were creating a new list, and assigning back in-place using the listvar[:] on the left-hand-side:

a[:] = [x for x in a if x <= 2]

Matlab supplies find because its array-centric model works by selecting items using their array indices. You can do this in Python, certainly, but the more Pythonic way is using iterators and generators, as already mentioned by @EliBendersky.