Update
Question / Goal
How can I write a decorator that takes in a function
fnand creates a dataclass where each argument / keyword-argument is a field and the docstring is copied over for better intellisense support. I don't want to see**kwargs:AnyI want to know what the variables are.
In the MWE below there are the following things:
get_func_params: utility function for getting default parameters from a functionSomeModuleType: a dummy type to check if intelli-sense is showing what kind of types hints are being copied-overmwe_func: a dummy function that has a docstring and some type annotations. The decoratorfunc_to_classshould copy over the docstring and function annotations for the new class constructor. Ideally all the arguments are fields in of the new dataclass.Class2Subclass: a dummy class for the decorated class to subclass.mwe_class: the dummy class that we are decoratingfunc_to_class: the decorator function
M.W.E.
Imports
from dataclasses import dataclass, field, fields, _FIELD
from typing import get_type_hints, Optional, NamedTuple, Callable, Dict, Any, List
import inspect
from inspect import Signature, Parameter
from functools import wraps
import numpy as np
Utils
def get_func_params(
fn: Callable,
drop_self: Optional[bool] = True,
drop_before: Optional[int] = 0,
drop_idxs: Optional[List[int]] = list(),
drop_names: Optional[List[str]] = list(),
drop_after: Optional[int] = None,
) -> Dict[str, Any]:
params = inspect.signature(fn).parameters
params = {k: v.default for k, v in params.items()}
if drop_self and 'self' in params:
params.pop('self')
params = {
n: p for i, (n, p) in enumerate(params.items())
if (
# is before <= i < after
(drop_before <= i or (drop_after is not None and i < drop_after))
# i not in drop_idxs and n not in drop_names
and (i not in drop_idxs and n not in drop_names)
)
}
return params
Dummy Functions / Classes
class SomeModuleType(NamedTuple):
a: int
b: str
def mwe_func(data:np.ndarray, a_bool:bool=False, a_thing:Optional[SomeModuleType]=None) -> np.ndarray:
'''
Parameters
----------
data : np.ndarray
A numpy array of data
a_bool : bool, default=False
A boolean
a_thing : Optinoal[SomeModuleType]
A thing
'''
# ...
return data
class Class2Subclass:
def expected_method(self, a:int=0):
pass
pass
Example
@func_to_class(mwe_func)
class mwe_class:
pass
No intelli-sense (that is co-pilot suggesting things)
Original
I am working with scanpy and sklearn. Currently I have the following:
import scanpy as sp, anndata as ad, numpy as np, pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin
from dataclasses import dataclass
import inspect
from functools import wraps
from typing import List, Any, Optional, Callable, Union, Tuple, Iterable, Set, TypeAlias, Type, Dict
def filter_kwargs_for_func(fn: Callable, **kwargs:Optional[dict]):
params = inspect.signature(fn).parameters
return {k:v for k,v in kwargs.items() if k in params}
def filter_kwargs_for_class(cls: Callable, **kwargs:Optional[dict]):
params = inspect.signature(cls.__init__).parameters
return {k:v for k,v in kwargs.items() if k in params}
def wrangle_kwargs_for_func(
fn: Callable,
defaults: Optional[dict]=None,
**kwargs:Optional[dict]
) -> dict:
# copy defaults
params = (defaults or {}).copy()
# update with kwargs of our function
params.update(kwargs or {})
# filter for only the params that other function accepts
params = filter_kwargs_for_func(fn, **params)
return params
def wrangle_kwargs_for_class(
cls: Callable,
defaults: Optional[dict]=None,
**kwargs:Optional[dict]
) -> dict:
# copy defaults
params = (defaults or {}).copy()
# update with kwargs of our class
params.update(kwargs or {})
# filter for only the params that other class accepts
params = filter_kwargs_for_class(cls, **params)
return params
def get_func_params(
fn: Callable,
drop_self: Optional[bool] = True,
drop_before: Optional[int] = 0,
drop_idxs: Optional[List[int]] = list(),
drop_names: Optional[List[str]] = list(),
drop_after: Optional[int] = None,
) -> Dict[str, Any]:
params = inspect.signature(fn).parameters
params = {k: v.default for k, v in params.items()}
if drop_self and 'self' in params:
params.pop('self')
params = {
n: p for i, (n, p) in enumerate(params.items())
if (
# is before <= i < after
(drop_before <= i or (drop_after is not None and i < drop_after))
# i not in drop_idxs and n not in drop_names
and (i not in drop_idxs and n not in drop_names)
)
}
return params
@dataclass
class MyPipeline:
...
def preprocess_data(self, min_genes: int = 200, min_cells: int = 3):
sc.pp.filter_cells(self.data, min_genes=min_genes)
sc.pp.filter_genes(self.data, min_cells=min_cells)
self.data.raw = self.data
sc.pp.normalize_total(self.data, target_sum=1e4)
sc.pp.log1p(self.data)
sc.pp.highly_variable_genes(self.data, min_mean=0.0125, max_mean=3, min_disp=0.5)
self.data = self.data[:, self.data.var.highly_variable]
sc.pp.scale(self.data, max_value=10)
sc.tl.pca(self.data, svd_solver='arpack')
sc.pp.neighbors(self.data, n_neighbors=10, n_pcs=40)
sc.tl.umap(self.data)
Where the focus here is on MyPipeline. Right now it isn't very flexible because very few keyword arguments are exposed and in the event some functions share the same keyword argument which function it belongs to.
Initially all I wanted was a way to to specify something like
@fn_kwargs(sc.pp.filter_cells)
class FilterCellKWArgs:
pass
...
def pipeline(filter_cells_kwargs:FilterCellKWArgs, ...):
...
...
and then have intellisense show me (or anyone else) what args / keyword arguments these functions have available, what their defaults are and maybe even the docstring of the original function. That doesn't seem too tenable.
So now I am think it be useful to wrap functions like sc.pp.filter_cells , sc.pp.highly_variable_genes and sc.pp.scale as sklearn operators e.g. BaseEstimators / TransformerMixin / etc.
As the arguments / keyword arguments would be set on construction and then an sklearn Pipeline can handle the rest. So I am looking for something like this
@scop(sc.pp.filter_cells)
@dataclass
class FilterCells:
pass
which should be functionally equivalent to
@dataclass
class FilterCells(BaseEstimator, TransformerMixin):
# NOTE: these are the defaults for sc.pp.filter_cells
# you can get them from inspect.signature(sc.pp.filter_cells)
# data: ad.Anndata
min_counts: Optional[int] = None
min_genes: Optional[int] = None
max_counts: Optional[int] = None
max_genes: Optional[int] = None
inplace: bool = True
copy: bool = False
def fit(self, X: ad.AnnData, y=None):
# NOTE: this is a dummy method
# as we don't need to fit anything, just call the wrapped
# function sc.pp.filtered_cells
pass
def transform(self, X):
Y = sc.pp.filter_cells(
X, min_counts=self.min_counts, min_genes=self.min_genes,
max_counts=self.max_counts, max_genes=self.max_genes,
inplace=self.inplace, copy=self.copy
)
return X if self.inplace else Y
def fit_transform(self, X, y=None):
return self.fit(X, y).transform(X)
And I have tried quite a few things (see below).
Specific Question
How can I write a decorator that in one of these functions (or any function really) and creates a dataclass where each argument / keyword-argument is a field and the docstring is copied over for better intellisense support. I don't want to see **kwargs:Any I want to know what the variables are.
Attempts
Current Attempt
from dataclasses import dataclass, field, fields, _FIELD
from typing import get_type_hints
from inspect import Signature, Parameter
def scop(fn):
params = get_func_params(fn, drop_self=False)
params = {k: v for k, v in params.items() if v is not inspect.Parameter.empty}
def class_decorator(cls):
cls = dataclass(cls) # Ensure cls is a dataclass
# Add fields from fn to cls
for name, default in params.items():
if name not in get_type_hints(cls):
field_obj = field(default=default)
setattr(cls, name, field_obj)
cls.__annotations__[name] = type(default)
# Update __init__ method to include new fields
def __init__(self, **kwargs):
for name, value in kwargs.items():
setattr(self, name, value)
cls.__init__ = __init__
# Update __init__ method signature
sig = inspect.signature(fn)
parameters = [
Parameter(name, Parameter.KEYWORD_ONLY, default=default)
for name, default in params.items()
]
cls.__init__.__signature__ = sig.replace(parameters=parameters)
# Add methods to cls
def fit(self, X, y=None):
return self
def transform(self, X):
kwargs = {f.name: getattr(self, f.name) for f in fields(self)}
fn(X, **kwargs)
return X
def fit_transform(self, X, y=None):
return self.fit(X, y).transform(X)
cls.fit = fit
cls.transform = transform
cls.fit_transform = fit_transform
# Update docstrings
cls.__doc__ = fn.__doc__
cls.fit.__doc__ = fn.__doc__
cls.transform.__doc__ = fn.__doc__
cls.fit_transform.__doc__ = fn.__doc__
return cls
return class_decorator
which does basically work:
@dataclass
@scop(sc.pp.filter_cells)
class FilterCells:
pass
fc = FilterCells(min_cells=3)
print(fc.min_cells)
# 3
But. I still can not see the docstring / args / keyword-args when I am typing FilterCells(...) and now fit, fit_transform just show Any rather than (X, y=None)
I also loose the BaseEstimator repr so there is that...
Another Notable Attempt
def scop(fn):
params = get_func_params(fn, drop_self=False)
def class_decorator(cls):
class Wrapper(cls, BaseEstimator, TransformerMixin):
fn_params = {k: v for k, v in params.items() if v is not inspect.Parameter.empty}
def __init__(self, **kwargs):
self.params = {**self.fn_params, **kwargs}
super().__init__()
def fit(self, X, y=None):
return self
def transform(self, X):
fn(X, **self.params)
return X
def fit_transform(self, X, y=None):
return self.fit(X, y).transform(X)
Wrapper.__name__ = cls.__name__
Wrapper.__doc__ = fn.__doc__
Wrapper.__annotations__ = {**cls.__annotations__, **params}
return Wrapper
return class_decorator
that can be used like
@scop(sc.pp.filter_cells)
@dataclass
class FilterCells:
pass
fc = FilterCells(min_genes=200)
fc
Of note:
- we have
BaseEstimatorrepr... (stop..class_decorator..Wrapper() fc.min_genesresults inAttributeError, so we lose access to dataclass fields.
Original-ish Attempt
def scop(fn):
params = get_func_params(fn, drop_self=False)
class Wrapper(BaseEstimator, TransformerMixin):
@wraps(fn)
def __init__(self, *args, **kwargs) -> None:
super().__init__()
print('ARGS', args)
print('KWARGS', kwargs)
print('SIGNATURE', inspect.signature(sc.pp.filter_cells))
print('PARAMS', params)
@wraps(fn)
def fit(self, X, y=None):
return self
@wraps(fn)
def transform(self, X):
fn(self.data, X)
return self.data
@wraps(fn)
def fit_transform(self, X, y=None):
return self.fit(X, y).transform(X)
Wrapper.__init__.__doc__ = fn.__doc__
Wrapper.__init__.__annotations__ = fn.__annotations__
# Wrapper = inspect.signature(sc.pp.filter_cells)
Wrapper.fit.__doc__ = fn.__doc__
Wrapper.fit.__annotations__ = fn.__annotations__
Wrapper.transform.__doc__ = fn.__doc__
Wrapper.transform.__annotations__ = fn.__annotations__
Wrapper.fit_transform.__doc__ = fn.__doc__
Wrapper.fit_transform.__annotations__ = fn.__annotations__
Wrapper.__name__ = fn.__name__
Wrapper.__doc__ = fn.__doc__
Wrapper.__annotations__ = fn.__annotations__
# methods = {
# '__init__': Wrapper.__init__,
# 'fit': Wrapper.fit,
# 'transform': Wrapper.transform,
# 'fit_transform': Wrapper.fit_transform,
# }
# Wrapper = type(fn.__name__, (BaseEstimator, TransformerMixin), methods)
return Wrapper
but notice that it prints ARGS (<class '__main__.FilterCells'>,) as this decorator gets called over the class not on class initialization.

