Skip to content
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ jobs:
no-comments: ${{ matrix.python-version != '3.11' || matrix.python-platform != 'Linux' }} # Having each job create the same comment is too noisy.
project: ./pyrightconfig.stricter.json
- name: Run pyright on the test cases
if: ${{ matrix.python-version != '3.7' }}
uses: jakebailey/pyright-action@v1
with:
version: ${{ steps.pyright_version.outputs.value }}
Expand Down
76 changes: 75 additions & 1 deletion test_cases/stdlib/itertools/check_itertools_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from __future__ import annotations

import collections
import functools
import math
import operator
import sys
from itertools import chain, combinations, count, cycle, filterfalse, groupby, islice, product, repeat, starmap, tee, zip_longest
from typing import Any, Callable, Collection, Hashable, Iterable, Iterator, Sequence, Tuple, Type, TypeVar, Union, overload
from typing import Any, Callable, Collection, Hashable, Iterable, Iterator, Sequence, Sized, Tuple, Type, TypeVar, Union, overload
from typing_extensions import Literal, TypeAlias, TypeVarTuple, Unpack

_T = TypeVar("_T")
Expand Down Expand Up @@ -108,6 +109,15 @@ def quantify(iterable: Iterable[object], pred: Callable[[Any], bool] = bool) ->
return sum(map(pred, iterable))


# Slightly adapted from the itertools docs recipe.
# See https://github.com/python/typeshed/issues/10980#issuecomment-1794927596
def all_equal(iterable: Iterable[object]) -> bool:
"Returns True if all the elements are equal to each other"
g = groupby(iterable)
next(g, True)
return not next(g, False)


@overload
def first_true(
iterable: Iterable[_T], default: Literal[False] = False, pred: Callable[[_T], bool] | None = None
Expand All @@ -134,6 +144,30 @@ def first_true(iterable: Iterable[object], default: object = False, pred: Callab
_ExceptionOrExceptionTuple: TypeAlias = Union[Type[BaseException], Tuple[Type[BaseException], ...]]


# This one has an extra `assert isinstance(iterable, Sized)` call
# compared to the itertools docs
def iter_index(iterable: Iterable[_T], value: _T, start: int = 0, stop: int | None = None) -> Iterator[int]:
"Return indices where a value occurs in a sequence or iterable."
# iter_index('AABCADEAF', 'A') --> 0 1 4 7
seq_index = getattr(iterable, "index", None)
if seq_index is None:
# Slow path for general iterables
it = islice(iterable, start, stop)
for i, element in enumerate(it, start):
if element is value or element == value:
yield i
else:
# Fast path for sequences
assert isinstance(iterable, Sized)
stop = len(iterable) if stop is None else stop
i = start - 1
try:
while True:
yield (i := seq_index(value, i + 1, stop))
except ValueError:
pass


@overload
def iter_except(func: Callable[[], _T], exception: _ExceptionOrExceptionTuple, first: None = None) -> Iterator[_T]:
...
Expand Down Expand Up @@ -298,6 +332,36 @@ def polynomial_derivative(coefficients: Sequence[float]) -> list[float]:
return list(map(operator.mul, coefficients, powers))


def sieve(n: int) -> Iterator[int]:
"Primes less than n."
# sieve(30) --> 2 3 5 7 11 13 17 19 23 29
if n > 2:
yield 2
start = 3
data = bytearray((0, 1)) * (n // 2)
limit = math.isqrt(n) + 1
for p in iter_index(data, 1, start, limit):
yield from iter_index(data, 1, start, p * p)
data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p)))
start = p * p
yield from iter_index(data, 1, start)


def factor(n: int) -> Iterator[int]:
"Prime factors of n."
# factor(99) --> 3 3 11
# factor(1_000_000_000_000_007) --> 47 59 360620266859
# factor(1_000_000_000_000_403) --> 1000000000000403
for prime in sieve(math.isqrt(n) + 1):
while not n % prime:
yield prime
n //= prime
if n == 1:
return
if n > 1:
yield n


if sys.version_info >= (3, 8):

def nth_combination(iterable: Iterable[_T], r: int, index: int) -> tuple[_T, ...]:
Expand Down Expand Up @@ -389,6 +453,16 @@ def convolve(signal: Iterable[float], kernel: Iterable[float]) -> Iterator[float
windowed_signal = sliding_window(padded_signal, n)
return map(math.sumprod, repeat(kernel), windowed_signal)

def polynomial_from_roots(roots: Iterable[int]) -> list[float]:
"""Compute a polynomial's coefficients from its roots.

(x - 5) (x + 4) (x - 3) expands to: x³ -4x² -17x + 60
"""
# polynomial_from_roots([5, -4, 3]) --> [1, -4, -17, 60]
factors = zip(repeat(1), map(operator.neg, roots))
it: Iterable[float] = functools.reduce(convolve, factors, [1])
return list(it)

def polynomial_eval(coefficients: Sequence[float], x: float) -> float:
"""Evaluate a polynomial at a specific value.
Computes with better numeric stability than Horner's method.
Expand Down
2 changes: 1 addition & 1 deletion tests/regr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
TYPESHED = "typeshed"

SUPPORTED_PLATFORMS = ["linux", "darwin", "win32"]
SUPPORTED_VERSIONS = ["3.12", "3.11", "3.10", "3.9", "3.8", "3.7"]
SUPPORTED_VERSIONS = ["3.12", "3.11", "3.10", "3.9", "3.8"]


def package_with_test_cases(package_name: str) -> PackageInfo:
Expand Down