Require decorated function to accept argument matching bound `TypeVar` without narrowing to that type

173 Views Asked by At

If I define my decorator like this

T = TypeVar('T', bound=Event)

def register1(evtype: Type[T]) -> Callable[[Callable[[T], None]], Callable[[T], None]]:
    def decorator(handler):
        # register handler for event type
        return handler
    return decorator

I get a proper error if I use it on the wrong function:

class A(Event):
    pass

class B(Event):
    pass

@register1(A) # Argument of type "(ev: B) -> None" cannot be assigned to parameter of type "(A) -> None"
def handler1_1(ev: B):
    pass

However, it does not work if I apply the decorator multiple times:

@register1(A) # Argument of type "(B) -> None" cannot be assigned to parameter of type "(A) -> None"
@register1(B)
def handler1_3(ev: A|B):
    pass

I kind of want the decorators to build up a Union of allowed/required argument types.

I think ParamSpec is the way to solve it, but how can I use ParamSpec to not overwrite the argument type but also require that the argument type matches the type that is in the decorator argument?

Using ParamSpec does not result in any type error:

P = ParamSpec("P")

def register2(evtype: Type[T]) -> Callable[[Callable[P, None]], Callable[P, None]]:
    def decorator(handler):
        # ...
        return handler
    return decorator

@register2(A) # This should be an error
def handler2_1(ev: B):
    pass

If I add another TypeVar and use a Union it does work for the double-decorated and even triple decorated function, but not or single decorated functions.

T2 = TypeVar('T2')

def register3(evtype: Type[T]) -> Callable[[Callable[[Union[T,T2]], None]], Callable[[Union[T,T2]], None]]:
    def decorator(handler):
        # ...
        return handler
    return decorator

# Expected error:
@register3(A) # Argument of type "(ev: B) -> None" cannot be assigned to parameter of type "(A | T2@register3) -> None"
def handler3_1(ev: B):
    pass

# Wrong error:
@register3(A) # Argument of type "(ev: A) -> None" cannot be assigned to parameter of type "(A | T2@register3) -> None"
def handler3_2(ev: A):
    pass

# Works fine
@register3(A)
@register3(B)
def handler3_3(ev: A|B):
    pass

While writing this question, I came the solution closer and closer. And I will provide my own solution in an Answer.

However, I'm interested if there are better ways to solve this.

2

There are 2 best solutions below

0
Chris Fu On BEST ANSWER

Transported from https://github.com/microsoft/pyright/discussions/7404 --

from __future__ import annotations

from typing import Any, Callable, Protocol, TypeVar, overload

T_co = TypeVar("T_co", covariant=True)

T0 = TypeVar("T0")
T1 = TypeVar('T1')

class RegisterResult(Protocol[T_co]):
    @overload
    def __call__(self, handler: Callable[[T_co | T1], None]) -> Callable[[T_co | T1], None]: ...

    @overload
    def __call__(self, handler: Callable[[T_co], None]) -> Callable[[T_co], None]: ...

def register(evtype: type[T0]) -> RegisterResult[T0]:
    def decorator(handler: Any) -> Any:
        return handler
    
    return decorator

class A: ...
class B: ...
class C: ...

@register(A)
def handle_a(ev: A): ...

handle_a(A())

@register(A)
@register(B)
# ... Should support infinite amount of @register calls
def handle_ab(ev: A|B): ... 

handle_ab(A())
handle_ab(B())

#Expected error cases because of wrong types:
@register(A)
def handle_b(ev: B): ... 
handle_a(B())
handle_ab(C())

Note that the code above is working with latest Pyright (v1.1.353) and may cease working in a future version based on how Pyright deals with compatibilities of overloaded functions. And it is NOT fully working with latest Mypy (v1.9.0) as I checked.

1
MaPePeR On

This is only a partial solution. It only works for the register side, but does not typecheck the calling side properly

By adding the case for when the decorated function only accepts a single argument to the decorator argument using a Union, I don't get any unexpected errors from pyright anymore:

def register4(evtype: Type[T]) -> Callable[[Union[Callable[[T|T2], None],Callable[[T], None]]], Callable[[T|T2], None]]:
    def decorator(handler):
        # ...
        return handler
    return decorator

#Expected errors
@register4(A) # Argument of type "(ev: B) -> None" cannot be assigned to parameter of type "((A | T2@register4) -> None) | ((A) -> None)"
def handler4_1(ev: B):
    pass

@register4(A)
def handler4_2(ev: A):
    pass

@register4(A)
@register4(B)
#@register4(C)
def handler4_3(ev: A|B|C):
    pass

As was pointed out in the comments:

handler4_2(B())

does not result in an error, even though it should.

I tried to fix it by splitting the Union into @overload declarations, but that doesn't work:

@overload
def register4(evtype: Type[T]) -> Callable[[Callable[[T | T2], None]],Callable[[T | T2], None]]: ...


@overload
# Overload 2 for "register2" will never be used because its parameters overlap overload 1
def register4(evtype: Type[T]) -> Callable[[Callable[[T], None]], Callable[[T], None]]: ...

I think it ignores the second overlap, because it thinks they overlap, but we've already seen, that both cases are interpreted differently in register3. And the behavior on the examples also changes if I swap the two declarations. So this might be a pyright bug.

This is with pyright 1.1.352.