Efficiently return the index of the first value satisfying condition in array

numba

With numba it’s possible to optimise both scenarios. Syntactically, you need only construct a function with a simple for loop:

from numba import njit

@njit
def get_first_index_nb(A, k):
    for i in range(len(A)):
        if A[i] > k:
            return i
    return -1

idx = get_first_index_nb(A, 0.9)

Numba improves performance by JIT (“Just In Time”) compiling code and leveraging CPU-level optimisations. A regular for loop without the @njit decorator would typically be slower than the methods you’ve already tried for the case where the condition is met late.

For a Pandas numeric series df['data'], you can simply feed the NumPy representation to the JIT-compiled function:

idx = get_first_index_nb(df['data'].values, 0.9)

Generalisation

Since numba permits functions as arguments, and assuming the passed the function can also be JIT-compiled, you can arrive at a method to calculate the nth index where a condition is met for an arbitrary func.

@njit
def get_nth_index_count(A, func, count):
    c = 0
    for i in range(len(A)):
        if func(A[i]):
            c += 1
            if c == count:
                return i
    return -1

@njit
def func(val):
    return val > 0.9

# get index of 3rd value where func evaluates to True
idx = get_nth_index_count(arr, func, 3)

For the 3rd last value, you can feed the reverse, arr[::-1], and negate the result from len(arr) - 1, the - 1 necessary to account for 0-indexing.

Performance benchmarking

# Python 3.6.5, NumPy 1.14.3, Numba 0.38.0

np.random.seed(0)
arr = np.random.rand(10**7)
m = 0.9
n = 0.999999

@njit
def get_first_index_nb(A, k):
    for i in range(len(A)):
        if A[i] > k:
            return i
    return -1

def get_first_index_np(A, k):
    for i in range(len(A)):
        if A[i] > k:
            return i
    return -1

%timeit get_first_index_nb(arr, m)                                 # 375 ns
%timeit get_first_index_np(arr, m)                                 # 2.71 µs
%timeit next(iter(np.where(arr > m)[0]), -1)                       # 43.5 ms
%timeit next((idx for idx, val in enumerate(arr) if val > m), -1)  # 2.5 µs

%timeit get_first_index_nb(arr, n)                                 # 204 µs
%timeit get_first_index_np(arr, n)                                 # 44.8 ms
%timeit next(iter(np.where(arr > n)[0]), -1)                       # 21.4 ms
%timeit next((idx for idx, val in enumerate(arr) if val > n), -1)  # 39.2 ms

Leave a Comment