Python set comprehension performance difference when changing order of code execution

153 Views Asked by At

I ran into this very strange issue during work and I truly could not figure out why. I really want to post the source code directly but unfortunately company policy do not allow that, so instead I will aim to provide a minimal implementation that captures the gist of the problem, I hope you will understand.

Background and setup

I have a configurable .json file modules_config.json that basically defines the order of calling certain python functions (all described in details later), it has the following format:

...
"modules": [
    {"module_name": "int_filter_module", "module_config": { "interval": ... }}, 
    {"module_name": "list_filter_module", "module_config": { "common_ints": ... }}
],
...

and this controls the order in which the modules are applied to a list data: list[data_object], essentially via a function load_and_run_modules that looks like the following:

import json
def load_and_run_modules(
    modules_config_path: str, 
    data: list[data_object],
) -> list[data_object]:
    # bunch of validations of "modules_config_path"
    with open(modules_config_path) as f:
        modules_config = json.loads(f)
    
    # bunch of validations of "modules_config"
    modules: dict[str, Any] = modules_config["modules"]
    processed_data: list[data_object] = data
    for module in modules:
        module_name, module_config = module["module_name"], module["module_config"]

        # load_module essentially import the corresponding module based on "module_name"
        # and calls its __init__() with "module_config",
        Module = load_module(module_name, module_config)
        
        # all Module object subclass an abstract BaseModule and must override a "run" method
        processed_data = Module.run(processed_data) 
        
    return processed_data

To summarize, this load_and_run_modules function does nothing but initializing each Module object and calling its run() method sequentially, feeding the output of the previous module as the input to the next one.

For our purpose, a data_object has 3 fields: an id, an int_field and a list_field, and IntFilterModule essentially filters out data_objects from the list of input data based on the int_field, whereas ListFilterModule filters them out based on list_field. There is a bit more indirection going on, as the source of truth for each data_object's list_field is stored in a static file of "custom format" and we must rely on that file. You can think of the file as having the following format [ data_object: { id: ..., list_field: ... }, ...] if it helps.

For completeness, here are minimal implementations for IntFilterModule and ListFilterModule: int_filter_module.py:

class IntFilterModule(BaseModule): 
    def __init__(self, module_config):
        # some initialization, the important thing is that we store a list of intervals 
        # which we later use to filter the data_object's in data
        self.intervals: list[list[int, int]] = module_config["intervals"]
        ...

    def run(self, data: list[data_object]): 
        processed_data: list[data_object] = []
        
        for data_object in data:
            # keep a data_object iff its int_field is in one of the intervals, now I look 
            # at this code I realize I can condense the whole thing into a single line using
            # filter() and any(), but it is what it is :(
            for interval in self.intervals:
                if interval[0] < data_object.int_field < interval[1]:
                    processed_data.append(data_object)
        
        return processed_data       

list_filter_module.py:

class ListFilterModule(BaseModule): 
    def __init__(self, module_config):
        # some initialization, the important thing is that we store a set of ints
        # which we later use to filter the data_object's in data
        self.common_ints: set[int] = set(module_config["common_ints"])

        # this module is somewhat more complex, each data_object itself does not act as
        # the source of truth for its own list_field, instead, we must read it from a 
        # external static file in custom format. We have a custom deserialization function 
        # read() 
        with open("source_of_truth.bin", "b") as f:
            self.source_of_truth = read(f)
        ...

    def run(self, data: list[data_object]): 
        # keep a data_object iff its "list_field", as dictated by self.source_of_truth,
        # has a non-trivial intersection with self.common_ints
        data_to_keep: set[id] = {
            data_object.id 
            for data_object in self.source_of_truth
            if set(data_object.list_field) & self.common_ints
        }
        
        processed_data: list[data_object] = [data_object if data_object.id in data_to_keep]
        
        return processed_data       

The problem

The mind-boggling behavior that I discovered is the following: if I configure my modules_config.json to run IntFilterModule first then ListFilterModule second, then the ListFilterModule will run extremely slow, taking 9 mins in the worst case I observed so far. However, if I just switch their order in modules_config.json (Note: I am not touching the python files at all, only modifying the .json config file) to run ListFilterModule first and then IntFilterModule second, then the ListFilterModule will finish significantly faster, averaging at around 4 secs in all the experiments I run.

This made absolutely no sense to me. So I turned on pyinstrument to profile the program. As it turns out, when I am running ListFilterModule as the second module, the set comprehension at the beginning of run() would become super slow, and it's responsible for the entire performance diff.

