Cleaner subset selection in pandas

Update: It turns out this post solves a problem that was already more or less solved. The .loc[selection] allows selection to be a callable. I did not know this but I was kindly pointed to it when I opened an issue in the pandas issue tracker.

My suggested subset() methods still allow some more advanced use cases (see below for *args and **kwargs, and GroupBy.subset()), but the basic functionality similar to filter() in R is provided by the simple syntax .loc[callable].

See the relevant pandas docs here.

I was recently teaching some basics in pandas to two colleagues who already know how to work with dataframes in R. They did not seem very happy about my suggested way to chain subset selection operations depending on the values of a DataFrame:

selected = (
    data
    .pipe(lambda d: d[d['column_A'] > 0])
    .pipe(lambda d: d[d['column_B'] < d['column_C']])
    # etc
)

Talking to my colleagues made me realize that these lines are potentially quite confusing to a Python beginner. The d[d[...]] appearing on each row is not exactly beautiful. A large part of these rows are meaningless repetition: in .pipe(lambda d: d[something]]) the most important content is actually only the something.

Compare to the equivalent R (dplyr) code which is clean and crystal clear:

selected <- data %>%
    filter(column.A > 0) %>%
    filter(column.B < column.C) %>%
    #etc

Alternatives in pandas

To my knowledge there is no equivalent function of the dplyr filter in pandas. There is a DataFrame.filter method but it only selects according to labels. There is also a DataFrame.where which almost does the right thing, but non-selected rows are filled with nan instead of being dropped.

Don't get me wrong, both DataFrame.filter and DataFrame.where are useful, but they don't provide anything like the filter function in R.

A monkey-patch solution for NDFrame

So I ended up monkey-patching pandas to deal with the problem. The first patch goes on NDFrame and thus provides the syntax .subset(predicate) on both Series and DataFrame.

selected = (
    data
    .subset(lambda d: d['column_A'] > 0)
    .subset(lambda d: d['column_B'] < d['column_C'])
    # etc
)

Actually, this code is not much shorter than my usual way with .pipe(...) but I find that it is much more legible. Less brackets to confuse the eyes, and the verb subset is more informative than pipe.

And a similar one for GroupBy

The other monkey-patch goes on GroupBy. For example, to select all the above-mean items in each group:

above_mean = (
    my_series
    .groupby('A')  # usually some index level, but could be any grouping
    .subset(lambda s: s > s.mean())
)

This one I find very useful. However, it is slightly questionable because I built it to return an index with the same levels as the original data. This is somewhat unusual compared to most GroupBy operations, whose results normally have the group keys as index (i.e., one row per group).

The code and a notebook with examples

The code example below shows my first try at an implementation. It seems to work alright, but I probably will find some edge cases as I use it more.

I also made a notebook with some more examples.

import pandas as pd

def subset(d, predicate, *args, **kwargs):
    """
    Get a subset of a pandas NDFrame using a predicate.

    Equivalent to d[predicate(d, *args, **kwargs)]
    """
    return d[predicate(d, *args, **kwargs)]

# Monkey-patch the NDFrame to allow DataFrame.subset(predicate)
pd.core.generic.NDFrame.subset = subset

def groupby_subset(groupby, predicate, *args, **kwargs):
    """
    Apply subsetting for each group in a pandas GroupBy.

    Roughly equivalent to
    """
    result = groupby.apply(subset, predicate, *args, **kwargs)

    any_group_index = next(iter(groupby.groups.values()))
    orig_n_levels = len(any_group_index.levels)
    result_n_levels = len(result.index.levels)
    levels_to_drop = list(range(result_n_levels - orig_n_levels))

    return result.droplevel(levels_to_drop)

# And monkey-patch GroupBy to allow GroupBy.subset(predicate)
pd.core.groupby.GroupBy.subset = groupby_subset