Python 3 type hint for a factory method on a base class returning a child class instance

20.4k Views Asked by At

Let's say I have two classes Base and Child with a factory method in Base. The factory method calls another classmethod which may be overriden by Base's child classes.

class Base(object):
    @classmethod
    def create(cls, *args: Tuple) -> 'Base':
        value = cls._prepare(*args)
        return cls(value)

    @classmethod
    def _prepare(cls, *args: Tuple) -> Any:
        return args[0] if args else None

    def __init__(self, value: Any) -> None:
        self.value = value


class Child(Base):
    @classmethod
    def _prepare(cls, *args: Tuple) -> Any:
        return args[1] if len(args) > 1 else None

    def method_not_present_on_base(self) -> None:
        pass

Is there a way to annotate Base.create so that a static type checker could infer that Base.create() returned an instance of Base and Child.create() returned an instance of Child, so that the following example would pass static analysis?

base = Base.create(1)
child = Child.create(2, 3)
child.method_not_present_on_base()

In the above example a static type checker would rightfully complain that the method_not_present_on_base is, well, not present on the Base class.


I thought about turning Base into a generic class and having the child classes specify themselves as type arguments, i.e. bringing the CRTP to Python.

T = TypeVar('T')

class Base(Generic[T]):
    @classmethod
    def create(cls, *args: Tuple) -> T: ...

class Child(Base['Child']): ...

But this feels rather unpythonic with CRTP coming from C++ and all...

2

There are 2 best solutions below

6
On BEST ANSWER

It is indeed possible: the feature is called TypeVar with Generic Self (though this is slightly misleading because we're using this for a class method in this case). I believe it behaves roughly equivalently to the "CRTP" technique you linked to (though I'm not a C++ expert so can't say for certain).

In any case, you would declare your base and child classes like so:

from typing import TypeVar, Type, Tuple

T = TypeVar('T', bound='Base')

class Base:
    @classmethod
    def create(cls: Type[T], *args: Tuple[Any]) -> T: ...

class Child(Base):
    @classmethod
    def create(cls, *args: Tuple[Any]) -> 'Child': ...

Note that:

  1. We don't need to make the class itself generic since we only need a generic function
  2. Setting the TypeVar's bound to 'Base' is strictly speaking optional, but is probably a good idea: this way, the callers of your base class/subclasses will at least be able to call methods defined in the base class even if you don't know exactly which subclass you're dealing with.
  3. We can omit the annotation on cls for the child definition.
3
On

Python 3.11 now has a Self type, in case anyone else also stumbled on this old question. mypy support was added in version 1.0, available on PyPI since Feb. 2023.

https://docs.python.org/3/library/typing.html#typing.Self

It's a DRY way of annotating the return value of class method factories:

from typing import Self
from collections import defaultdict

class NestedDefaultDict(defaultdict):
    def __init__(self, *args, **kwargs):
      super().__init__(NestedDefaultDict, *args, **kwargs)
      
    @classmethod  
    def from_nested_dict(cls, dict_) -> Self:
        inst = NestedDefaultDict()
        for key, val in dict_.items():
            inst[key] = cls.from_nested_dict(val) if isinstance(val, dict) else val
        return inst

Self's great for method chaining APIs too. I've copied the Self-less example below from James Murphy's video, and just added the 3 annotations.


# https://github.com/mCodingLLC/VideosSampleCode/blob/master/videos/095_method_chaining_and_self/method_chaining_and_self.py

class Player:
    def __init__(self, name, position, fatigue=0):
        self.name = name
        self.position = position
        self.fatigue = fatigue

    def draw(self) -> Self:
        print(f"drawing {self.name} to screen at {self.position}")
        return self

    def move(self, delta) -> Self:
        self.position += delta
        self.fatigue += 1
        return self

    def rest(self) -> Self:
        self.fatigue = 0
        return self