gtn/.venv/Lib/site-packages/mypy/test/helpers.py

464 lines
16 KiB
Python
Raw Normal View History

from __future__ import annotations
import contextlib
import os
import pathlib
import re
import shutil
import sys
import time
from typing import Any, Callable, Iterable, Iterator, Pattern
# Exporting Suite as alias to TestCase for backwards compatibility
# TODO: avoid aliasing - import and subclass TestCase directly
from unittest import TestCase
Suite = TestCase # re-exporting
import pytest
import mypy.api as api
import mypy.version
from mypy import defaults
from mypy.main import process_options
from mypy.options import Options
from mypy.test.config import test_data_prefix, test_temp_dir
from mypy.test.data import DataDrivenTestCase, DeleteFile, UpdateFile, fix_cobertura_filename
skip = pytest.mark.skip
# AssertStringArraysEqual displays special line alignment helper messages if
# the first different line has at least this many characters,
MIN_LINE_LENGTH_FOR_ALIGNMENT = 5
def run_mypy(args: list[str]) -> None:
__tracebackhide__ = True
# We must enable site packages even though they could cause problems,
# since stubs for typing_extensions live there.
outval, errval, status = api.run(args + ["--show-traceback", "--no-silence-site-packages"])
if status != 0:
sys.stdout.write(outval)
sys.stderr.write(errval)
pytest.fail(msg="Sample check failed", pytrace=False)
def assert_string_arrays_equal(expected: list[str], actual: list[str], msg: str) -> None:
"""Assert that two string arrays are equal.
Display any differences in a human-readable form.
"""
actual = clean_up(actual)
if actual != expected:
num_skip_start = num_skipped_prefix_lines(expected, actual)
num_skip_end = num_skipped_suffix_lines(expected, actual)
sys.stderr.write("Expected:\n")
# If omit some lines at the beginning, indicate it by displaying a line
# with '...'.
if num_skip_start > 0:
sys.stderr.write(" ...\n")
# Keep track of the first different line.
first_diff = -1
# Display only this many first characters of identical lines.
width = 75
for i in range(num_skip_start, len(expected) - num_skip_end):
if i >= len(actual) or expected[i] != actual[i]:
if first_diff < 0:
first_diff = i
sys.stderr.write(f" {expected[i]:<45} (diff)")
else:
e = expected[i]
sys.stderr.write(" " + e[:width])
if len(e) > width:
sys.stderr.write("...")
sys.stderr.write("\n")
if num_skip_end > 0:
sys.stderr.write(" ...\n")
sys.stderr.write("Actual:\n")
if num_skip_start > 0:
sys.stderr.write(" ...\n")
for j in range(num_skip_start, len(actual) - num_skip_end):
if j >= len(expected) or expected[j] != actual[j]:
sys.stderr.write(f" {actual[j]:<45} (diff)")
else:
a = actual[j]
sys.stderr.write(" " + a[:width])
if len(a) > width:
sys.stderr.write("...")
sys.stderr.write("\n")
if not actual:
sys.stderr.write(" (empty)\n")
if num_skip_end > 0:
sys.stderr.write(" ...\n")
sys.stderr.write("\n")
if 0 <= first_diff < len(actual) and (
len(expected[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT
or len(actual[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT
):
# Display message that helps visualize the differences between two
# long lines.
show_align_message(expected[first_diff], actual[first_diff])
pytest.fail(msg, pytrace=False)
def assert_module_equivalence(name: str, expected: Iterable[str], actual: Iterable[str]) -> None:
expected_normalized = sorted(expected)
actual_normalized = sorted(set(actual).difference({"__main__"}))
assert_string_arrays_equal(
expected_normalized,
actual_normalized,
("Actual modules ({}) do not match expected modules ({}) " 'for "[{} ...]"').format(
", ".join(actual_normalized), ", ".join(expected_normalized), name
),
)
def assert_target_equivalence(name: str, expected: list[str], actual: list[str]) -> None:
"""Compare actual and expected targets (order sensitive)."""
assert_string_arrays_equal(
expected,
actual,
("Actual targets ({}) do not match expected targets ({}) " 'for "[{} ...]"').format(
", ".join(actual), ", ".join(expected), name
),
)
def show_align_message(s1: str, s2: str) -> None:
"""Align s1 and s2 so that the their first difference is highlighted.
For example, if s1 is 'foobar' and s2 is 'fobar', display the
following lines:
E: foobar
A: fobar
^
If s1 and s2 are long, only display a fragment of the strings around the
first difference. If s1 is very short, do nothing.
"""
# Seeing what went wrong is trivial even without alignment if the expected
# string is very short. In this case do nothing to simplify output.
if len(s1) < 4:
return
maxw = 72 # Maximum number of characters shown
sys.stderr.write("Alignment of first line difference:\n")
trunc = False
while s1[:30] == s2[:30]:
s1 = s1[10:]
s2 = s2[10:]
trunc = True
if trunc:
s1 = "..." + s1
s2 = "..." + s2
max_len = max(len(s1), len(s2))
extra = ""
if max_len > maxw:
extra = "..."
# Write a chunk of both lines, aligned.
sys.stderr.write(f" E: {s1[:maxw]}{extra}\n")
sys.stderr.write(f" A: {s2[:maxw]}{extra}\n")
# Write an indicator character under the different columns.
sys.stderr.write(" ")
for j in range(min(maxw, max(len(s1), len(s2)))):
if s1[j : j + 1] != s2[j : j + 1]:
sys.stderr.write("^") # Difference
break
else:
sys.stderr.write(" ") # Equal
sys.stderr.write("\n")
def clean_up(a: list[str]) -> list[str]:
"""Remove common directory prefix from all strings in a.
This uses a naive string replace; it seems to work well enough. Also
remove trailing carriage returns.
"""
res = []
pwd = os.getcwd()
driver = pwd + "/driver.py"
for s in a:
prefix = os.sep
ss = s
for p in prefix, prefix.replace(os.sep, "/"):
if p != "/" and p != "//" and p != "\\" and p != "\\\\":
ss = ss.replace(p, "")
# Ignore spaces at end of line.
ss = re.sub(" +$", "", ss)
# Remove pwd from driver.py's path
ss = ss.replace(driver, "driver.py")
res.append(re.sub("\\r$", "", ss))
return res
@contextlib.contextmanager
def local_sys_path_set() -> Iterator[None]:
"""Temporary insert current directory into sys.path.
This can be used by test cases that do runtime imports, for example
by the stubgen tests.
"""
old_sys_path = sys.path.copy()
if not ("" in sys.path or "." in sys.path):
sys.path.insert(0, "")
try:
yield
finally:
sys.path = old_sys_path
def num_skipped_prefix_lines(a1: list[str], a2: list[str]) -> int:
num_eq = 0
while num_eq < min(len(a1), len(a2)) and a1[num_eq] == a2[num_eq]:
num_eq += 1
return max(0, num_eq - 4)
def num_skipped_suffix_lines(a1: list[str], a2: list[str]) -> int:
num_eq = 0
while num_eq < min(len(a1), len(a2)) and a1[-num_eq - 1] == a2[-num_eq - 1]:
num_eq += 1
return max(0, num_eq - 4)
def testfile_pyversion(path: str) -> tuple[int, int]:
if path.endswith("python312.test"):
return 3, 12
elif path.endswith("python311.test"):
return 3, 11
elif path.endswith("python310.test"):
return 3, 10
elif path.endswith("python39.test"):
return 3, 9
elif path.endswith("python38.test"):
return 3, 8
else:
return defaults.PYTHON3_VERSION
def normalize_error_messages(messages: list[str]) -> list[str]:
"""Translate an array of error messages to use / as path separator."""
a = []
for m in messages:
a.append(m.replace(os.sep, "/"))
return a
def retry_on_error(func: Callable[[], Any], max_wait: float = 1.0) -> None:
"""Retry callback with exponential backoff when it raises OSError.
If the function still generates an error after max_wait seconds, propagate
the exception.
This can be effective against random file system operation failures on
Windows.
"""
t0 = time.time()
wait_time = 0.01
while True:
try:
func()
return
except OSError:
wait_time = min(wait_time * 2, t0 + max_wait - time.time())
if wait_time <= 0.01:
# Done enough waiting, the error seems persistent.
raise
time.sleep(wait_time)
def good_repr(obj: object) -> str:
if isinstance(obj, str):
if obj.count("\n") > 1:
bits = ["'''\\"]
for line in obj.split("\n"):
# force repr to use ' not ", then cut it off
bits.append(repr('"' + line)[2:-1])
bits[-1] += "'''"
return "\n".join(bits)
return repr(obj)
def assert_equal(a: object, b: object, fmt: str = "{} != {}") -> None:
__tracebackhide__ = True
if a != b:
raise AssertionError(fmt.format(good_repr(a), good_repr(b)))
def typename(t: type) -> str:
if "." in str(t):
return str(t).split(".")[-1].rstrip("'>")
else:
return str(t)[8:-2]
def assert_type(typ: type, value: object) -> None:
__tracebackhide__ = True
if type(value) != typ:
raise AssertionError(f"Invalid type {typename(type(value))}, expected {typename(typ)}")
def parse_options(
program_text: str, testcase: DataDrivenTestCase, incremental_step: int
) -> Options:
"""Parse comments like '# flags: --foo' in a test case."""
options = Options()
flags = re.search("# flags: (.*)$", program_text, flags=re.MULTILINE)
if incremental_step > 1:
flags2 = re.search(f"# flags{incremental_step}: (.*)$", program_text, flags=re.MULTILINE)
if flags2:
flags = flags2
if flags:
flag_list = flags.group(1).split()
flag_list.append("--no-site-packages") # the tests shouldn't need an installed Python
targets, options = process_options(flag_list, require_targets=False)
if targets:
# TODO: support specifying targets via the flags pragma
raise RuntimeError("Specifying targets via the flags pragma is not supported.")
if "--show-error-codes" not in flag_list:
options.hide_error_codes = True
else:
flag_list = []
options = Options()
options.error_summary = False
options.hide_error_codes = True
options.force_uppercase_builtins = True
options.force_union_syntax = True
# Allow custom python version to override testfile_pyversion.
if all(flag.split("=")[0] != "--python-version" for flag in flag_list):
options.python_version = testfile_pyversion(testcase.file)
if testcase.config.getoption("--mypy-verbose"):
options.verbosity = testcase.config.getoption("--mypy-verbose")
return options
def split_lines(*streams: bytes) -> list[str]:
"""Returns a single list of string lines from the byte streams in args."""
return [s for stream in streams for s in stream.decode("utf8").splitlines()]
def write_and_fudge_mtime(content: str, target_path: str) -> None:
# In some systems, mtime has a resolution of 1 second which can
# cause annoying-to-debug issues when a file has the same size
# after a change. We manually set the mtime to circumvent this.
# Note that we increment the old file's mtime, which guarantees a
# different value, rather than incrementing the mtime after the
# copy, which could leave the mtime unchanged if the old file had
# a similarly fudged mtime.
new_time = None
if os.path.isfile(target_path):
new_time = os.stat(target_path).st_mtime + 1
dir = os.path.dirname(target_path)
os.makedirs(dir, exist_ok=True)
with open(target_path, "w", encoding="utf-8") as target:
target.write(content)
if new_time:
os.utime(target_path, times=(new_time, new_time))
def perform_file_operations(operations: list[UpdateFile | DeleteFile]) -> None:
for op in operations:
if isinstance(op, UpdateFile):
# Modify/create file
write_and_fudge_mtime(op.content, op.target_path)
else:
# Delete file/directory
if os.path.isdir(op.path):
# Sanity check to avoid unexpected deletions
assert op.path.startswith("tmp")
shutil.rmtree(op.path)
else:
# Use retries to work around potential flakiness on Windows (AppVeyor).
path = op.path
retry_on_error(lambda: os.remove(path))
def check_test_output_files(
testcase: DataDrivenTestCase, step: int, strip_prefix: str = ""
) -> None:
for path, expected_content in testcase.output_files:
if path.startswith(strip_prefix):
path = path[len(strip_prefix) :]
if not os.path.exists(path):
raise AssertionError(
"Expected file {} was not produced by test case{}".format(
path, " on step %d" % step if testcase.output2 else ""
)
)
with open(path, encoding="utf8") as output_file:
actual_output_content = output_file.read()
if isinstance(expected_content, Pattern):
if expected_content.fullmatch(actual_output_content) is not None:
continue
raise AssertionError(
"Output file {} did not match its expected output pattern\n---\n{}\n---".format(
path, actual_output_content
)
)
normalized_output = normalize_file_output(
actual_output_content.splitlines(), os.path.abspath(test_temp_dir)
)
# We always normalize things like timestamp, but only handle operating-system
# specific things if requested.
if testcase.normalize_output:
if testcase.suite.native_sep and os.path.sep == "\\":
normalized_output = [fix_cobertura_filename(line) for line in normalized_output]
normalized_output = normalize_error_messages(normalized_output)
assert_string_arrays_equal(
expected_content.splitlines(),
normalized_output,
"Output file {} did not match its expected output{}".format(
path, " on step %d" % step if testcase.output2 else ""
),
)
def normalize_file_output(content: list[str], current_abs_path: str) -> list[str]:
"""Normalize file output for comparison."""
timestamp_regex = re.compile(r"\d{10}")
result = [x.replace(current_abs_path, "$PWD") for x in content]
version = mypy.version.__version__
result = [re.sub(r"\b" + re.escape(version) + r"\b", "$VERSION", x) for x in result]
# We generate a new mypy.version when building mypy wheels that
# lacks base_version, so handle that case.
base_version = getattr(mypy.version, "base_version", version)
result = [re.sub(r"\b" + re.escape(base_version) + r"\b", "$VERSION", x) for x in result]
result = [timestamp_regex.sub("$TIMESTAMP", x) for x in result]
return result
def find_test_files(pattern: str, exclude: list[str] | None = None) -> list[str]:
return [
path.name
for path in (pathlib.Path(test_data_prefix).rglob(pattern))
if path.name not in (exclude or [])
]