From 08a6746f795ef2cc9eadd5a2dec72e39b055cf19 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Tue, 7 Nov 2023 16:50:07 +0000 Subject: [PATCH 01/11] Add the itertools recipes as test cases --- test_cases/stdlib/check_itertools.py | 483 +++++++++++++++++++++++++++ tests/regr_test.py | 2 +- 2 files changed, 484 insertions(+), 1 deletion(-) create mode 100644 test_cases/stdlib/check_itertools.py diff --git a/test_cases/stdlib/check_itertools.py b/test_cases/stdlib/check_itertools.py new file mode 100644 index 000000000000..2871c8009641 --- /dev/null +++ b/test_cases/stdlib/check_itertools.py @@ -0,0 +1,483 @@ +"""Type-annotated versions of the recipes from the itertools docs. + +These are all meant to be examples of idiomatic itertools usage, +so they should all type-check without error. +""" +from __future__ import annotations + +import collections +import functools +import math +import operator +import sys +from itertools import chain, combinations, count, cycle, filterfalse, groupby, islice, product, repeat, starmap, tee, zip_longest +from typing import Any, Callable, Collection, Hashable, Iterable, Iterator, Sequence, Sized, TypeVar, overload +from typing_extensions import Literal + +_T = TypeVar("_T") +_T1 = TypeVar("_T1") +_T2 = TypeVar("_T2") + + +def take(n: int, iterable: Iterable[_T]) -> list[_T]: + "Return first n items of the iterable as a list" + return list(islice(iterable, n)) + + +# Note: the itertools docs uses the parameter name "iterator", +# but the function actually accepts any iterable +# as its second argument +def prepend(value: _T1, iterator: Iterable[_T2]) -> chain[_T1 | _T2]: + "Prepend a single value in front of an iterator" + # prepend(1, [2, 3, 4]) --> 1 2 3 4 + return chain([value], iterator) + + +def tabulate(function: Callable[[int], _T], start: int = 0) -> Iterator[_T]: + "Return function(0), function(1), ..." + return map(function, count(start)) + + +# TODO: Uncomment when we can use PEP-646 in typeshed: +# +# _Ts = TypeVarTuple("_Ts") +# +# def repeatfunc(func: Callable[[Unpack[_Ts]], _T], times: int | None = None, *args: Unpack[_Ts]) -> Iterator[_T]: +# """Repeat calls to func with specified arguments. +# +# Example: repeatfunc(random.random) +# """ +# if times is None: +# return starmap(func, repeat(args)) +# return starmap(func, repeat(args, times)) + + +def flatten(list_of_lists: Iterable[Iterable[_T]]) -> chain[_T]: + "Flatten one level of nesting" + return chain.from_iterable(list_of_lists) + + +def ncycles(iterable: Iterable[_T], n: int) -> Iterator[_T]: + "Returns the sequence elements n times" + return chain.from_iterable(repeat(tuple(iterable), n)) + + +def tail(n: int, iterable: Iterable[_T]) -> Iterator[_T]: + "Return an iterator over the last n items" + # tail(3, 'ABCDEFG') --> E F G + return iter(collections.deque(iterable, maxlen=n)) + + +# This function *accepts* any iterable, +# but it only *makes sense* to use it with an iterator +def consume(iterator: Iterator[object], n: int | None = None) -> None: + "Advance the iterator n-steps ahead. If n is None, consume entirely." + # Use functions that consume iterators at C speed. + if n is None: + # feed the entire iterator into a zero-length deque + collections.deque(iterator, maxlen=0) + else: + # advance to the empty slice starting at position n + next(islice(iterator, n, n), None) + + +@overload +def nth(iterable: Iterable[_T], n: int, default: None = None) -> _T | None: + ... + + +@overload +def nth(iterable: Iterable[_T], n: int, default: _T1) -> _T | _T1: + ... + + +def nth(iterable: Iterable[object], n: int, default: object = None) -> object: + "Returns the nth item or a default value" + return next(islice(iterable, n, None), default) + + +@overload +def quantify(iterable: Iterable[object]) -> int: + ... + + +@overload +def quantify(iterable: Iterable[_T], pred: Callable[[_T], bool]) -> int: + ... + + +def quantify(iterable: Iterable[object], pred: Callable[[Any], bool] = bool) -> int: + "Given a predicate that returns True or False, count the True results." + return sum(map(pred, iterable)) + + +# Slightly adapted from the itertools docs recipe. +# See https://github.com/python/typeshed/issues/10980#issuecomment-1794927596 +def all_equal(iterable: Iterable[object]) -> bool: + "Returns True if all the elements are equal to each other" + g = groupby(iterable) + next(g, True) + return not next(g, False) + + +@overload +def first_true(iterable: Iterable[_T], default: bool = False, pred: Callable[[_T], bool] | None = None) -> _T | bool: + ... + + +@overload +def first_true(iterable: Iterable[_T], default: _T1, pred: Callable[[_T], bool] | None = None) -> _T | _T1: + ... + + +def first_true(iterable: Iterable[object], default: object = False, pred: Callable[[Any], bool] | None = None) -> object: + """Returns the first true value in the iterable. + + If no true value is found, returns *default* + + If *pred* is not None, returns the first item + for which pred(item) is true. + + """ + # first_true([a,b,c], x) --> a or b or c or x + # first_true([a,b], x, f) --> a if f(a) else b if f(b) else x + return next(filter(pred, iterable), default) + + +# This one has an extra `assert isinstance(iterable, Sized)` call +# compared to the itertools docs +def iter_index(iterable: Iterable[_T], value: _T, start: int = 0, stop: int | None = None) -> Iterator[int]: + "Return indices where a value occurs in a sequence or iterable." + # iter_index('AABCADEAF', 'A') --> 0 1 4 7 + seq_index = getattr(iterable, "index", None) + if seq_index is None: + # Slow path for general iterables + it = islice(iterable, start, stop) + for i, element in enumerate(it, start): + if element is value or element == value: + yield i + else: + # Fast path for sequences + assert isinstance(iterable, Sized) + stop = len(iterable) if stop is None else stop + i = start - 1 + try: + while True: + yield (i := seq_index(value, i + 1, stop)) + except ValueError: + pass + + +@overload +def iter_except(func: Callable[[], _T], exception: type[BaseException], first: None = None) -> Iterator[_T]: + ... + + +@overload +def iter_except(func: Callable[[], _T], exception: type[BaseException], first: Callable[[], _T1]) -> Iterator[_T | _T1]: + ... + + +def iter_except( + func: Callable[[], object], exception: type[BaseException], first: Callable[[], object] | None = None +) -> Iterator[object]: + """Call a function repeatedly until an exception is raised. + + Converts a call-until-exception interface to an iterator interface. + Like builtins.iter(func, sentinel) but uses an exception instead + of a sentinel to end the loop. + + Examples: + iter_except(functools.partial(heappop, h), IndexError) # priority queue iterator + iter_except(d.popitem, KeyError) # non-blocking dict iterator + iter_except(d.popleft, IndexError) # non-blocking deque iterator + iter_except(q.get_nowait, Queue.Empty) # loop over a producer Queue + iter_except(s.pop, KeyError) # non-blocking set iterator + + """ + try: + if first is not None: + yield first() # For database APIs needing an initial cast to db.first() + while True: + yield func() + except exception: + pass + + +def sliding_window(iterable: Iterable[_T], n: int) -> Iterator[tuple[_T, ...]]: + # sliding_window('ABCDEFG', 4) --> ABCD BCDE CDEF DEFG + it = iter(iterable) + window = collections.deque(islice(it, n - 1), maxlen=n) + for x in it: + window.append(x) + yield tuple(window) + + +def roundrobin(*iterables: Iterable[_T]) -> Iterator[_T]: + "roundrobin('ABC', 'D', 'EF') --> A D E B F C" + # Recipe credited to George Sakkis + num_active = len(iterables) + nexts: Iterator[Callable[[], _T]] = cycle(iter(it).__next__ for it in iterables) + while num_active: + try: + for next in nexts: + yield next() + except StopIteration: + # Remove the iterator we just exhausted from the cycle. + num_active -= 1 + nexts = cycle(islice(nexts, num_active)) + + +def partition(pred: Callable[[_T], bool], iterable: Iterable[_T]) -> tuple[Iterator[_T], Iterator[_T]]: + """Partition entries into false entries and true entries. + + If *pred* is slow, consider wrapping it with functools.lru_cache(). + """ + # partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9 + t1, t2 = tee(iterable) + return filterfalse(pred, t1), filter(pred, t2) + + +def subslices(seq: Sequence[_T]) -> Iterator[Sequence[_T]]: + "Return all contiguous non-empty subslices of a sequence" + # subslices('ABCD') --> A AB ABC ABCD B BC BCD C CD D + slices = starmap(slice, combinations(range(len(seq) + 1), 2)) + return map(operator.getitem, repeat(seq), slices) + + +def before_and_after(predicate: Callable[[_T], bool], it: Iterable[_T]) -> tuple[Iterator[_T], Iterator[_T]]: + """Variant of takewhile() that allows complete + access to the remainder of the iterator. + + >>> it = iter('ABCdEfGhI') + >>> all_upper, remainder = before_and_after(str.isupper, it) + >>> ''.join(all_upper) + 'ABC' + >>> ''.join(remainder) # takewhile() would lose the 'd' + 'dEfGhI' + + Note that the first iterator must be fully + consumed before the second iterator can + generate valid results. + """ + it = iter(it) + transition: list[_T] = [] + + def true_iterator() -> Iterator[_T]: + for elem in it: + if predicate(elem): + yield elem + else: + transition.append(elem) + return + + def remainder_iterator() -> Iterator[_T]: + yield from transition + yield from it + + return true_iterator(), remainder_iterator() + + +def unique_everseen(iterable: Iterable[_T], key: Callable[[_T], Hashable] | None = None) -> Iterator[_T]: + "List unique elements, preserving order. Remember all elements ever seen." + # unique_everseen('AAAABBBCCDAABBB') --> A B C D + # unique_everseen('ABBcCAD', str.lower) --> A B c D + seen: set[Hashable] = set() + if key is None: + for element in filterfalse(seen.__contains__, iterable): + seen.add(element) + yield element + # For order preserving deduplication, + # a faster but non-lazy solution is: + # yield from dict.fromkeys(iterable) + else: + for element in iterable: + k = key(element) + if k not in seen: + seen.add(k) + yield element + # For use cases that allow the last matching element to be returned, + # a faster but non-lazy solution is: + # t1, t2 = tee(iterable) + # yield from dict(zip(map(key, t1), t2)).values() + + +# Slightly adapted from the docs recipe; a one-liner was a bit much for pyright +def unique_justseen(iterable: Iterable[_T], key: Callable[[_T], bool] | None = None) -> Iterator[_T]: + "List unique elements, preserving order. Remember only the element just seen." + # unique_justseen('AAAABBBCCDAABBB') --> A B C D A B + # unique_justseen('ABBcCAD', str.lower) --> A B c A D + g: groupby[_T, _T | bool] = groupby(iterable, key) + return map(next, map(operator.itemgetter(1), g)) + + +def powerset(iterable: Iterable[_T]) -> Iterator[tuple[_T, ...]]: + "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)" + s = list(iterable) + return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) + + +def polynomial_derivative(coefficients: Sequence[int]) -> list[int]: + """Compute the first derivative of a polynomial. + + f(x) = x³ -4x² -17x + 60 + f'(x) = 3x² -8x -17 + """ + # polynomial_derivative([1, -4, -17, 60]) -> [3, -8, -17] + n = len(coefficients) + powers = reversed(range(1, n)) + return list(map(operator.mul, coefficients, powers)) + + +def sieve(n: int) -> Iterator[int]: + "Primes less than n." + # sieve(30) --> 2 3 5 7 11 13 17 19 23 29 + if n > 2: + yield 2 + start = 3 + data = bytearray((0, 1)) * (n // 2) + limit = math.isqrt(n) + 1 + for p in iter_index(data, 1, start, limit): + yield from iter_index(data, 1, start, p * p) + data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p))) + start = p * p + yield from iter_index(data, 1, start) + + +def factor(n: int) -> Iterator[int]: + "Prime factors of n." + # factor(99) --> 3 3 11 + # factor(1_000_000_000_000_007) --> 47 59 360620266859 + # factor(1_000_000_000_000_403) --> 1000000000000403 + for prime in sieve(math.isqrt(n) + 1): + while not n % prime: + yield prime + n //= prime + if n == 1: + return + if n > 1: + yield n + + +def nth_combination(iterable: Iterable[_T], r: int, index: int) -> tuple[_T, ...]: + "Equivalent to list(combinations(iterable, r))[index]" + pool = tuple(iterable) + n = len(pool) + c = math.comb(n, r) + if index < 0: + index += c + if index < 0 or index >= c: + raise IndexError + result: list[_T] = [] + while r: + c, n, r = c * r // n, n - 1, r - 1 + while index >= c: + index -= c + c, n = c * (n - r) // n, n - 1 + result.append(pool[-1 - n]) + return tuple(result) + + +if sys.version_info >= (3, 10): + + @overload + def grouper( + iterable: Iterable[_T], n: int, *, incomplete: Literal["fill"] = "fill", fillvalue: None = None + ) -> Iterator[tuple[_T | None, ...]]: + ... + + @overload + def grouper( + iterable: Iterable[_T], n: int, *, incomplete: Literal["fill"] = "fill", fillvalue: _T1 + ) -> Iterator[tuple[_T | _T1, ...]]: + ... + + @overload + def grouper( + iterable: Iterable[_T], n: int, *, incomplete: Literal["strict", "ignore"], fillvalue: None = None + ) -> Iterator[tuple[_T, ...]]: + ... + + def grouper( + iterable: Iterable[object], n: int, *, incomplete: Literal["fill", "strict", "ignore"] = "fill", fillvalue: object = None + ) -> Iterator[tuple[object, ...]]: + "Collect data into non-overlapping fixed-length chunks or blocks" + # grouper('ABCDEFG', 3, fillvalue='x') --> ABC DEF Gxx + # grouper('ABCDEFG', 3, incomplete='strict') --> ABC DEF ValueError + # grouper('ABCDEFG', 3, incomplete='ignore') --> ABC DEF + args = [iter(iterable)] * n + if incomplete == "fill": + return zip_longest(*args, fillvalue=fillvalue) + if incomplete == "strict": + return zip(*args, strict=True) + if incomplete == "ignore": + return zip(*args) + else: + raise ValueError("Expected fill, strict, or ignore") + + def transpose(it: Iterable[Iterable[_T]]) -> Iterator[tuple[_T, ...]]: + "Swap the rows and columns of the input." + # transpose([(1, 2, 3), (11, 22, 33)]) --> (1, 11) (2, 22) (3, 33) + return zip(*it, strict=True) + + +if sys.version_info >= (3, 12): + from itertools import batched + + def sum_of_squares(it: Iterable[float]) -> float: + "Add up the squares of the input values." + # sum_of_squares([10, 20, 30]) -> 1400 + return math.sumprod(*tee(it)) + + def convolve(signal: Iterable[float], kernel: Iterable[float]) -> Iterator[float]: + """Discrete linear convolution of two iterables. + + The kernel is fully consumed before the calculations begin. + The signal is consumed lazily and can be infinite. + + Convolutions are mathematically commutative. + If the signal and kernel are swapped, + the output will be the same. + + Article: https://betterexplained.com/articles/intuitive-convolution/ + Video: https://www.youtube.com/watch?v=KuXjwB4LzSA + """ + # convolve(data, [0.25, 0.25, 0.25, 0.25]) --> Moving average (blur) + # convolve(data, [1/2, 0, -1/2]) --> 1st derivative estimate + # convolve(data, [1, -2, 1]) --> 2nd derivative estimate + kernel = tuple(kernel)[::-1] + n = len(kernel) + padded_signal = chain(repeat(0, n - 1), signal, repeat(0, n - 1)) + windowed_signal = sliding_window(padded_signal, n) + return map(math.sumprod, repeat(kernel), windowed_signal) + + def polynomial_from_roots(roots: Iterable[int]) -> list[float]: + """Compute a polynomial's coefficients from its roots. + + (x - 5) (x + 4) (x - 3) expands to: x³ -4x² -17x + 60 + """ + # polynomial_from_roots([5, -4, 3]) --> [1, -4, -17, 60] + factors = zip(repeat(1), map(operator.neg, roots)) + return list(functools.reduce(convolve, factors, [1])) + + def polynomial_eval(coefficients: Sequence[int], x: float) -> float: + """Evaluate a polynomial at a specific value. + + Computes with better numeric stability than Horner's method. + """ + # Evaluate x³ -4x² -17x + 60 at x = 2.5 + # polynomial_eval([1, -4, -17, 60], x=2.5) --> 8.125 + n = len(coefficients) + if not n: + return type(x)(0) + powers = map(pow, repeat(x), reversed(range(n))) + return math.sumprod(coefficients, powers) + + # Slightly adapted from the itertools docs, + # to make things a little easier for pyright + def matmul(m1: Sequence[Collection[float]], m2: Sequence[Collection[float]]) -> Iterator[tuple[float, ...]]: + "Multiply two matrices." + # matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]) --> (49, 80), (41, 60) + n = len(m2[0]) + it: Iterator[float] = starmap(math.sumprod, product(m1, transpose(m2))) + return batched(it, n) diff --git a/tests/regr_test.py b/tests/regr_test.py index cc7ede290b9a..3d50aec406ea 100755 --- a/tests/regr_test.py +++ b/tests/regr_test.py @@ -41,7 +41,7 @@ TYPESHED = "typeshed" SUPPORTED_PLATFORMS = ["linux", "darwin", "win32"] -SUPPORTED_VERSIONS = ["3.12", "3.11", "3.10", "3.9", "3.8", "3.7"] +SUPPORTED_VERSIONS = ["3.12", "3.11", "3.10", "3.9", "3.8"] def package_with_test_cases(package_name: str) -> PackageInfo: From 97b5b5b2cf1dcb116b002a080e69d8cb483007c1 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Tue, 7 Nov 2023 16:57:10 +0000 Subject: [PATCH 02/11] Update itertools.pyi --- stdlib/itertools.pyi | 2 ++ 1 file changed, 2 insertions(+) diff --git a/stdlib/itertools.pyi b/stdlib/itertools.pyi index 881fb236be07..669958897a06 100644 --- a/stdlib/itertools.pyi +++ b/stdlib/itertools.pyi @@ -7,7 +7,9 @@ if sys.version_info >= (3, 9): from types import GenericAlias _T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) _S = TypeVar("_S") +_S_co = TypeVar("_S_co", covariant=True) _N = TypeVar("_N", int, float, SupportsFloat, SupportsInt, SupportsIndex, SupportsComplex) _T_co = TypeVar("_T_co", covariant=True) _S_co = TypeVar("_S_co", covariant=True) From 3209e32d66f92d9dbef75633fd5f3af3410f68a0 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Tue, 7 Nov 2023 17:05:14 +0000 Subject: [PATCH 03/11] Update itertools.pyi --- stdlib/itertools.pyi | 2 -- 1 file changed, 2 deletions(-) diff --git a/stdlib/itertools.pyi b/stdlib/itertools.pyi index 669958897a06..881fb236be07 100644 --- a/stdlib/itertools.pyi +++ b/stdlib/itertools.pyi @@ -7,9 +7,7 @@ if sys.version_info >= (3, 9): from types import GenericAlias _T = TypeVar("_T") -_T_co = TypeVar("_T_co", covariant=True) _S = TypeVar("_S") -_S_co = TypeVar("_S_co", covariant=True) _N = TypeVar("_N", int, float, SupportsFloat, SupportsInt, SupportsIndex, SupportsComplex) _T_co = TypeVar("_T_co", covariant=True) _S_co = TypeVar("_S_co", covariant=True) From f80425eb34be9e34fbae9d7cd72f3cac1f7f90c0 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Tue, 7 Nov 2023 17:13:05 +0000 Subject: [PATCH 04/11] Update tests.yml --- .github/workflows/tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a4b7c009b596..4ff0a41c3963 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -173,6 +173,7 @@ jobs: no-comments: ${{ matrix.python-version != '3.11' || matrix.python-platform != 'Linux' }} # Having each job create the same comment is too noisy. project: ./pyrightconfig.stricter.json - name: Run pyright on the test cases + if: ${{ matrix.python-version != '3.7' }} uses: jakebailey/pyright-action@v1 with: version: ${{ steps.pyright_version.outputs.value }} From 8268732cb583a703f61251d23ee5a9237fdae2db Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Tue, 7 Nov 2023 17:16:58 +0000 Subject: [PATCH 05/11] Update check_itertools.py --- test_cases/stdlib/check_itertools.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test_cases/stdlib/check_itertools.py b/test_cases/stdlib/check_itertools.py index 2871c8009641..d3043b8d730f 100644 --- a/test_cases/stdlib/check_itertools.py +++ b/test_cases/stdlib/check_itertools.py @@ -307,8 +307,7 @@ def unique_justseen(iterable: Iterable[_T], key: Callable[[_T], bool] | None = N "List unique elements, preserving order. Remember only the element just seen." # unique_justseen('AAAABBBCCDAABBB') --> A B C D A B # unique_justseen('ABBcCAD', str.lower) --> A B c A D - g: groupby[_T, _T | bool] = groupby(iterable, key) - return map(next, map(operator.itemgetter(1), g)) + return map(next, map(operator.itemgetter(1), groupby(iterable, key))) def powerset(iterable: Iterable[_T]) -> Iterator[tuple[_T, ...]]: From 0f32c331223b2c976801b0006a81a97b58686234 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Tue, 7 Nov 2023 17:32:33 +0000 Subject: [PATCH 06/11] another attempt --- test_cases/stdlib/check_itertools.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test_cases/stdlib/check_itertools.py b/test_cases/stdlib/check_itertools.py index d3043b8d730f..3ce667dec262 100644 --- a/test_cases/stdlib/check_itertools.py +++ b/test_cases/stdlib/check_itertools.py @@ -307,7 +307,8 @@ def unique_justseen(iterable: Iterable[_T], key: Callable[[_T], bool] | None = N "List unique elements, preserving order. Remember only the element just seen." # unique_justseen('AAAABBBCCDAABBB') --> A B C D A B # unique_justseen('ABBcCAD', str.lower) --> A B c A D - return map(next, map(operator.itemgetter(1), groupby(iterable, key))) + g: groupby[_T | bool, _T] = groupby(iterable, key) + return map(next, map(operator.itemgetter(1), g)) def powerset(iterable: Iterable[_T]) -> Iterator[tuple[_T, ...]]: From 0b4222c83bda811058cf8aee0ca69eeac41bc257 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Tue, 7 Nov 2023 22:29:15 +0000 Subject: [PATCH 07/11] fix mypy errors --- test_cases/stdlib/check_itertools.py | 3 ++- tests/regr_test.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/test_cases/stdlib/check_itertools.py b/test_cases/stdlib/check_itertools.py index 3ce667dec262..8479f39b89f6 100644 --- a/test_cases/stdlib/check_itertools.py +++ b/test_cases/stdlib/check_itertools.py @@ -458,7 +458,8 @@ def polynomial_from_roots(roots: Iterable[int]) -> list[float]: """ # polynomial_from_roots([5, -4, 3]) --> [1, -4, -17, 60] factors = zip(repeat(1), map(operator.neg, roots)) - return list(functools.reduce(convolve, factors, [1])) + it: Iterable[float] = functools.reduce(convolve, factors, [1]) + return list(it) def polynomial_eval(coefficients: Sequence[int], x: float) -> float: """Evaluate a polynomial at a specific value. diff --git a/tests/regr_test.py b/tests/regr_test.py index 3d50aec406ea..a728b89fde09 100755 --- a/tests/regr_test.py +++ b/tests/regr_test.py @@ -177,6 +177,7 @@ def run_testcases( platform, "--strict", "--pretty", + "--new-type-inference", ] if package.is_stdlib: From c45bd2d9a7bfaeaedabc0b3e9ab27cdbdb39ffec Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Tue, 7 Nov 2023 23:00:33 +0000 Subject: [PATCH 08/11] move around --- test_cases/stdlib/check_itertools.py | 484 ------------------ .../itertools/check_itertools_recipes.py | 83 +++ 2 files changed, 83 insertions(+), 484 deletions(-) delete mode 100644 test_cases/stdlib/check_itertools.py diff --git a/test_cases/stdlib/check_itertools.py b/test_cases/stdlib/check_itertools.py deleted file mode 100644 index 8479f39b89f6..000000000000 --- a/test_cases/stdlib/check_itertools.py +++ /dev/null @@ -1,484 +0,0 @@ -"""Type-annotated versions of the recipes from the itertools docs. - -These are all meant to be examples of idiomatic itertools usage, -so they should all type-check without error. -""" -from __future__ import annotations - -import collections -import functools -import math -import operator -import sys -from itertools import chain, combinations, count, cycle, filterfalse, groupby, islice, product, repeat, starmap, tee, zip_longest -from typing import Any, Callable, Collection, Hashable, Iterable, Iterator, Sequence, Sized, TypeVar, overload -from typing_extensions import Literal - -_T = TypeVar("_T") -_T1 = TypeVar("_T1") -_T2 = TypeVar("_T2") - - -def take(n: int, iterable: Iterable[_T]) -> list[_T]: - "Return first n items of the iterable as a list" - return list(islice(iterable, n)) - - -# Note: the itertools docs uses the parameter name "iterator", -# but the function actually accepts any iterable -# as its second argument -def prepend(value: _T1, iterator: Iterable[_T2]) -> chain[_T1 | _T2]: - "Prepend a single value in front of an iterator" - # prepend(1, [2, 3, 4]) --> 1 2 3 4 - return chain([value], iterator) - - -def tabulate(function: Callable[[int], _T], start: int = 0) -> Iterator[_T]: - "Return function(0), function(1), ..." - return map(function, count(start)) - - -# TODO: Uncomment when we can use PEP-646 in typeshed: -# -# _Ts = TypeVarTuple("_Ts") -# -# def repeatfunc(func: Callable[[Unpack[_Ts]], _T], times: int | None = None, *args: Unpack[_Ts]) -> Iterator[_T]: -# """Repeat calls to func with specified arguments. -# -# Example: repeatfunc(random.random) -# """ -# if times is None: -# return starmap(func, repeat(args)) -# return starmap(func, repeat(args, times)) - - -def flatten(list_of_lists: Iterable[Iterable[_T]]) -> chain[_T]: - "Flatten one level of nesting" - return chain.from_iterable(list_of_lists) - - -def ncycles(iterable: Iterable[_T], n: int) -> Iterator[_T]: - "Returns the sequence elements n times" - return chain.from_iterable(repeat(tuple(iterable), n)) - - -def tail(n: int, iterable: Iterable[_T]) -> Iterator[_T]: - "Return an iterator over the last n items" - # tail(3, 'ABCDEFG') --> E F G - return iter(collections.deque(iterable, maxlen=n)) - - -# This function *accepts* any iterable, -# but it only *makes sense* to use it with an iterator -def consume(iterator: Iterator[object], n: int | None = None) -> None: - "Advance the iterator n-steps ahead. If n is None, consume entirely." - # Use functions that consume iterators at C speed. - if n is None: - # feed the entire iterator into a zero-length deque - collections.deque(iterator, maxlen=0) - else: - # advance to the empty slice starting at position n - next(islice(iterator, n, n), None) - - -@overload -def nth(iterable: Iterable[_T], n: int, default: None = None) -> _T | None: - ... - - -@overload -def nth(iterable: Iterable[_T], n: int, default: _T1) -> _T | _T1: - ... - - -def nth(iterable: Iterable[object], n: int, default: object = None) -> object: - "Returns the nth item or a default value" - return next(islice(iterable, n, None), default) - - -@overload -def quantify(iterable: Iterable[object]) -> int: - ... - - -@overload -def quantify(iterable: Iterable[_T], pred: Callable[[_T], bool]) -> int: - ... - - -def quantify(iterable: Iterable[object], pred: Callable[[Any], bool] = bool) -> int: - "Given a predicate that returns True or False, count the True results." - return sum(map(pred, iterable)) - - -# Slightly adapted from the itertools docs recipe. -# See https://github.com/python/typeshed/issues/10980#issuecomment-1794927596 -def all_equal(iterable: Iterable[object]) -> bool: - "Returns True if all the elements are equal to each other" - g = groupby(iterable) - next(g, True) - return not next(g, False) - - -@overload -def first_true(iterable: Iterable[_T], default: bool = False, pred: Callable[[_T], bool] | None = None) -> _T | bool: - ... - - -@overload -def first_true(iterable: Iterable[_T], default: _T1, pred: Callable[[_T], bool] | None = None) -> _T | _T1: - ... - - -def first_true(iterable: Iterable[object], default: object = False, pred: Callable[[Any], bool] | None = None) -> object: - """Returns the first true value in the iterable. - - If no true value is found, returns *default* - - If *pred* is not None, returns the first item - for which pred(item) is true. - - """ - # first_true([a,b,c], x) --> a or b or c or x - # first_true([a,b], x, f) --> a if f(a) else b if f(b) else x - return next(filter(pred, iterable), default) - - -# This one has an extra `assert isinstance(iterable, Sized)` call -# compared to the itertools docs -def iter_index(iterable: Iterable[_T], value: _T, start: int = 0, stop: int | None = None) -> Iterator[int]: - "Return indices where a value occurs in a sequence or iterable." - # iter_index('AABCADEAF', 'A') --> 0 1 4 7 - seq_index = getattr(iterable, "index", None) - if seq_index is None: - # Slow path for general iterables - it = islice(iterable, start, stop) - for i, element in enumerate(it, start): - if element is value or element == value: - yield i - else: - # Fast path for sequences - assert isinstance(iterable, Sized) - stop = len(iterable) if stop is None else stop - i = start - 1 - try: - while True: - yield (i := seq_index(value, i + 1, stop)) - except ValueError: - pass - - -@overload -def iter_except(func: Callable[[], _T], exception: type[BaseException], first: None = None) -> Iterator[_T]: - ... - - -@overload -def iter_except(func: Callable[[], _T], exception: type[BaseException], first: Callable[[], _T1]) -> Iterator[_T | _T1]: - ... - - -def iter_except( - func: Callable[[], object], exception: type[BaseException], first: Callable[[], object] | None = None -) -> Iterator[object]: - """Call a function repeatedly until an exception is raised. - - Converts a call-until-exception interface to an iterator interface. - Like builtins.iter(func, sentinel) but uses an exception instead - of a sentinel to end the loop. - - Examples: - iter_except(functools.partial(heappop, h), IndexError) # priority queue iterator - iter_except(d.popitem, KeyError) # non-blocking dict iterator - iter_except(d.popleft, IndexError) # non-blocking deque iterator - iter_except(q.get_nowait, Queue.Empty) # loop over a producer Queue - iter_except(s.pop, KeyError) # non-blocking set iterator - - """ - try: - if first is not None: - yield first() # For database APIs needing an initial cast to db.first() - while True: - yield func() - except exception: - pass - - -def sliding_window(iterable: Iterable[_T], n: int) -> Iterator[tuple[_T, ...]]: - # sliding_window('ABCDEFG', 4) --> ABCD BCDE CDEF DEFG - it = iter(iterable) - window = collections.deque(islice(it, n - 1), maxlen=n) - for x in it: - window.append(x) - yield tuple(window) - - -def roundrobin(*iterables: Iterable[_T]) -> Iterator[_T]: - "roundrobin('ABC', 'D', 'EF') --> A D E B F C" - # Recipe credited to George Sakkis - num_active = len(iterables) - nexts: Iterator[Callable[[], _T]] = cycle(iter(it).__next__ for it in iterables) - while num_active: - try: - for next in nexts: - yield next() - except StopIteration: - # Remove the iterator we just exhausted from the cycle. - num_active -= 1 - nexts = cycle(islice(nexts, num_active)) - - -def partition(pred: Callable[[_T], bool], iterable: Iterable[_T]) -> tuple[Iterator[_T], Iterator[_T]]: - """Partition entries into false entries and true entries. - - If *pred* is slow, consider wrapping it with functools.lru_cache(). - """ - # partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9 - t1, t2 = tee(iterable) - return filterfalse(pred, t1), filter(pred, t2) - - -def subslices(seq: Sequence[_T]) -> Iterator[Sequence[_T]]: - "Return all contiguous non-empty subslices of a sequence" - # subslices('ABCD') --> A AB ABC ABCD B BC BCD C CD D - slices = starmap(slice, combinations(range(len(seq) + 1), 2)) - return map(operator.getitem, repeat(seq), slices) - - -def before_and_after(predicate: Callable[[_T], bool], it: Iterable[_T]) -> tuple[Iterator[_T], Iterator[_T]]: - """Variant of takewhile() that allows complete - access to the remainder of the iterator. - - >>> it = iter('ABCdEfGhI') - >>> all_upper, remainder = before_and_after(str.isupper, it) - >>> ''.join(all_upper) - 'ABC' - >>> ''.join(remainder) # takewhile() would lose the 'd' - 'dEfGhI' - - Note that the first iterator must be fully - consumed before the second iterator can - generate valid results. - """ - it = iter(it) - transition: list[_T] = [] - - def true_iterator() -> Iterator[_T]: - for elem in it: - if predicate(elem): - yield elem - else: - transition.append(elem) - return - - def remainder_iterator() -> Iterator[_T]: - yield from transition - yield from it - - return true_iterator(), remainder_iterator() - - -def unique_everseen(iterable: Iterable[_T], key: Callable[[_T], Hashable] | None = None) -> Iterator[_T]: - "List unique elements, preserving order. Remember all elements ever seen." - # unique_everseen('AAAABBBCCDAABBB') --> A B C D - # unique_everseen('ABBcCAD', str.lower) --> A B c D - seen: set[Hashable] = set() - if key is None: - for element in filterfalse(seen.__contains__, iterable): - seen.add(element) - yield element - # For order preserving deduplication, - # a faster but non-lazy solution is: - # yield from dict.fromkeys(iterable) - else: - for element in iterable: - k = key(element) - if k not in seen: - seen.add(k) - yield element - # For use cases that allow the last matching element to be returned, - # a faster but non-lazy solution is: - # t1, t2 = tee(iterable) - # yield from dict(zip(map(key, t1), t2)).values() - - -# Slightly adapted from the docs recipe; a one-liner was a bit much for pyright -def unique_justseen(iterable: Iterable[_T], key: Callable[[_T], bool] | None = None) -> Iterator[_T]: - "List unique elements, preserving order. Remember only the element just seen." - # unique_justseen('AAAABBBCCDAABBB') --> A B C D A B - # unique_justseen('ABBcCAD', str.lower) --> A B c A D - g: groupby[_T | bool, _T] = groupby(iterable, key) - return map(next, map(operator.itemgetter(1), g)) - - -def powerset(iterable: Iterable[_T]) -> Iterator[tuple[_T, ...]]: - "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)" - s = list(iterable) - return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) - - -def polynomial_derivative(coefficients: Sequence[int]) -> list[int]: - """Compute the first derivative of a polynomial. - - f(x) = x³ -4x² -17x + 60 - f'(x) = 3x² -8x -17 - """ - # polynomial_derivative([1, -4, -17, 60]) -> [3, -8, -17] - n = len(coefficients) - powers = reversed(range(1, n)) - return list(map(operator.mul, coefficients, powers)) - - -def sieve(n: int) -> Iterator[int]: - "Primes less than n." - # sieve(30) --> 2 3 5 7 11 13 17 19 23 29 - if n > 2: - yield 2 - start = 3 - data = bytearray((0, 1)) * (n // 2) - limit = math.isqrt(n) + 1 - for p in iter_index(data, 1, start, limit): - yield from iter_index(data, 1, start, p * p) - data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p))) - start = p * p - yield from iter_index(data, 1, start) - - -def factor(n: int) -> Iterator[int]: - "Prime factors of n." - # factor(99) --> 3 3 11 - # factor(1_000_000_000_000_007) --> 47 59 360620266859 - # factor(1_000_000_000_000_403) --> 1000000000000403 - for prime in sieve(math.isqrt(n) + 1): - while not n % prime: - yield prime - n //= prime - if n == 1: - return - if n > 1: - yield n - - -def nth_combination(iterable: Iterable[_T], r: int, index: int) -> tuple[_T, ...]: - "Equivalent to list(combinations(iterable, r))[index]" - pool = tuple(iterable) - n = len(pool) - c = math.comb(n, r) - if index < 0: - index += c - if index < 0 or index >= c: - raise IndexError - result: list[_T] = [] - while r: - c, n, r = c * r // n, n - 1, r - 1 - while index >= c: - index -= c - c, n = c * (n - r) // n, n - 1 - result.append(pool[-1 - n]) - return tuple(result) - - -if sys.version_info >= (3, 10): - - @overload - def grouper( - iterable: Iterable[_T], n: int, *, incomplete: Literal["fill"] = "fill", fillvalue: None = None - ) -> Iterator[tuple[_T | None, ...]]: - ... - - @overload - def grouper( - iterable: Iterable[_T], n: int, *, incomplete: Literal["fill"] = "fill", fillvalue: _T1 - ) -> Iterator[tuple[_T | _T1, ...]]: - ... - - @overload - def grouper( - iterable: Iterable[_T], n: int, *, incomplete: Literal["strict", "ignore"], fillvalue: None = None - ) -> Iterator[tuple[_T, ...]]: - ... - - def grouper( - iterable: Iterable[object], n: int, *, incomplete: Literal["fill", "strict", "ignore"] = "fill", fillvalue: object = None - ) -> Iterator[tuple[object, ...]]: - "Collect data into non-overlapping fixed-length chunks or blocks" - # grouper('ABCDEFG', 3, fillvalue='x') --> ABC DEF Gxx - # grouper('ABCDEFG', 3, incomplete='strict') --> ABC DEF ValueError - # grouper('ABCDEFG', 3, incomplete='ignore') --> ABC DEF - args = [iter(iterable)] * n - if incomplete == "fill": - return zip_longest(*args, fillvalue=fillvalue) - if incomplete == "strict": - return zip(*args, strict=True) - if incomplete == "ignore": - return zip(*args) - else: - raise ValueError("Expected fill, strict, or ignore") - - def transpose(it: Iterable[Iterable[_T]]) -> Iterator[tuple[_T, ...]]: - "Swap the rows and columns of the input." - # transpose([(1, 2, 3), (11, 22, 33)]) --> (1, 11) (2, 22) (3, 33) - return zip(*it, strict=True) - - -if sys.version_info >= (3, 12): - from itertools import batched - - def sum_of_squares(it: Iterable[float]) -> float: - "Add up the squares of the input values." - # sum_of_squares([10, 20, 30]) -> 1400 - return math.sumprod(*tee(it)) - - def convolve(signal: Iterable[float], kernel: Iterable[float]) -> Iterator[float]: - """Discrete linear convolution of two iterables. - - The kernel is fully consumed before the calculations begin. - The signal is consumed lazily and can be infinite. - - Convolutions are mathematically commutative. - If the signal and kernel are swapped, - the output will be the same. - - Article: https://betterexplained.com/articles/intuitive-convolution/ - Video: https://www.youtube.com/watch?v=KuXjwB4LzSA - """ - # convolve(data, [0.25, 0.25, 0.25, 0.25]) --> Moving average (blur) - # convolve(data, [1/2, 0, -1/2]) --> 1st derivative estimate - # convolve(data, [1, -2, 1]) --> 2nd derivative estimate - kernel = tuple(kernel)[::-1] - n = len(kernel) - padded_signal = chain(repeat(0, n - 1), signal, repeat(0, n - 1)) - windowed_signal = sliding_window(padded_signal, n) - return map(math.sumprod, repeat(kernel), windowed_signal) - - def polynomial_from_roots(roots: Iterable[int]) -> list[float]: - """Compute a polynomial's coefficients from its roots. - - (x - 5) (x + 4) (x - 3) expands to: x³ -4x² -17x + 60 - """ - # polynomial_from_roots([5, -4, 3]) --> [1, -4, -17, 60] - factors = zip(repeat(1), map(operator.neg, roots)) - it: Iterable[float] = functools.reduce(convolve, factors, [1]) - return list(it) - - def polynomial_eval(coefficients: Sequence[int], x: float) -> float: - """Evaluate a polynomial at a specific value. - - Computes with better numeric stability than Horner's method. - """ - # Evaluate x³ -4x² -17x + 60 at x = 2.5 - # polynomial_eval([1, -4, -17, 60], x=2.5) --> 8.125 - n = len(coefficients) - if not n: - return type(x)(0) - powers = map(pow, repeat(x), reversed(range(n))) - return math.sumprod(coefficients, powers) - - # Slightly adapted from the itertools docs, - # to make things a little easier for pyright - def matmul(m1: Sequence[Collection[float]], m2: Sequence[Collection[float]]) -> Iterator[tuple[float, ...]]: - "Multiply two matrices." - # matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]) --> (49, 80), (41, 60) - n = len(m2[0]) - it: Iterator[float] = starmap(math.sumprod, product(m1, transpose(m2))) - return batched(it, n) diff --git a/test_cases/stdlib/itertools/check_itertools_recipes.py b/test_cases/stdlib/itertools/check_itertools_recipes.py index 340811eec1c5..cd26a90a895b 100644 --- a/test_cases/stdlib/itertools/check_itertools_recipes.py +++ b/test_cases/stdlib/itertools/check_itertools_recipes.py @@ -108,6 +108,15 @@ def quantify(iterable: Iterable[object], pred: Callable[[Any], bool] = bool) -> return sum(map(pred, iterable)) +# Slightly adapted from the itertools docs recipe. +# See https://github.com/python/typeshed/issues/10980#issuecomment-1794927596 +def all_equal(iterable: Iterable[object]) -> bool: + "Returns True if all the elements are equal to each other" + g = groupby(iterable) + next(g, True) + return not next(g, False) + + @overload def first_true( iterable: Iterable[_T], default: Literal[False] = False, pred: Callable[[_T], bool] | None = None @@ -134,6 +143,30 @@ def first_true(iterable: Iterable[object], default: object = False, pred: Callab _ExceptionOrExceptionTuple: TypeAlias = Union[Type[BaseException], Tuple[Type[BaseException], ...]] +# This one has an extra `assert isinstance(iterable, Sized)` call +# compared to the itertools docs +def iter_index(iterable: Iterable[_T], value: _T, start: int = 0, stop: int | None = None) -> Iterator[int]: + "Return indices where a value occurs in a sequence or iterable." + # iter_index('AABCADEAF', 'A') --> 0 1 4 7 + seq_index = getattr(iterable, "index", None) + if seq_index is None: + # Slow path for general iterables + it = islice(iterable, start, stop) + for i, element in enumerate(it, start): + if element is value or element == value: + yield i + else: + # Fast path for sequences + assert isinstance(iterable, Sized) + stop = len(iterable) if stop is None else stop + i = start - 1 + try: + while True: + yield (i := seq_index(value, i + 1, stop)) + except ValueError: + pass + + @overload def iter_except(func: Callable[[], _T], exception: _ExceptionOrExceptionTuple, first: None = None) -> Iterator[_T]: ... @@ -298,6 +331,36 @@ def polynomial_derivative(coefficients: Sequence[float]) -> list[float]: return list(map(operator.mul, coefficients, powers)) +def sieve(n: int) -> Iterator[int]: + "Primes less than n." + # sieve(30) --> 2 3 5 7 11 13 17 19 23 29 + if n > 2: + yield 2 + start = 3 + data = bytearray((0, 1)) * (n // 2) + limit = math.isqrt(n) + 1 + for p in iter_index(data, 1, start, limit): + yield from iter_index(data, 1, start, p * p) + data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p))) + start = p * p + yield from iter_index(data, 1, start) + + +def factor(n: int) -> Iterator[int]: + "Prime factors of n." + # factor(99) --> 3 3 11 + # factor(1_000_000_000_000_007) --> 47 59 360620266859 + # factor(1_000_000_000_000_403) --> 1000000000000403 + for prime in sieve(math.isqrt(n) + 1): + while not n % prime: + yield prime + n //= prime + if n == 1: + return + if n > 1: + yield n + + if sys.version_info >= (3, 8): def nth_combination(iterable: Iterable[_T], r: int, index: int) -> tuple[_T, ...]: @@ -363,6 +426,7 @@ def transpose(it: Iterable[Iterable[_T]]) -> Iterator[tuple[_T, ...]]: if sys.version_info >= (3, 12): + from itertools import batched def sum_of_squares(it: Iterable[float]) -> float: "Add up the squares of the input values." @@ -388,6 +452,16 @@ def convolve(signal: Iterable[float], kernel: Iterable[float]) -> Iterator[float windowed_signal = sliding_window(padded_signal, n) return map(math.sumprod, repeat(kernel), windowed_signal) + def polynomial_from_roots(roots: Iterable[int]) -> list[float]: + """Compute a polynomial's coefficients from its roots. + + (x - 5) (x + 4) (x - 3) expands to: x³ -4x² -17x + 60 + """ + # polynomial_from_roots([5, -4, 3]) --> [1, -4, -17, 60] + factors = zip(repeat(1), map(operator.neg, roots)) + it: Iterable[float] = functools.reduce(convolve, factors, [1]) + return list(it) + def polynomial_eval(coefficients: Sequence[float], x: float) -> float: """Evaluate a polynomial at a specific value. Computes with better numeric stability than Horner's method. @@ -399,3 +473,12 @@ def polynomial_eval(coefficients: Sequence[float], x: float) -> float: return type(x)(0) powers = map(pow, repeat(x), reversed(range(n))) return math.sumprod(coefficients, powers) + + # Slightly adapted from the itertools docs, + # to make things a little easier for pyright + def matmul(m1: Sequence[Collection[float]], m2: Sequence[Collection[float]]) -> Iterator[tuple[float, ...]]: + "Multiply two matrices." + # matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]) --> (49, 80), (41, 60) + n = len(m2[0]) + it: Iterator[float] = starmap(math.sumprod, product(m1, transpose(m2))) + return batched(it, n) From cee12932dee96f1aff5c0681d4acafa8aa308a78 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Wed, 15 Nov 2023 15:50:59 +0000 Subject: [PATCH 09/11] remove `--new-type-inference`, now enabled by default --- tests/regr_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/regr_test.py b/tests/regr_test.py index a728b89fde09..3d50aec406ea 100755 --- a/tests/regr_test.py +++ b/tests/regr_test.py @@ -177,7 +177,6 @@ def run_testcases( platform, "--strict", "--pretty", - "--new-type-inference", ] if package.is_stdlib: From a7bef32b7f5e8dba4c4919f485b994b26d74ecf8 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Wed, 15 Nov 2023 16:02:56 +0000 Subject: [PATCH 10/11] missing imports --- test_cases/stdlib/itertools/check_itertools_recipes.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test_cases/stdlib/itertools/check_itertools_recipes.py b/test_cases/stdlib/itertools/check_itertools_recipes.py index cd26a90a895b..cee3bdd73596 100644 --- a/test_cases/stdlib/itertools/check_itertools_recipes.py +++ b/test_cases/stdlib/itertools/check_itertools_recipes.py @@ -6,11 +6,12 @@ from __future__ import annotations import collections +import functools import math import operator import sys -from itertools import chain, combinations, count, cycle, filterfalse, groupby, islice, repeat, starmap, tee, zip_longest -from typing import Any, Callable, Hashable, Iterable, Iterator, Sequence, Tuple, Type, TypeVar, Union, overload +from itertools import chain, combinations, count, cycle, filterfalse, groupby, islice, product, repeat, starmap, tee, zip_longest +from typing import Any, Callable, Collection, Hashable, Iterable, Iterator, Sequence, Sized, Tuple, Type, TypeVar, Union, overload from typing_extensions import Literal, TypeAlias, TypeVarTuple, Unpack _T = TypeVar("_T") From 07efdcf577eaced0090a9f9df9cf736df1031848 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Wed, 29 Nov 2023 11:30:01 +0000 Subject: [PATCH 11/11] Update test_cases/stdlib/itertools/check_itertools_recipes.py --- test_cases/stdlib/itertools/check_itertools_recipes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test_cases/stdlib/itertools/check_itertools_recipes.py b/test_cases/stdlib/itertools/check_itertools_recipes.py index 5362fdae2ad3..4f9d917ae171 100644 --- a/test_cases/stdlib/itertools/check_itertools_recipes.py +++ b/test_cases/stdlib/itertools/check_itertools_recipes.py @@ -478,4 +478,5 @@ def polynomial_eval(coefficients: Sequence[float], x: float) -> float: def matmul(m1: Sequence[Collection[float]], m2: Sequence[Collection[float]]) -> Iterator[tuple[float, ...]]: "Multiply two matrices." # matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]) --> (49, 80), (41, 60) + n = len(m2[0]) return batched(starmap(math.sumprod, product(m1, transpose(m2))), n)