My question

Why would the order of running these modules have anything to do with the performance of a set comprehension? I want to call out that in both cases this set comprehension iterates over a fixed static file, so that the number of iterations should (and I verified to my best ability) stay the same!

Some additional info

There are a total of 2 million data_objects in our application, and the static file encodes list_field for each data_object. If this makes any difference, the "custom format" of the static file is protobuf.

Applying the modules in either order resulted in the same final data list, so these 2 modules are "commutative" in this specific scenario.

I am running the program on Centos Stream 9 machine with x86_64 architecture, and reproduced the behavior on Python 3.7.9, 3.11.4, 3.13.0. I was able to reproduce this problem on more than one machine, although admittedly they were all Centos Stream 9 with x86_64 architecture.

Investigations that I've done

Initially, I had some theories regarding the possible cause, but unfortunately each and every one of them is refuted through my experiments.

Theory #1

It could just be that the modules scales differently with the size of its input data list, so if the more efficient module runs first and filter out a lot of data_object, the other less efficient module can benefit from the input size reduction and run faster.

Why I think this is not the cause

  1. From the pyinstrument result, I am certain that the performance difference was introduced by the set comprehension at the beginning of ListFilterModule.run(), but this set comprehension always iterate over a static structure, so it in theory scales O(1) with the input data list.
  2. If this was the case, then we would expect to have a performance degradation by having ListFilterModule run first. However, we see a performance improvement by doing so. In addition, if anything, IntFilterModule filtered out around 300K data_objects, that ought to make ListFilterModule run faster.
  3. By inspecting the code above, we can see that both module scales roughly linearly (O(n)) on the length of the input data list (let's call it n), and in fact IntFilterModule scales slightly worse as it runs a nested for loop, so it actually scales on the order of O(kn) where k is the length of self.intevals. Although, in my case, k == 1, so they both scales with n on the same order.

Theory #2

The set comprehension somehow dynamically changed during runtime, meaning that it iterated over more data_object's than it should.

Why I think this is not the cause

I used sys.settrace() with a custom trace function that basically gets the underlying iterator object from the frame during the set comprehension. I inspected this iterator object, writing out the sequence it generates to a file in both cases, I found no diff between the files using Linux diff.

Theory #3

Cache misses. It could be that somehow the cache is more "warmed up" if we choose to do things in a specific order.

Why I think this is not the cause

I used Linux perf stat -e cache-misses,cache-reference,instructions,cycles,branches,branch-misses,context-switches to profile the program, in an attempt to find out whether the number of cache misses were significantly different between the 2 cases. To much of my surprise, both cases had roughly the same number of cache misses, with the faster case having 2.8B cache misses and the slower case having 3.1B cache misses. However, what did stand out was the number of instructions. While the faster case only executed around 1.0T instructions, the slower case executed almost 2.9T instructions.

Somehow, even though the CPython interpreter is running on the same bytecode (I made sure of these by freezing every .py and .pyc file, removing the write permission off each of them), it would execute significantly more instructions in one case than in another.

Things that can help me proceed with my investigation

At this point, I am very much out of my depth. Of course, I'd very much appreciate an answer that explains the underlying cause, but I understand the problem at hand is quite complex, and probably not easy to answer since I wasn't able to provide a minimal working example. Thus, I also welcome any tips that can help me further my investigation.

Specifically, here are some things that I'd like to try, but couldn't figure out how:

  1. A bytecode profiler. If I had access to such thing, I would like to profile the runtime of each type of bytecode during that set comprehension execution. This would give me a very good idea of exactly which bytecode is the offender. I tried to implement one myself in the form of a trace function making use of time.perf_counter, but I am afraid it is too crappy to give an accurate measurement, if not outright wrong.
  2. A tutorial on profiling or debugging the CPython interpreter itself. I am pretty new to the CPython interpreter. In fact, as of a week ago, I didn't even know that the CPython interpreter existed (lol). However, if there is a way I can set a precise breakpoint into the C code that's executing the set comprehension function, I might be able to figure out something. I tried to read ceval.c and setobject.c, but unfortunately that didn't prove to be very fruitful due to the complexity of the code.

I am sorry for throwing this novel desperately at anyone passing by, but in case you got hit and is interested enough to stay around and help, I really appreciate your patience and kindness! In the meantime, I would try my best to reproduce this problem locally without any of my company's code, hopefully I can reproduce this problem at least somewhat.

0

There are 0 best solutions below