from __future__ import annotations from typing import Final, NamedTuple, Sequence, TypeVar, Union from typing_extensions import TypeAlias as _TypeAlias from mypy.messages import format_type from mypy.nodes import ARG_POS, Argument, Block, ClassDef, Context, SymbolTable, TypeInfo, Var from mypy.options import Options from mypy.plugin import CheckerPluginInterface, FunctionContext, MethodContext, MethodSigContext from mypy.plugins.common import add_method_to_class from mypy.subtypes import is_subtype from mypy.types import ( AnyType, CallableType, FunctionLike, Instance, NoneType, Overloaded, Type, TypeOfAny, get_proper_type, ) class SingledispatchTypeVars(NamedTuple): return_type: Type fallback: CallableType class RegisterCallableInfo(NamedTuple): register_type: Type singledispatch_obj: Instance SINGLEDISPATCH_TYPE: Final = "functools._SingleDispatchCallable" SINGLEDISPATCH_REGISTER_METHOD: Final = f"{SINGLEDISPATCH_TYPE}.register" SINGLEDISPATCH_CALLABLE_CALL_METHOD: Final = f"{SINGLEDISPATCH_TYPE}.__call__" def get_singledispatch_info(typ: Instance) -> SingledispatchTypeVars | None: if len(typ.args) == 2: return SingledispatchTypeVars(*typ.args) # type: ignore[arg-type] return None T = TypeVar("T") def get_first_arg(args: list[list[T]]) -> T | None: """Get the element that corresponds to the first argument passed to the function""" if args and args[0]: return args[0][0] return None REGISTER_RETURN_CLASS: Final = "_SingleDispatchRegisterCallable" REGISTER_CALLABLE_CALL_METHOD: Final = f"functools.{REGISTER_RETURN_CLASS}.__call__" def make_fake_register_class_instance( api: CheckerPluginInterface, type_args: Sequence[Type] ) -> Instance: defn = ClassDef(REGISTER_RETURN_CLASS, Block([])) defn.fullname = f"functools.{REGISTER_RETURN_CLASS}" info = TypeInfo(SymbolTable(), defn, "functools") obj_type = api.named_generic_type("builtins.object", []).type info.bases = [Instance(obj_type, [])] info.mro = [info, obj_type] defn.info = info func_arg = Argument(Var("name"), AnyType(TypeOfAny.implementation_artifact), None, ARG_POS) add_method_to_class(api, defn, "__call__", [func_arg], NoneType()) return Instance(info, type_args) PluginContext: _TypeAlias = Union[FunctionContext, MethodContext] def fail(ctx: PluginContext, msg: str, context: Context | None) -> None: """Emit an error message. This tries to emit an error message at the location specified by `context`, falling back to the location specified by `ctx.context`. This is helpful when the only context information about where you want to put the error message may be None (like it is for `CallableType.definition`) and falling back to the location of the calling function is fine.""" # TODO: figure out if there is some more reliable way of getting context information, so this # function isn't necessary if context is not None: err_context = context else: err_context = ctx.context ctx.api.fail(msg, err_context) def create_singledispatch_function_callback(ctx: FunctionContext) -> Type: """Called for functools.singledispatch""" func_type = get_proper_type(get_first_arg(ctx.arg_types)) if isinstance(func_type, CallableType): if len(func_type.arg_kinds) < 1: fail( ctx, "Singledispatch function requires at least one argument", func_type.definition ) return ctx.default_return_type elif not func_type.arg_kinds[0].is_positional(star=True): fail( ctx, "First argument to singledispatch function must be a positional argument", func_type.definition, ) return ctx.default_return_type # singledispatch returns an instance of functools._SingleDispatchCallable according to # typeshed singledispatch_obj = get_proper_type(ctx.default_return_type) assert isinstance(singledispatch_obj, Instance) singledispatch_obj.args += (func_type,) return ctx.default_return_type def singledispatch_register_callback(ctx: MethodContext) -> Type: """Called for functools._SingleDispatchCallable.register""" assert isinstance(ctx.type, Instance) # TODO: check that there's only one argument first_arg_type = get_proper_type(get_first_arg(ctx.arg_types)) if isinstance(first_arg_type, (CallableType, Overloaded)) and first_arg_type.is_type_obj(): # HACK: We received a class as an argument to register. We need to be able # to access the function that register is being applied to, and the typeshed definition # of register has it return a generic Callable, so we create a new # SingleDispatchRegisterCallable class, define a __call__ method, and then add a # plugin hook for that. # is_subtype doesn't work when the right type is Overloaded, so we need the # actual type register_type = first_arg_type.items[0].ret_type type_args = RegisterCallableInfo(register_type, ctx.type) register_callable = make_fake_register_class_instance(ctx.api, type_args) return register_callable elif isinstance(first_arg_type, CallableType): # TODO: do more checking for registered functions register_function(ctx, ctx.type, first_arg_type, ctx.api.options) # The typeshed stubs for register say that the function returned is Callable[..., T], even # though the function returned is the same as the one passed in. We return the type of the # function so that mypy can properly type check cases where the registered function is used # directly (instead of through singledispatch) return first_arg_type # fallback in case we don't recognize the arguments return ctx.default_return_type def register_function( ctx: PluginContext, singledispatch_obj: Instance, func: Type, options: Options, register_arg: Type | None = None, ) -> None: """Register a function""" func = get_proper_type(func) if not isinstance(func, CallableType): return metadata = get_singledispatch_info(singledispatch_obj) if metadata is None: # if we never added the fallback to the type variables, we already reported an error, so # just don't do anything here return dispatch_type = get_dispatch_type(func, register_arg) if dispatch_type is None: # TODO: report an error here that singledispatch requires at least one argument # (might want to do the error reporting in get_dispatch_type) return fallback = metadata.fallback fallback_dispatch_type = fallback.arg_types[0] if not is_subtype(dispatch_type, fallback_dispatch_type): fail( ctx, "Dispatch type {} must be subtype of fallback function first argument {}".format( format_type(dispatch_type, options), format_type(fallback_dispatch_type, options) ), func.definition, ) return return def get_dispatch_type(func: CallableType, register_arg: Type | None) -> Type | None: if register_arg is not None: return register_arg if func.arg_types: return func.arg_types[0] return None def call_singledispatch_function_after_register_argument(ctx: MethodContext) -> Type: """Called on the function after passing a type to register""" register_callable = ctx.type if isinstance(register_callable, Instance): type_args = RegisterCallableInfo(*register_callable.args) # type: ignore[arg-type] func = get_first_arg(ctx.arg_types) if func is not None: register_function( ctx, type_args.singledispatch_obj, func, ctx.api.options, type_args.register_type ) # see call to register_function in the callback for register return func return ctx.default_return_type def call_singledispatch_function_callback(ctx: MethodSigContext) -> FunctionLike: """Called for functools._SingleDispatchCallable.__call__""" if not isinstance(ctx.type, Instance): return ctx.default_signature metadata = get_singledispatch_info(ctx.type) if metadata is None: return ctx.default_signature return metadata.fallback