How can I use MyPy to overload the __init__ method to adjust a getter's return value?

697 Views Asked by At

Let's say I have a class like this (pseudo-code, please ignore the odd db structure):

class Blog():
    title = StringProperty()
    comments = StringProperty(repeated=True)

I want to type check StringProperty such that Blog().title returns a str type, and Blog().comments returns a List[str] type. MyPy mentions that something like this is possible by dynamically typing the __init__ method.

Here's what I've tried:

U = TypeVar('U', bound=StringProperty)
V = TypeVar('V', bound=StringProperty)

class StringProperty(Property[T]):
    @overload
    def __init__(self: StringProperty[U], repeated: Literal[False]=False, **kwargs) -> None: ...

    @overload
    def __init__(self: StringProperty[V], repeated: Literal[True]=True, **kwargs) -> None: ...

    @overload
    def __get__(self: StringProperty[U], instance, cls) -> str: ...
    
    @overload
    def __get__(self: StringProperty[V], instance, cls) -> List[str]: ...
    
    def __set__(self, instance, value: Optional[Union[str, List[str]]]) -> None: ...

However, this throws an error that the second __get__ type signature will never be matched. How can I set MyPy to know the return value of the StringProperty.__get__ method dynamically by whether the repeated property is True or False?

2

There are 2 best solutions below

0
On

I also had to overload get in @Slix first example:

from __future__ import annotations
from typing import TypeVar, overload, Literal, Generic

_GetReturnT = TypeVar('_GetReturnT', str, list[str], str | list[str])

class StringProperty(Generic[_GetReturnT]):
    
    # Handles the default value case too.
    @overload
    def __init__(self: StringProperty[str], repeated: Literal[False]=False, **kwargs) -> None: ...

    @overload
    def __init__(self: StringProperty[list[str]], repeated: Literal[True], **kwargs) -> None: ...
    
    # Callers won't always pass a literal bool right at the call site. The bool
    # could come from somewhere far. Then we can't know what exactly get()
    # will return.
    @overload
    def __init__(self: StringProperty[str | list[str]], repeated: bool, **kwargs) -> None: ...
    
    def __init__(self, repeated: bool = False, **kwargs) -> None:
        self._repeated = repeated

    @overload
    def get(self: StringProperty[str]) -> str:
        ...

    @overload
    def get(self: StringProperty[list[str]]) -> list[str]:
        ...

    @overload
    def get(self: StringProperty[str | list[str]]) -> str | list[str]:
        ...

    def get(self) -> str | list[str]:
        if self._repeated:
            return ["Hello", "world!"]
        else:
            return "just one string"


default = StringProperty()  # StringProperty[str]
default_get = default.get()  # str

false_literal = StringProperty(repeated=False)  # StringProperty[str]
false_literal_get = false_literal.get()  # str

true_literal = StringProperty(repeated=True)  # StringProperty[list[str]]
true_literal_get = true_literal.get()  # list[str]

import random
some_bool = random.choice([True, False])  # bool
unknown_bool = StringProperty(repeated=some_bool)  # StringProperty[str | list[str]]
unknown_bool_get = unknown_bool.get()  # str | list[str]

reveal_locals()

# error: Value of type variable "_GetReturnT" of "StringProperty" cannot be "int"
#
# This error happens because we limited _GetReturnT's possible types in
# TypeVar(). If we didn't limit the types, users could accidentally refer to a
# type in an annotation that's impossible to instantiate.
def some_user_function(prop: StringProperty[int]) -> None:
    prop.get()
> venv/bin/mypy --version
mypy 0.991 (compiled: yes)
> venv/bin/mypy field.py
field.py:57: note: Revealed local types are:
field.py:57: note:     default: field.StringProperty[builtins.str]
field.py:57: note:     default_get: builtins.str
field.py:57: note:     false_literal: field.StringProperty[builtins.str]
field.py:57: note:     false_literal_get: builtins.str
field.py:57: note:     some_bool: builtins.bool
field.py:57: note:     true_literal: field.StringProperty[builtins.list[builtins.str]]
field.py:57: note:     true_literal_get: builtins.list[builtins.str]
field.py:57: note:     unknown_bool: field.StringProperty[Union[builtins.str, builtins.list[builtins.str]]]
field.py:57: note:     unknown_bool_get: Union[builtins.str, builtins.list[builtins.str]]
field.py:64: error: Value of type variable "_GetReturnT" of "StringProperty" cannot be "int"  [type-var]
field.py:65: error: Invalid self argument "StringProperty[int]" to attribute function "get" with type "Callable[[StringProperty[str]], str]"  [misc]
Found 2 errors in 1 file (checked 1 source file)
1
On

