Python decorator for C++ class virtual methods using pybind11

41 Views Asked by At

Let me explain the issue using a simple code example.
counter_wrapper.h:

#include <iostream>

class CounterWrapper {
public:
    CounterWrapper();
    virtual ~CounterWrapper();
    virtual void count_a(const int a);
protected:
    int counter;
};
void register_counter(CounterWrapper *counter);
void call_counts();

counter_wrapper.cpp:

#include <iostream>
#include "counter_wrapper.h"

CounterWrapper::CounterWrapper() {
    std::cout << "CounterWrapper::CounterWrapper()" << std::endl;
}
CounterWrapper::~CounterWrapper() {
    std::cout << "CounterWrapper::~CounterWrapper()" << std::endl;
}
void CounterWrapper::count_a(const int a) {
    counter += a;
    std::cout << "CounterWrapper::count_a() with a=" << a << std::endl;
}

CounterWrapper *p_counter = 0;
void register_counter(CounterWrapper *counter) {
    p_counter = counter;
}
void call_counts() {
    p_counter->count_a(1);
}

counter_wrapper_pybind.cpp:

#include <pybind11/pybind11.h>
#include "counter_wrapper.h"

class PyCounterWrapper : public CounterWrapper {
public:
  // Inherit the constructors
  using CounterWrapper::CounterWrapper;
  // Trampolines (need one for each virtual function)
  void count_a(const int a) override {
    PYBIND11_OVERRIDE(void, CounterWrapper, count_a, a);
  }
};

PYBIND11_MODULE(counter_wrapper_pybind, m) {
  m.doc() = "CounterWrapper binding";
  // CounterWrapper bindings
  pybind11::class_<CounterWrapper, PyCounterWrapper /* <--- trampoline*/>(m, "CounterWrapper")
    .def(pybind11::init<>())
    .def("count_a", &CounterWrapper::count_a)
  ;
  m.def("register_counter",  &register_counter, "");
  m.def("call_counts",       &call_counts,      "");
}

test_counter_wrapper.py:

from counter_wrapper_pybind import CounterWrapper
from counter_wrapper_pybind import register_counter
from counter_wrapper_pybind import call_counts

def decorate_all_counts():
    def top_decorate(cls):
        def sub_decorate(func):
            def wrapper(*args, **kwargs):
                print('WRAPPER: function \'{}.{}({})\' was called'.format(cls.__name__, func.__name__, args))
                return func(*args, **kwargs)
            return wrapper
        method_l = [method for method in dir(cls) if callable(getattr(cls, method)) and method.startswith("count")]
        for method in method_l:
            setattr(cls, method, sub_decorate(getattr(cls, method)))
        return cls
    return top_decorate

@decorate_all_counts()
class MyCounterWrapper(CounterWrapper):
    def __init__(self):
        super().__init__()
        print('MyCounterWrapper() is created', flush=True)
    # # Uncomment to pass the TC
    # def count_a(self, a):
    #     print(f'MyCounterWrapper::count_a({a})', flush=True)
    #     super().count_a(a)

my_counter_wrapper = MyCounterWrapper()
register_counter(my_counter_wrapper)
call_counts()

CMakeLists.txt:

cmake_minimum_required(VERSION 3.25)
find_package(pybind11 REQUIRED)

set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)

add_library( counter_wrapper counter_wrapper.cpp )
pybind11_add_module( counter_wrapper_pybind SHARED counter_wrapper_pybind.cpp )

target_link_libraries(counter_wrapper_pybind PRIVATE counter_wrapper )
target_include_directories(counter_wrapper_pybind PRIVATE counter_wrapper )

To reproduce the issue:

  1. Make sure that both pybind11 and cmake packages are available
  2. Put 5 files to a single folder
  3. Build: mkdir build && cd build && cmake .. && make -j16 && cd ..
  4. export PYTHONPATH=$PYTHONPATH:$PWD/build
  5. python3 test_counter_wrapper.py

I would like to introduce a Py callback (where the WRAPPER print is placed) for every call of C++ count_* function (in the original code there are hundreds of different C++ count_* functions: count_a, count_b, count_c ...).
The current example works only for the cases when count_a function is overridden in MyCounterWrapper() class. Otherwise, such error appears:

RecursionError: maximum recursion depth exceeded while getting the repr of an object

Function call_counts() is needed to call other count functions from C++ side.
The pybind trampoline is needed to catch the C++ function calls in Py.
The class decorator is to have a Py callback for every count function call.\

I would appreciate any ideas on either how to fix that or how to achieve the same results using a different approach.

0

There are 0 best solutions below