"""Pattern checker. This file is conceptually part of TypeChecker.""" from __future__ import annotations from collections import defaultdict from typing import Final, NamedTuple import mypy.checker from mypy import message_registry from mypy.checkmember import analyze_member_access from mypy.expandtype import expand_type_by_instance from mypy.join import join_types from mypy.literals import literal_hash from mypy.maptype import map_instance_to_supertype from mypy.meet import narrow_declared_type from mypy.messages import MessageBuilder from mypy.nodes import ARG_POS, Context, Expression, NameExpr, TypeAlias, TypeInfo, Var from mypy.options import Options from mypy.patterns import ( AsPattern, ClassPattern, MappingPattern, OrPattern, Pattern, SequencePattern, SingletonPattern, StarredPattern, ValuePattern, ) from mypy.plugin import Plugin from mypy.subtypes import is_subtype from mypy.typeops import ( coerce_to_literal, make_simplified_union, try_getting_str_literals_from_type, tuple_fallback, ) from mypy.types import ( AnyType, Instance, LiteralType, NoneType, ProperType, TupleType, Type, TypedDictType, TypeOfAny, UninhabitedType, UnionType, get_proper_type, ) from mypy.typevars import fill_typevars from mypy.visitor import PatternVisitor self_match_type_names: Final = [ "builtins.bool", "builtins.bytearray", "builtins.bytes", "builtins.dict", "builtins.float", "builtins.frozenset", "builtins.int", "builtins.list", "builtins.set", "builtins.str", "builtins.tuple", ] non_sequence_match_type_names: Final = ["builtins.str", "builtins.bytes", "builtins.bytearray"] # For every Pattern a PatternType can be calculated. This requires recursively calculating # the PatternTypes of the sub-patterns first. # Using the data in the PatternType the match subject and captured names can be narrowed/inferred. class PatternType(NamedTuple): type: Type # The type the match subject can be narrowed to rest_type: Type # The remaining type if the pattern didn't match captures: dict[Expression, Type] # The variables captured by the pattern class PatternChecker(PatternVisitor[PatternType]): """Pattern checker. This class checks if a pattern can match a type, what the type can be narrowed to, and what type capture patterns should be inferred as. """ # Some services are provided by a TypeChecker instance. chk: mypy.checker.TypeChecker # This is shared with TypeChecker, but stored also here for convenience. msg: MessageBuilder # Currently unused plugin: Plugin # The expression being matched against the pattern subject: Expression subject_type: Type # Type of the subject to check the (sub)pattern against type_context: list[Type] # Types that match against self instead of their __match_args__ if used as a class pattern # Filled in from self_match_type_names self_match_types: list[Type] # Types that are sequences, but don't match sequence patterns. Filled in from # non_sequence_match_type_names non_sequence_match_types: list[Type] options: Options def __init__( self, chk: mypy.checker.TypeChecker, msg: MessageBuilder, plugin: Plugin, options: Options ) -> None: self.chk = chk self.msg = msg self.plugin = plugin self.type_context = [] self.self_match_types = self.generate_types_from_names(self_match_type_names) self.non_sequence_match_types = self.generate_types_from_names( non_sequence_match_type_names ) self.options = options def accept(self, o: Pattern, type_context: Type) -> PatternType: self.type_context.append(type_context) result = o.accept(self) self.type_context.pop() return result def visit_as_pattern(self, o: AsPattern) -> PatternType: current_type = self.type_context[-1] if o.pattern is not None: pattern_type = self.accept(o.pattern, current_type) typ, rest_type, type_map = pattern_type else: typ, rest_type, type_map = current_type, UninhabitedType(), {} if not is_uninhabited(typ) and o.name is not None: typ, _ = self.chk.conditional_types_with_intersection( current_type, [get_type_range(typ)], o, default=current_type ) if not is_uninhabited(typ): type_map[o.name] = typ return PatternType(typ, rest_type, type_map) def visit_or_pattern(self, o: OrPattern) -> PatternType: current_type = self.type_context[-1] # # Check all the subpatterns # pattern_types = [] for pattern in o.patterns: pattern_type = self.accept(pattern, current_type) pattern_types.append(pattern_type) current_type = pattern_type.rest_type # # Collect the final type # types = [] for pattern_type in pattern_types: if not is_uninhabited(pattern_type.type): types.append(pattern_type.type) # # Check the capture types # capture_types: dict[Var, list[tuple[Expression, Type]]] = defaultdict(list) # Collect captures from the first subpattern for expr, typ in pattern_types[0].captures.items(): node = get_var(expr) capture_types[node].append((expr, typ)) # Check if other subpatterns capture the same names for i, pattern_type in enumerate(pattern_types[1:]): vars = {get_var(expr) for expr, _ in pattern_type.captures.items()} if capture_types.keys() != vars: self.msg.fail(message_registry.OR_PATTERN_ALTERNATIVE_NAMES, o.patterns[i]) for expr, typ in pattern_type.captures.items(): node = get_var(expr) capture_types[node].append((expr, typ)) captures: dict[Expression, Type] = {} for var, capture_list in capture_types.items(): typ = UninhabitedType() for _, other in capture_list: typ = join_types(typ, other) captures[capture_list[0][0]] = typ union_type = make_simplified_union(types) return PatternType(union_type, current_type, captures) def visit_value_pattern(self, o: ValuePattern) -> PatternType: current_type = self.type_context[-1] typ = self.chk.expr_checker.accept(o.expr) typ = coerce_to_literal(typ) narrowed_type, rest_type = self.chk.conditional_types_with_intersection( current_type, [get_type_range(typ)], o, default=current_type ) if not isinstance(get_proper_type(narrowed_type), (LiteralType, UninhabitedType)): return PatternType(narrowed_type, UnionType.make_union([narrowed_type, rest_type]), {}) return PatternType(narrowed_type, rest_type, {}) def visit_singleton_pattern(self, o: SingletonPattern) -> PatternType: current_type = self.type_context[-1] value: bool | None = o.value if isinstance(value, bool): typ = self.chk.expr_checker.infer_literal_expr_type(value, "builtins.bool") elif value is None: typ = NoneType() else: assert False narrowed_type, rest_type = self.chk.conditional_types_with_intersection( current_type, [get_type_range(typ)], o, default=current_type ) return PatternType(narrowed_type, rest_type, {}) def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: # # check for existence of a starred pattern # current_type = get_proper_type(self.type_context[-1]) if not self.can_match_sequence(current_type): return self.early_non_match() star_positions = [i for i, p in enumerate(o.patterns) if isinstance(p, StarredPattern)] star_position: int | None = None if len(star_positions) == 1: star_position = star_positions[0] elif len(star_positions) >= 2: assert False, "Parser should prevent multiple starred patterns" required_patterns = len(o.patterns) if star_position is not None: required_patterns -= 1 # # get inner types of original type # if isinstance(current_type, TupleType): inner_types = current_type.items size_diff = len(inner_types) - required_patterns if size_diff < 0: return self.early_non_match() elif size_diff > 0 and star_position is None: return self.early_non_match() else: inner_type = self.get_sequence_type(current_type, o) if inner_type is None: inner_type = self.chk.named_type("builtins.object") inner_types = [inner_type] * len(o.patterns) # # match inner patterns # contracted_new_inner_types: list[Type] = [] contracted_rest_inner_types: list[Type] = [] captures: dict[Expression, Type] = {} contracted_inner_types = self.contract_starred_pattern_types( inner_types, star_position, required_patterns ) for p, t in zip(o.patterns, contracted_inner_types): pattern_type = self.accept(p, t) typ, rest, type_map = pattern_type contracted_new_inner_types.append(typ) contracted_rest_inner_types.append(rest) self.update_type_map(captures, type_map) new_inner_types = self.expand_starred_pattern_types( contracted_new_inner_types, star_position, len(inner_types) ) rest_inner_types = self.expand_starred_pattern_types( contracted_rest_inner_types, star_position, len(inner_types) ) # # Calculate new type # new_type: Type rest_type: Type = current_type if isinstance(current_type, TupleType): narrowed_inner_types = [] inner_rest_types = [] for inner_type, new_inner_type in zip(inner_types, new_inner_types): ( narrowed_inner_type, inner_rest_type, ) = self.chk.conditional_types_with_intersection( new_inner_type, [get_type_range(inner_type)], o, default=new_inner_type ) narrowed_inner_types.append(narrowed_inner_type) inner_rest_types.append(inner_rest_type) if all(not is_uninhabited(typ) for typ in narrowed_inner_types): new_type = TupleType(narrowed_inner_types, current_type.partial_fallback) else: new_type = UninhabitedType() if all(is_uninhabited(typ) for typ in inner_rest_types): # All subpatterns always match, so we can apply negative narrowing rest_type = TupleType(rest_inner_types, current_type.partial_fallback) else: new_inner_type = UninhabitedType() for typ in new_inner_types: new_inner_type = join_types(new_inner_type, typ) new_type = self.construct_sequence_child(current_type, new_inner_type) if is_subtype(new_type, current_type): new_type, _ = self.chk.conditional_types_with_intersection( current_type, [get_type_range(new_type)], o, default=current_type ) else: new_type = current_type return PatternType(new_type, rest_type, captures) def get_sequence_type(self, t: Type, context: Context) -> Type | None: t = get_proper_type(t) if isinstance(t, AnyType): return AnyType(TypeOfAny.from_another_any, t) if isinstance(t, UnionType): items = [self.get_sequence_type(item, context) for item in t.items] not_none_items = [item for item in items if item is not None] if not_none_items: return make_simplified_union(not_none_items) else: return None if self.chk.type_is_iterable(t) and isinstance(t, (Instance, TupleType)): if isinstance(t, TupleType): t = tuple_fallback(t) return self.chk.iterable_item_type(t, context) else: return None def contract_starred_pattern_types( self, types: list[Type], star_pos: int | None, num_patterns: int ) -> list[Type]: """ Contracts a list of types in a sequence pattern depending on the position of a starred capture pattern. For example if the sequence pattern [a, *b, c] is matched against types [bool, int, str, bytes] the contracted types are [bool, Union[int, str], bytes]. If star_pos in None the types are returned unchanged. """ if star_pos is None: return types new_types = types[:star_pos] star_length = len(types) - num_patterns new_types.append(make_simplified_union(types[star_pos : star_pos + star_length])) new_types += types[star_pos + star_length :] return new_types def expand_starred_pattern_types( self, types: list[Type], star_pos: int | None, num_types: int ) -> list[Type]: """Undoes the contraction done by contract_starred_pattern_types. For example if the sequence pattern is [a, *b, c] and types [bool, int, str] are extended to length 4 the result is [bool, int, int, str]. """ if star_pos is None: return types new_types = types[:star_pos] star_length = num_types - len(types) + 1 new_types += [types[star_pos]] * star_length new_types += types[star_pos + 1 :] return new_types def visit_starred_pattern(self, o: StarredPattern) -> PatternType: captures: dict[Expression, Type] = {} if o.capture is not None: list_type = self.chk.named_generic_type("builtins.list", [self.type_context[-1]]) captures[o.capture] = list_type return PatternType(self.type_context[-1], UninhabitedType(), captures) def visit_mapping_pattern(self, o: MappingPattern) -> PatternType: current_type = get_proper_type(self.type_context[-1]) can_match = True captures: dict[Expression, Type] = {} for key, value in zip(o.keys, o.values): inner_type = self.get_mapping_item_type(o, current_type, key) if inner_type is None: can_match = False inner_type = self.chk.named_type("builtins.object") pattern_type = self.accept(value, inner_type) if is_uninhabited(pattern_type.type): can_match = False else: self.update_type_map(captures, pattern_type.captures) if o.rest is not None: mapping = self.chk.named_type("typing.Mapping") if is_subtype(current_type, mapping) and isinstance(current_type, Instance): mapping_inst = map_instance_to_supertype(current_type, mapping.type) dict_typeinfo = self.chk.lookup_typeinfo("builtins.dict") rest_type = Instance(dict_typeinfo, mapping_inst.args) else: object_type = self.chk.named_type("builtins.object") rest_type = self.chk.named_generic_type( "builtins.dict", [object_type, object_type] ) captures[o.rest] = rest_type if can_match: # We can't narrow the type here, as Mapping key is invariant. new_type = self.type_context[-1] else: new_type = UninhabitedType() return PatternType(new_type, current_type, captures) def get_mapping_item_type( self, pattern: MappingPattern, mapping_type: Type, key: Expression ) -> Type | None: mapping_type = get_proper_type(mapping_type) if isinstance(mapping_type, TypedDictType): with self.msg.filter_errors() as local_errors: result: Type | None = self.chk.expr_checker.visit_typeddict_index_expr( mapping_type, key ) has_local_errors = local_errors.has_new_errors() # If we can't determine the type statically fall back to treating it as a normal # mapping if has_local_errors: with self.msg.filter_errors() as local_errors: result = self.get_simple_mapping_item_type(pattern, mapping_type, key) if local_errors.has_new_errors(): result = None else: with self.msg.filter_errors(): result = self.get_simple_mapping_item_type(pattern, mapping_type, key) return result def get_simple_mapping_item_type( self, pattern: MappingPattern, mapping_type: Type, key: Expression ) -> Type: result, _ = self.chk.expr_checker.check_method_call_by_name( "__getitem__", mapping_type, [key], [ARG_POS], pattern ) return result def visit_class_pattern(self, o: ClassPattern) -> PatternType: current_type = get_proper_type(self.type_context[-1]) # # Check class type # type_info = o.class_ref.node if type_info is None: return PatternType(AnyType(TypeOfAny.from_error), AnyType(TypeOfAny.from_error), {}) if isinstance(type_info, TypeAlias) and not type_info.no_args: self.msg.fail(message_registry.CLASS_PATTERN_GENERIC_TYPE_ALIAS, o) return self.early_non_match() if isinstance(type_info, TypeInfo): any_type = AnyType(TypeOfAny.implementation_artifact) typ: Type = Instance(type_info, [any_type] * len(type_info.defn.type_vars)) elif isinstance(type_info, TypeAlias): typ = type_info.target else: if isinstance(type_info, Var) and type_info.type is not None: name = type_info.type.str_with_options(self.options) else: name = type_info.name self.msg.fail(message_registry.CLASS_PATTERN_TYPE_REQUIRED.format(name), o) return self.early_non_match() new_type, rest_type = self.chk.conditional_types_with_intersection( current_type, [get_type_range(typ)], o, default=current_type ) if is_uninhabited(new_type): return self.early_non_match() # TODO: Do I need this? narrowed_type = narrow_declared_type(current_type, new_type) # # Convert positional to keyword patterns # keyword_pairs: list[tuple[str | None, Pattern]] = [] match_arg_set: set[str] = set() captures: dict[Expression, Type] = {} if len(o.positionals) != 0: if self.should_self_match(typ): if len(o.positionals) > 1: self.msg.fail(message_registry.CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o) pattern_type = self.accept(o.positionals[0], narrowed_type) if not is_uninhabited(pattern_type.type): return PatternType( pattern_type.type, join_types(rest_type, pattern_type.rest_type), pattern_type.captures, ) captures = pattern_type.captures else: with self.msg.filter_errors() as local_errors: match_args_type = analyze_member_access( "__match_args__", typ, o, False, False, False, self.msg, original_type=typ, chk=self.chk, ) has_local_errors = local_errors.has_new_errors() if has_local_errors: self.msg.fail( message_registry.MISSING_MATCH_ARGS.format( typ.str_with_options(self.options) ), o, ) return self.early_non_match() proper_match_args_type = get_proper_type(match_args_type) if isinstance(proper_match_args_type, TupleType): match_arg_names = get_match_arg_names(proper_match_args_type) if len(o.positionals) > len(match_arg_names): self.msg.fail(message_registry.CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o) return self.early_non_match() else: match_arg_names = [None] * len(o.positionals) for arg_name, pos in zip(match_arg_names, o.positionals): keyword_pairs.append((arg_name, pos)) if arg_name is not None: match_arg_set.add(arg_name) # # Check for duplicate patterns # keyword_arg_set = set() has_duplicates = False for key, value in zip(o.keyword_keys, o.keyword_values): keyword_pairs.append((key, value)) if key in match_arg_set: self.msg.fail( message_registry.CLASS_PATTERN_KEYWORD_MATCHES_POSITIONAL.format(key), value ) has_duplicates = True elif key in keyword_arg_set: self.msg.fail( message_registry.CLASS_PATTERN_DUPLICATE_KEYWORD_PATTERN.format(key), value ) has_duplicates = True keyword_arg_set.add(key) if has_duplicates: return self.early_non_match() # # Check keyword patterns # can_match = True for keyword, pattern in keyword_pairs: key_type: Type | None = None with self.msg.filter_errors() as local_errors: if keyword is not None: key_type = analyze_member_access( keyword, narrowed_type, pattern, False, False, False, self.msg, original_type=new_type, chk=self.chk, ) else: key_type = AnyType(TypeOfAny.from_error) has_local_errors = local_errors.has_new_errors() if has_local_errors or key_type is None: key_type = AnyType(TypeOfAny.from_error) self.msg.fail( message_registry.CLASS_PATTERN_UNKNOWN_KEYWORD.format( typ.str_with_options(self.options), keyword ), pattern, ) inner_type, inner_rest_type, inner_captures = self.accept(pattern, key_type) if is_uninhabited(inner_type): can_match = False else: self.update_type_map(captures, inner_captures) if not is_uninhabited(inner_rest_type): rest_type = current_type if not can_match: new_type = UninhabitedType() return PatternType(new_type, rest_type, captures) def should_self_match(self, typ: Type) -> bool: typ = get_proper_type(typ) if isinstance(typ, Instance) and typ.type.is_named_tuple: return False for other in self.self_match_types: if is_subtype(typ, other): return True return False def can_match_sequence(self, typ: ProperType) -> bool: if isinstance(typ, UnionType): return any(self.can_match_sequence(get_proper_type(item)) for item in typ.items) for other in self.non_sequence_match_types: # We have to ignore promotions, as memoryview should match, but bytes, # which it can be promoted to, shouldn't if is_subtype(typ, other, ignore_promotions=True): return False sequence = self.chk.named_type("typing.Sequence") # If the static type is more general than sequence the actual type could still match return is_subtype(typ, sequence) or is_subtype(sequence, typ) def generate_types_from_names(self, type_names: list[str]) -> list[Type]: types: list[Type] = [] for name in type_names: try: types.append(self.chk.named_type(name)) except KeyError as e: # Some built in types are not defined in all test cases if not name.startswith("builtins."): raise e return types def update_type_map( self, original_type_map: dict[Expression, Type], extra_type_map: dict[Expression, Type] ) -> None: # Calculating this would not be needed if TypeMap directly used literal hashes instead of # expressions, as suggested in the TODO above it's definition already_captured = {literal_hash(expr) for expr in original_type_map} for expr, typ in extra_type_map.items(): if literal_hash(expr) in already_captured: node = get_var(expr) self.msg.fail( message_registry.MULTIPLE_ASSIGNMENTS_IN_PATTERN.format(node.name), expr ) else: original_type_map[expr] = typ def construct_sequence_child(self, outer_type: Type, inner_type: Type) -> Type: """ If outer_type is a child class of typing.Sequence returns a new instance of outer_type, that is a Sequence of inner_type. If outer_type is not a child class of typing.Sequence just returns a Sequence of inner_type For example: construct_sequence_child(List[int], str) = List[str] TODO: this doesn't make sense. For example if one has class S(Sequence[int], Generic[T]) or class T(Sequence[Tuple[T, T]]), there is no way any of those can map to Sequence[str]. """ proper_type = get_proper_type(outer_type) if isinstance(proper_type, UnionType): types = [ self.construct_sequence_child(item, inner_type) for item in proper_type.items if self.can_match_sequence(get_proper_type(item)) ] return make_simplified_union(types) sequence = self.chk.named_generic_type("typing.Sequence", [inner_type]) if is_subtype(outer_type, self.chk.named_type("typing.Sequence")): proper_type = get_proper_type(outer_type) if isinstance(proper_type, TupleType): proper_type = tuple_fallback(proper_type) assert isinstance(proper_type, Instance) empty_type = fill_typevars(proper_type.type) partial_type = expand_type_by_instance(empty_type, sequence) return expand_type_by_instance(partial_type, proper_type) else: return sequence def early_non_match(self) -> PatternType: return PatternType(UninhabitedType(), self.type_context[-1], {}) def get_match_arg_names(typ: TupleType) -> list[str | None]: args: list[str | None] = [] for item in typ.items: values = try_getting_str_literals_from_type(item) if values is None or len(values) != 1: args.append(None) else: args.append(values[0]) return args def get_var(expr: Expression) -> Var: """ Warning: this in only true for expressions captured by a match statement. Don't call it from anywhere else """ assert isinstance(expr, NameExpr) node = expr.node assert isinstance(node, Var) return node def get_type_range(typ: Type) -> mypy.checker.TypeRange: typ = get_proper_type(typ) if ( isinstance(typ, Instance) and typ.last_known_value and isinstance(typ.last_known_value.value, bool) ): typ = typ.last_known_value return mypy.checker.TypeRange(typ, is_upper_bound=False) def is_uninhabited(typ: Type) -> bool: return isinstance(get_proper_type(typ), UninhabitedType)