__init__ can be overloaded. self will become the given type.

TypeVar needs to become some kind of real type during type analysis. It can't stay as T or U or V. It must be filled in with a type like str or Literal["foo"].

from __future__ import annotations
from typing import TypeVar, overload, Literal, Generic

_GetReturnT = TypeVar('_GetReturnT', str, list[str], str | list[str])

class StringProperty(Generic[_GetReturnT]):
    
    # Handles the default value case too.
    @overload
    def __init__(self: StringProperty[str], repeated: Literal[False]=False, **kwargs) -> None: ...

    @overload
    def __init__(self: StringProperty[list[str]], repeated: Literal[True], **kwargs) -> None: ...
    
    # Callers won't always pass a literal bool right at the call site. The bool
    # could come from somewhere far. Then we can't know what exactly get()
    # will return.
    @overload
    def __init__(self: StringProperty[str | list[str]], repeated: bool, **kwargs) -> None: ...
    
    def __init__(self, repeated: bool = False, **kwargs) -> None:
        self._repeated = repeated

    def get(self) -> _GetReturnT:
        if self._repeated:
            return ["Hello", "world!"]
        else:
            return "just one string"


default = StringProperty()  # StringProperty[str]
default_get = default.get()  # str

false_literal = StringProperty(repeated=False)  # StringProperty[str]
false_literal_get = false_literal.get()  # str

true_literal = StringProperty(repeated=True)  # StringProperty[list[str]]
true_literal_get = true_literal.get()  # list[str]

import random
some_bool = random.choice([True, False])  # bool
unknown_bool = StringProperty(repeated=some_bool)  # StringProperty[str | list[str]]
unknown_bool_get = unknown_bool.get()  # str | list[str]

reveal_locals()

# error: Value of type variable "_GetReturnT" of "StringProperty" cannot be "int"
#
# This error happens because we limited _GetReturnT's possible types in
# TypeVar(). If we didn't limit the types, users could accidentally refer to a
# type in an annotation that's impossible to instantiate.
def some_user_function(prop: StringProperty[int]) -> None:
    prop.get()

Note that setting and reading self._repeated does not aid in typing here in any way. StringProperty gets its type only from what types were passed to the constructor. If someone runs false_literal._repeated = True, then false_literal.get() would return ["Hello", "world!"], and the typing information is wrong.

Using str or list[str] for StringProperty's type here was convenient. But the type can be less specific for weirder classes. Here we could've used Literal[True], Literal[False], and Literal[True] | Literal[False] to represent the quality of being repeated. Then get() would need overrides based on self to determine the return type.

_T = TypeVar('_T',
    Literal["NOT_REPEATED"],
    Literal["REPEATED"],
    Literal[MyEnum.AMAZING],
    Literal[MyEnum.KINDA_OK_I_GUESS])

# For brevity I don't show Unions in this example, but you'd
# need them for a class that works properly.
class StringProperty(Generic[_T]):
    @overload
    def __init__(self: StringProperty[Literal["NOT_REPEATED"]],
        repeated: Literal[False]) -> None: ...

    @overload
    def __init__(self: StringProperty[Literal["REPEATED"]],
        repeated: Literal[True]) -> None: ...

    def __init__(self, repeated: bool) -> None:
        self._repeated = repeated

    @overload
    def get(self: StringProperty[Literal["NOT_REPEATED"]]) -> str: ...

    @overload
    def get(self: StringProperty[Literal["REPEATED"]]) -> list[str]: ...

    def get(self) -> str | list[str]:
        if self._repeated:
            return ["Hello", "world!"]
        else:
            return "just one string"