# Test cases related to the functools.singledispatch decorator # Most of these tests are marked as xfails because mypyc doesn't support singledispatch yet # (These tests will be re-enabled when mypyc supports singledispatch) [case testSpecializedImplementationUsed] from functools import singledispatch @singledispatch def fun(arg) -> bool: return False @fun.register def fun_specialized(arg: str) -> bool: return True def test_specialize() -> None: assert fun('a') assert not fun(3) [case testSubclassesOfExpectedTypeUseSpecialized] from functools import singledispatch class A: pass class B(A): pass @singledispatch def fun(arg) -> bool: return False @fun.register def fun_specialized(arg: A) -> bool: return True def test_specialize() -> None: assert fun(B()) assert fun(A()) [case testSuperclassImplementationNotUsedWhenSubclassHasImplementation] from functools import singledispatch class A: pass class B(A): pass @singledispatch def fun(arg) -> bool: # shouldn't be using this assert False @fun.register def fun_specialized(arg: A) -> bool: return False @fun.register def fun_specialized2(arg: B) -> bool: return True def test_specialize() -> None: assert fun(B()) assert not fun(A()) [case testMultipleUnderscoreFunctionsIsntError] from functools import singledispatch @singledispatch def fun(arg) -> str: return 'default' @fun.register def _(arg: str) -> str: return 'str' @fun.register def _(arg: int) -> str: return 'int' # extra function to make sure all 3 underscore functions aren't treated as one OverloadedFuncDef def a(b): pass @fun.register def _(arg: list) -> str: return 'list' def test_singledispatch() -> None: assert fun(0) == 'int' assert fun('a') == 'str' assert fun([1, 2]) == 'list' assert fun({'a': 'b'}) == 'default' [case testCanRegisterCompiledClasses] from functools import singledispatch class A: pass @singledispatch def fun(arg) -> bool: return False @fun.register def fun_specialized(arg: A) -> bool: return True def test_singledispatch() -> None: assert fun(A()) assert not fun(1) [case testTypeUsedAsArgumentToRegister] from functools import singledispatch @singledispatch def fun(arg) -> bool: return False @fun.register(int) def fun_specialized(arg) -> bool: return True def test_singledispatch() -> None: assert fun(1) assert not fun('a') [case testUseRegisterAsAFunction] from functools import singledispatch @singledispatch def fun(arg) -> bool: return False def fun_specialized_impl(arg) -> bool: return True fun.register(int, fun_specialized_impl) def test_singledispatch() -> None: assert fun(0) assert not fun('a') [case testRegisterDoesntChangeFunction] from functools import singledispatch @singledispatch def fun(arg) -> bool: return False @fun.register(int) def fun_specialized(arg) -> bool: return True def test_singledispatch() -> None: assert fun_specialized('a') # TODO: turn this into a mypy error [case testNoneIsntATypeWhenUsedAsArgumentToRegister] from functools import singledispatch @singledispatch def fun(arg) -> bool: return False try: @fun.register def fun_specialized(arg: None) -> bool: return True except TypeError: pass [case testRegisteringTheSameFunctionSeveralTimes] from functools import singledispatch @singledispatch def fun(arg) -> bool: return False @fun.register(int) @fun.register(str) def fun_specialized(arg) -> bool: return True def test_singledispatch() -> None: assert fun(0) assert fun('a') assert not fun([1, 2]) [case testTypeIsAnABC] from functools import singledispatch from collections.abc import Mapping @singledispatch def fun(arg) -> bool: return False @fun.register def fun_specialized(arg: Mapping) -> bool: return True def test_singledispatch() -> None: assert not fun(1) assert fun({'a': 'b'}) [case testSingleDispatchMethod-xfail] from functools import singledispatchmethod class A: @singledispatchmethod def fun(self, arg) -> str: return 'default' @fun.register def fun_int(self, arg: int) -> str: return 'int' @fun.register def fun_str(self, arg: str) -> str: return 'str' def test_singledispatchmethod() -> None: x = A() assert x.fun(5) == 'int' assert x.fun('a') == 'str' assert x.fun([1, 2]) == 'default' [case testSingleDispatchMethodWithOtherDecorator-xfail] from functools import singledispatchmethod class A: @singledispatchmethod @staticmethod def fun(arg) -> str: return 'default' @fun.register @staticmethod def fun_int(arg: int) -> str: return 'int' @fun.register @staticmethod def fun_str(arg: str) -> str: return 'str' def test_singledispatchmethod() -> None: x = A() assert x.fun(5) == 'int' assert x.fun('a') == 'str' assert x.fun([1, 2]) == 'default' [case testSingledispatchTreeSumAndEqual] from functools import singledispatch class Tree: pass class Leaf(Tree): pass class Node(Tree): def __init__(self, value: int, left: Tree, right: Tree) -> None: self.value = value self.left = left self.right = right @singledispatch def calc_sum(x: Tree) -> int: raise TypeError('invalid type for x') @calc_sum.register def _(x: Leaf) -> int: return 0 @calc_sum.register def _(x: Node) -> int: return x.value + calc_sum(x.left) + calc_sum(x.right) @singledispatch def equal(to_compare: Tree, known: Tree) -> bool: raise TypeError('invalid type for x') @equal.register def _(to_compare: Leaf, known: Tree) -> bool: return isinstance(known, Leaf) @equal.register def _(to_compare: Node, known: Tree) -> bool: if isinstance(known, Node): if to_compare.value != known.value: return False else: return equal(to_compare.left, known.left) and equal(to_compare.right, known.right) return False def build(n: int) -> Tree: if n == 0: return Leaf() return Node(n, build(n - 1), build(n - 1)) def test_sum_and_equal(): tree = build(5) tree2 = build(5) tree2.right.right.right.value = 10 assert calc_sum(tree) == 57 assert calc_sum(tree2) == 65 assert equal(tree, tree) assert not equal(tree, tree2) tree3 = build(4) assert not equal(tree, tree3) [case testSimulateMypySingledispatch] from functools import singledispatch from mypy_extensions import trait from typing import Iterator, Union, TypeVar, Any, List, Type # based on use of singledispatch in stubtest.py class Error: def __init__(self, msg: str) -> None: self.msg = msg @trait class Node: pass class MypyFile(Node): pass class TypeInfo(Node): pass @trait class SymbolNode(Node): pass @trait class Expression(Node): pass class TypeVarLikeExpr(SymbolNode, Expression): pass class TypeVarExpr(TypeVarLikeExpr): pass class TypeAlias(SymbolNode): pass class Missing: pass MISSING = Missing() T = TypeVar("T") MaybeMissing = Union[T, Missing] @singledispatch def verify(stub: Node, a: MaybeMissing[Any], b: List[str]) -> Iterator[Error]: yield Error('unknown node type') @verify.register(MypyFile) def verify_mypyfile(stub: MypyFile, a: MaybeMissing[int], b: List[str]) -> Iterator[Error]: if isinstance(a, Missing): yield Error("shouldn't be missing") return if not isinstance(a, int): # this check should be unnecessary because of the type signature and the previous check, # but stubtest.py has this check yield Error("should be an int") return yield from verify(TypeInfo(), str, ['abc', 'def']) @verify.register(TypeInfo) def verify_typeinfo(stub: TypeInfo, a: MaybeMissing[Type[Any]], b: List[str]) -> Iterator[Error]: yield Error('in TypeInfo') yield Error('hello') @verify.register(TypeVarExpr) def verify_typevarexpr(stub: TypeVarExpr, a: MaybeMissing[Any], b: List[str]) -> Iterator[Error]: if False: yield None def verify_list(stub, a, b) -> List[str]: """Helper function that converts iterator of errors to list of messages""" return list(err.msg for err in verify(stub, a, b)) def test_verify() -> None: assert verify_list(TypeAlias(), 'a', ['a', 'b']) == ['unknown node type'] assert verify_list(MypyFile(), MISSING, ['a', 'b']) == ["shouldn't be missing"] assert verify_list(MypyFile(), 5, ['a', 'b']) == ['in TypeInfo', 'hello'] assert verify_list(TypeInfo(), str, ['a', 'b']) == ['in TypeInfo', 'hello'] assert verify_list(TypeVarExpr(), 'a', ['x', 'y']) == [] [case testArgsInRegisteredImplNamedDifferentlyFromMainFunction] from functools import singledispatch @singledispatch def f(a) -> bool: return False @f.register def g(b: int) -> bool: return True def test_singledispatch(): assert f(5) assert not f('a') [case testKeywordArguments] from functools import singledispatch @singledispatch def f(arg, *, kwarg: int = 0) -> int: return kwarg + 10 @f.register def g(arg: int, *, kwarg: int = 5) -> int: return kwarg - 10 def test_keywords(): assert f('a') == 10 assert f('a', kwarg=3) == 13 assert f('a', kwarg=7) == 17 assert f(1) == -5 assert f(1, kwarg=4) == -6 assert f(1, kwarg=6) == -4 [case testGeneratorAndMultipleTypesOfIterable] from functools import singledispatch from typing import * @singledispatch def f(arg: Any) -> Iterable[int]: yield 1 @f.register def g(arg: str) -> Iterable[int]: return [0] def test_iterables(): assert f(1) != [1] assert list(f(1)) == [1] assert f('a') == [0] [case testRegisterUsedAtSameTimeAsOtherDecorators] from functools import singledispatch from typing import TypeVar class A: pass class B: pass T = TypeVar('T') def decorator(f: T) -> T: return f @singledispatch def f(arg) -> int: return 0 @f.register @decorator def h(arg: str) -> int: return 2 def test_singledispatch(): assert f(1) == 0 assert f('a') == 2 [case testDecoratorModifiesFunction] from functools import singledispatch from typing import Callable, Any class A: pass def decorator(f: Callable[[Any], int]) -> Callable[[Any], int]: def wrapper(x) -> int: return f(x) * 7 return wrapper @singledispatch def f(arg) -> int: return 10 @f.register @decorator def h(arg: str) -> int: return 5 def test_singledispatch(): assert f('a') == 35 assert f(A()) == 10 [case testMoreSpecificTypeBeforeLessSpecificType] from functools import singledispatch class A: pass class B(A): pass @singledispatch def f(arg) -> str: return 'default' @f.register def g(arg: B) -> str: return 'b' @f.register def h(arg: A) -> str: return 'a' def test_singledispatch(): assert f(B()) == 'b' assert f(A()) == 'a' assert f(5) == 'default' [case testMultipleRelatedClassesBeingRegistered] from functools import singledispatch class A: pass class B(A): pass class C(B): pass @singledispatch def f(arg) -> str: return 'default' @f.register def _(arg: A) -> str: return 'a' @f.register def _(arg: C) -> str: return 'c' @f.register def _(arg: B) -> str: return 'b' def test_singledispatch(): assert f(A()) == 'a' assert f(B()) == 'b' assert f(C()) == 'c' assert f(1) == 'default' [case testRegisteredImplementationsInDifferentFiles] from other_a import f, A, B, C @f.register def a(arg: A) -> int: return 2 @f.register def _(arg: C) -> int: return 3 def test_singledispatch(): assert f(B()) == 1 assert f(A()) == 2 assert f(C()) == 3 assert f(1) == 0 [file other_a.py] from functools import singledispatch class A: pass class B(A): pass class C(B): pass @singledispatch def f(arg) -> int: return 0 @f.register def g(arg: B) -> int: return 1 [case testOrderCanOnlyBeDeterminedFromMRONotIsinstanceChecks] from mypy_extensions import trait from functools import singledispatch @trait class A: pass @trait class B: pass class AB(A, B): pass class BA(B, A): pass @singledispatch def f(arg) -> str: return "default" pass @f.register def fa(arg: A) -> str: return "a" @f.register def fb(arg: B) -> str: return "b" def test_singledispatch(): assert f(AB()) == "a" assert f(BA()) == "b" [case testCallingFunctionBeforeAllImplementationsRegistered] from functools import singledispatch class A: pass class B(A): pass @singledispatch def f(arg) -> str: return 'default' assert f(A()) == 'default' assert f(B()) == 'default' assert f(1) == 'default' @f.register def g(arg: A) -> str: return 'a' assert f(A()) == 'a' assert f(B()) == 'a' assert f(1) == 'default' @f.register def _(arg: B) -> str: return 'b' assert f(A()) == 'a' assert f(B()) == 'b' assert f(1) == 'default' [case testDynamicallyRegisteringFunctionFromInterpretedCode] from functools import singledispatch class A: pass class B(A): pass class C(B): pass class D(C): pass @singledispatch def f(arg) -> str: return "default" @f.register def _(arg: B) -> str: return 'b' [file register_impl.py] from native import f, A, B, C @f.register(A) def a(arg) -> str: return 'a' @f.register def c(arg: C) -> str: return 'c' [file driver.py] from native import f, A, B, C from register_impl import a, c # We need a custom driver here because register_impl has to be run before we test this (so that the # additional implementations are registered) assert f(C()) == 'c' assert f(A()) == 'a' assert f(B()) == 'b' assert a(C()) == 'a' assert c(A()) == 'c' [case testMalformedDynamicRegisterCall] from functools import singledispatch @singledispatch def f(arg) -> None: pass [file register.py] from native import f from testutil import assertRaises with assertRaises(TypeError, 'Invalid first argument to `register()`'): @f.register def _(): pass [file driver.py] import register [case testCacheClearedWhenNewFunctionRegistered] from functools import singledispatch @singledispatch def f(arg) -> str: return 'default' [file register.py] from native import f class A: pass class B: pass class C: pass # annotated function assert f(A()) == 'default' @f.register def _(arg: A) -> str: return 'a' assert f(A()) == 'a' # type passed as argument assert f(B()) == 'default' @f.register(B) def _(arg: B) -> str: return 'b' assert f(B()) == 'b' # 2 argument form assert f(C()) == 'default' def c(arg) -> str: return 'c' f.register(C, c) assert f(C()) == 'c' [file driver.py] import register