|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -import fnmatch |
4 | | -import glob |
| 3 | +import fnmatch, glob |
5 | 4 | from pathlib import Path |
6 | | -from typing import TYPE_CHECKING |
| 5 | +from typing import TYPE_CHECKING, Iterable, Any, cast |
7 | 6 |
|
8 | 7 | from tomlkit import TOMLDocument, dumps, parse |
9 | 8 | from tomlkit.exceptions import NonExistentKey |
10 | | - |
11 | 9 | from commitizen.providers.base_provider import TomlProvider |
12 | 10 |
|
13 | 11 | if TYPE_CHECKING: |
14 | 12 | from tomlkit.items import AoT |
15 | 13 |
|
16 | 14 |
|
17 | | -class CargoProvider(TomlProvider): |
18 | | - """ |
19 | | - Cargo version management |
| 15 | +DictLike = dict[str, Any] |
20 | 16 |
|
21 | | - With support for `workspaces` |
22 | | - """ |
| 17 | + |
| 18 | +class CargoProvider(TomlProvider): |
| 19 | + """Cargo version management for virtual workspace manifests + version.workspace=true members.""" |
23 | 20 |
|
24 | 21 | filename = "Cargo.toml" |
25 | 22 | lock_filename = "Cargo.lock" |
26 | 23 |
|
27 | 24 | @property |
28 | 25 | def lock_file(self) -> Path: |
29 | | - return Path() / self.lock_filename |
| 26 | + return Path(self.lock_filename) |
30 | 27 |
|
31 | 28 | def get(self, document: TOMLDocument) -> str: |
32 | | - out = _try_get_workspace(document)["package"]["version"] |
33 | | - if TYPE_CHECKING: |
34 | | - assert isinstance(out, str) |
35 | | - return out |
| 29 | + t = _root_version_table(document) |
| 30 | + v = t.get("version") |
| 31 | + if not isinstance(v, str): |
| 32 | + raise TypeError("expected root version to be a string") |
| 33 | + return v |
36 | 34 |
|
37 | 35 | def set(self, document: TOMLDocument, version: str) -> None: |
38 | | - _try_get_workspace(document)["package"]["version"] = version |
| 36 | + _root_version_table(document)["version"] = version |
39 | 37 |
|
40 | 38 | def set_version(self, version: str) -> None: |
41 | 39 | super().set_version(version) |
42 | 40 | if self.lock_file.exists(): |
43 | 41 | self.set_lock_version(version) |
44 | 42 |
|
45 | 43 | def set_lock_version(self, version: str) -> None: |
46 | | - cargo_toml_content = parse(self.file.read_text()) |
47 | | - cargo_lock_content = parse(self.lock_file.read_text()) |
48 | | - packages = cargo_lock_content["package"] |
49 | | - |
| 44 | + cargo_toml = parse(self.file.read_text()) |
| 45 | + cargo_lock = parse(self.lock_file.read_text()) |
| 46 | + packages = cargo_lock["package"] |
50 | 47 | if TYPE_CHECKING: |
51 | 48 | assert isinstance(packages, AoT) |
52 | 49 |
|
53 | | - try: |
54 | | - cargo_package_name = cargo_toml_content["package"]["name"] # type: ignore[index] |
55 | | - if TYPE_CHECKING: |
56 | | - assert isinstance(cargo_package_name, str) |
57 | | - for i, package in enumerate(packages): |
58 | | - if package["name"] == cargo_package_name: |
59 | | - cargo_lock_content["package"][i]["version"] = version # type: ignore[index] |
60 | | - break |
61 | | - except NonExistentKey: |
62 | | - workspace = cargo_toml_content.get("workspace", {}) |
63 | | - if TYPE_CHECKING: |
64 | | - assert isinstance(workspace, dict) |
65 | | - workspace_members = workspace.get("members", []) |
66 | | - excluded_workspace_members = workspace.get("exclude", []) |
67 | | - members_inheriting: list[str] = [] |
68 | | - |
69 | | - for member in workspace_members: |
70 | | - for path in glob.glob(member, recursive=True): |
71 | | - if any( |
72 | | - fnmatch.fnmatch(path, pattern) |
73 | | - for pattern in excluded_workspace_members |
74 | | - ): |
75 | | - continue |
76 | | - |
77 | | - cargo_file = Path(path) / "Cargo.toml" |
78 | | - package_content = parse(cargo_file.read_text()).get("package", {}) |
79 | | - if TYPE_CHECKING: |
80 | | - assert isinstance(package_content, dict) |
81 | | - try: |
82 | | - version_workspace = package_content["version"]["workspace"] |
83 | | - if version_workspace is True: |
84 | | - package_name = package_content["name"] |
85 | | - if TYPE_CHECKING: |
86 | | - assert isinstance(package_name, str) |
87 | | - members_inheriting.append(package_name) |
88 | | - except NonExistentKey: |
89 | | - pass |
90 | | - |
91 | | - for i, package in enumerate(packages): |
92 | | - if package["name"] in members_inheriting: |
93 | | - cargo_lock_content["package"][i]["version"] = version # type: ignore[index] |
94 | | - |
95 | | - self.lock_file.write_text(dumps(cargo_lock_content)) |
96 | | - |
97 | | - |
98 | | -def _try_get_workspace(document: TOMLDocument) -> dict: |
| 50 | + root_pkg = _table_get(cargo_toml, "package") |
| 51 | + if root_pkg is not None: |
| 52 | + name = root_pkg.get("name") |
| 53 | + if isinstance(name, str): |
| 54 | + _lock_set_versions(packages, {name}, version) |
| 55 | + self.lock_file.write_text(dumps(cargo_lock)) |
| 56 | + return |
| 57 | + |
| 58 | + ws = _table_get(cargo_toml, "workspace") or {} |
| 59 | + members = cast(list[str], ws.get("members", []) or []) |
| 60 | + excludes = cast(list[str], ws.get("exclude", []) or []) |
| 61 | + inheriting = _workspace_inheriting_member_names(members, excludes) |
| 62 | + _lock_set_versions(packages, inheriting, version) |
| 63 | + self.lock_file.write_text(dumps(cargo_lock)) |
| 64 | + |
| 65 | + |
| 66 | +def _table_get(doc: TOMLDocument, key: str) -> DictLike | None: |
| 67 | + """Return a dict-like table for `key` if present, else None (type-safe for Pylance).""" |
99 | 68 | try: |
100 | | - workspace = document["workspace"] |
101 | | - if TYPE_CHECKING: |
102 | | - assert isinstance(workspace, dict) |
103 | | - return workspace |
| 69 | + v = doc[key] # tomlkit returns Container/Table-like; typing is loose |
104 | 70 | except NonExistentKey: |
105 | | - return document |
| 71 | + return None |
| 72 | + return cast(DictLike, v) if hasattr(v, "get") else None |
| 73 | + |
| 74 | + |
| 75 | +def _root_version_table(doc: TOMLDocument) -> DictLike: |
| 76 | + """Prefer [workspace.package]; fallback to [package].""" |
| 77 | + ws = _table_get(doc, "workspace") |
| 78 | + if ws is not None: |
| 79 | + pkg = ws.get("package") |
| 80 | + if hasattr(pkg, "get"): |
| 81 | + return cast(DictLike, pkg) |
| 82 | + pkg = _table_get(doc, "package") |
| 83 | + if pkg is None: |
| 84 | + raise NonExistentKey('expected either [workspace.package] or [package]') |
| 85 | + return pkg |
| 86 | + |
| 87 | + |
| 88 | +def _is_workspace_inherited_version(v: Any) -> bool: |
| 89 | + return hasattr(v, "get") and cast(DictLike, v).get("workspace") is True |
| 90 | + |
| 91 | + |
| 92 | +def _iter_member_dirs(members: Iterable[str], excludes: Iterable[str]) -> Iterable[Path]: |
| 93 | + for pat in members: |
| 94 | + for p in glob.glob(pat, recursive=True): |
| 95 | + if any(fnmatch.fnmatch(p, ex) for ex in excludes): |
| 96 | + continue |
| 97 | + yield Path(p) |
| 98 | + |
| 99 | + |
| 100 | +def _workspace_inheriting_member_names(members: Iterable[str], excludes: Iterable[str]) -> set[str]: |
| 101 | + out: set[str] = set() |
| 102 | + for d in _iter_member_dirs(members, excludes): |
| 103 | + cargo_file = d / "Cargo.toml" |
| 104 | + if not cargo_file.exists(): |
| 105 | + continue |
| 106 | + pkg = parse(cargo_file.read_text()).get("package") |
| 107 | + if not hasattr(pkg, "get"): |
| 108 | + continue |
| 109 | + pkgd = cast(DictLike, pkg) |
| 110 | + if _is_workspace_inherited_version(pkgd.get("version")): |
| 111 | + name = pkgd.get("name") |
| 112 | + if isinstance(name, str): |
| 113 | + out.add(name) |
| 114 | + return out |
| 115 | + |
| 116 | + |
| 117 | +def _lock_set_versions(packages: Any, names: set[str], version: str) -> None: |
| 118 | + if not names: |
| 119 | + return |
| 120 | + for i, p in enumerate(packages): |
| 121 | + if getattr(p, "get", None) and p.get("name") in names: |
| 122 | + packages[i]["version"] = version # type: ignore[index] |
0 commit comments