Skip to content

Commit a904802

Browse files
Update smac planner types (#4927)
* Update smac planner types Signed-off-by: Michael Carlstrom <[email protected]> * Test ament_mypy Signed-off-by: Michael Carlstrom <[email protected]> * Add packages Signed-off-by: Michael Carlstrom <[email protected]> * Fix arg name Signed-off-by: Michael Carlstrom <[email protected]> * Add ** Signed-off-by: Michael Carlstrom <[email protected]> * Specific package Signed-off-by: Michael Carlstrom <[email protected]> * re-run ci Signed-off-by: Michael Carlstrom <[email protected]> * re-run ci Signed-off-by: Michael Carlstrom <[email protected]> --------- Signed-off-by: Michael Carlstrom <[email protected]>
1 parent d8dcb8d commit a904802

File tree

10 files changed

+208
-84
lines changed

10 files changed

+208
-84
lines changed

.github/workflows/lint.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,16 @@ jobs:
1919
linter: ${{ matrix.linter }}
2020
distribution: rolling
2121
package-name: "*"
22+
23+
ament_lint_mypy:
24+
name: ament_mypy
25+
runs-on: ubuntu-latest
26+
container:
27+
image: rostooling/setup-ros-docker:ubuntu-noble-ros-rolling-ros-base-latest
28+
steps:
29+
- uses: actions/checkout@v4
30+
- uses: ros-tooling/[email protected]
31+
with:
32+
linter: mypy
33+
distribution: rolling
34+
package-name: "nav2_smac_planner"

nav2_smac_planner/lattice_primitives/generate_motion_primitives.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,26 @@
1818
import logging
1919
from pathlib import Path
2020
import time
21+
from typing import Any, cast, Dict, List, TypedDict
2122

2223
import constants
23-
from lattice_generator import LatticeGenerator
24+
from lattice_generator import ConfigDict, LatticeGenerator
2425
import matplotlib.pyplot as plt
2526
import numpy as np
27+
from trajectory import Trajectory
2628

2729
logging.basicConfig(level=logging.INFO)
2830
logger = logging.getLogger(__name__)
2931

3032

31-
def handle_arg_parsing():
33+
class HeaderDict(TypedDict):
34+
version: float
35+
date_generated: str
36+
lattice_metadata: Dict[str, Any]
37+
primitives: List[Dict[str, Any]]
38+
39+
40+
def handle_arg_parsing() -> argparse.Namespace:
3241
"""
3342
Handle the parsing of arguments.
3443
@@ -64,7 +73,8 @@ def handle_arg_parsing():
6473
return parser.parse_args()
6574

6675

67-
def create_heading_angle_list(minimal_set_trajectories: dict) -> list:
76+
def create_heading_angle_list(minimal_set_trajectories: Dict[float, List[Trajectory]]
77+
) -> List[float]:
6878
"""
6979
Create a sorted list of heading angles from the minimal trajectory set.
7080
@@ -83,7 +93,7 @@ def create_heading_angle_list(minimal_set_trajectories: dict) -> list:
8393
return sorted(heading_angles, key=lambda x: (x < 0, x))
8494

8595

86-
def read_config(config_path) -> dict:
96+
def read_config(config_path: Path) -> ConfigDict:
8797
"""
8898
Read in the user defined parameters via JSON.
8999
@@ -101,10 +111,11 @@ def read_config(config_path) -> dict:
101111
with open(config_path) as config_file:
102112
config = json.load(config_file)
103113

104-
return config
114+
return cast(ConfigDict, config)
105115

106116

107-
def create_header(config: dict, minimal_set_trajectories: dict) -> dict:
117+
def create_header(config: ConfigDict, minimal_set_trajectories: Dict[float, List[Trajectory]]
118+
) -> HeaderDict:
108119
"""
109120
Create a dict containing all the fields to populate the header with.
110121
@@ -121,7 +132,7 @@ def create_header(config: dict, minimal_set_trajectories: dict) -> dict:
121132
A dictionary containing the fields to populate the header with
122133
123134
"""
124-
header_dict = {
135+
header_dict: HeaderDict = {
125136
'version': constants.VERSION,
126137
'date_generated': datetime.today().strftime('%Y-%m-%d'),
127138
'lattice_metadata': {},
@@ -142,7 +153,7 @@ def create_header(config: dict, minimal_set_trajectories: dict) -> dict:
142153

143154

144155
def write_to_json(
145-
output_path: Path, minimal_set_trajectories: dict, config: dict
156+
output_path: Path, minimal_set_trajectories: Dict[float, List[Trajectory]], config: ConfigDict
146157
) -> None:
147158
"""
148159
Write the minimal spanning set to an output file.
@@ -171,7 +182,7 @@ def write_to_json(
171182
minimal_set_trajectories[start_angle], key=lambda x: x.parameters.end_angle
172183
):
173184

174-
traj_info = {}
185+
traj_info: Dict[str, Any] = {}
175186
traj_info['trajectory_id'] = idx
176187
traj_info['start_angle_index'] = heading_lookup[
177188
trajectory.parameters.start_angle
@@ -202,7 +213,7 @@ def write_to_json(
202213

203214

204215
def save_visualizations(
205-
visualizations_folder: Path, minimal_set_trajectories: dict
216+
visualizations_folder: Path, minimal_set_trajectories: Dict[float, List[Trajectory]]
206217
) -> None:
207218
"""
208219
Draw the visualizations for every trajectory and save it as an image.

nav2_smac_planner/lattice_primitives/helper.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License. Reserved.
1414

15+
from typing import Any, Optional
16+
1517
import numpy as np
18+
from numpy.typing import NDArray
1619

1720

18-
def normalize_angle(angle):
21+
def normalize_angle(angle: float) -> float:
1922
"""
2023
Normalize the angle to between [0, 2pi).
2124
@@ -37,7 +40,8 @@ def normalize_angle(angle):
3740
return angle
3841

3942

40-
def angle_difference(angle_1, angle_2, left_turn=None):
43+
def angle_difference(angle_1: float, angle_2: float,
44+
left_turn: Optional[float] = None) -> float:
4145
"""
4246
Calculate the difference between two angles based on a given direction.
4347
@@ -76,7 +80,8 @@ def angle_difference(angle_1, angle_2, left_turn=None):
7680
return 2 * np.pi - abs(angle_1 - angle_2)
7781

7882

79-
def interpolate_yaws(start_angle, end_angle, left_turn, steps):
83+
def interpolate_yaws(start_angle: float, end_angle: float,
84+
left_turn: bool, steps: int) -> Any:
8085
"""
8186
Create equally spaced yaws between two angles.
8287
@@ -110,7 +115,7 @@ def interpolate_yaws(start_angle, end_angle, left_turn, steps):
110115
return yaws
111116

112117

113-
def get_rotation_matrix(angle):
118+
def get_rotation_matrix(angle: float) -> NDArray[np.floating[Any]]:
114119
"""
115120
Return a rotation matrix that is equivalent to a 2D rotation of angle.
116121

nav2_smac_planner/lattice_primitives/lattice_generator.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,25 @@
1414

1515
from collections import defaultdict
1616
from enum import Enum
17+
from typing import Any, Dict, List, Tuple, TypedDict
1718

1819
from helper import angle_difference, interpolate_yaws
1920
import numpy as np
2021
from rtree import index
21-
from trajectory import Path, Trajectory, TrajectoryParameters
22+
23+
from trajectory import AnyFloat, FloatNDArray, Path, Trajectory, TrajectoryParameters
24+
2225
from trajectory_generator import TrajectoryGenerator
2326

2427

28+
class ConfigDict(TypedDict):
29+
grid_resolution: float
30+
turning_radius: float
31+
stopping_threshold: int
32+
num_of_headings: int
33+
motion_model: str
34+
35+
2536
class LatticeGenerator:
2637
"""
2738
Handles all the logic for computing the minimal control set.
@@ -46,7 +57,7 @@ class Flip(Enum):
4657
Y = 2
4758
BOTH = 3
4859

49-
def __init__(self, config: dict):
60+
def __init__(self, config: ConfigDict):
5061
"""Init the lattice generator from the user supplied config."""
5162
self.trajectory_generator = TrajectoryGenerator(config)
5263
self.grid_resolution = config['grid_resolution']
@@ -60,7 +71,7 @@ def __init__(self, config: dict):
6071
self.DISTANCE_THRESHOLD = 0.5 * self.grid_resolution
6172
self.ROTATION_THRESHOLD = 0.5 * (2 * np.pi / self.num_of_headings)
6273

63-
def _get_wave_front_points(self, pos: int) -> np.array:
74+
def _get_wave_front_points(self, pos: int) -> FloatNDArray:
6475
"""
6576
Calculate the end points that lie on the wave front.
6677
@@ -97,7 +108,7 @@ def _get_wave_front_points(self, pos: int) -> np.array:
97108

98109
return np.array(positions)
99110

100-
def _get_heading_discretization(self, number_of_headings: int) -> list:
111+
def _get_heading_discretization(self, number_of_headings: int) -> List[int]:
101112
"""
102113
Calculate the heading discretization based on the number of headings.
103114
@@ -131,7 +142,8 @@ def _get_heading_discretization(self, number_of_headings: int) -> list:
131142

132143
return sorted([np.arctan2(j, i) for i, j in zip(outer_edge_x, outer_edge_y)])
133144

134-
def _point_to_line_distance(self, p1: np.array, p2: np.array, q: np.array) -> float:
145+
def _point_to_line_distance(self, p1: FloatNDArray, p2: FloatNDArray,
146+
q: FloatNDArray) -> AnyFloat:
135147
"""
136148
Return the shortest distance from a point to a line segment.
137149
@@ -241,7 +253,7 @@ def _compute_min_trajectory_length(self) -> float:
241253

242254
return self.turning_radius * min(heading_diff)
243255

244-
def _generate_minimal_spanning_set(self) -> dict:
256+
def _generate_minimal_spanning_set(self) -> Dict[float, List[Trajectory]]:
245257
"""
246258
Generate the minimal spanning set.
247259
@@ -255,7 +267,7 @@ def _generate_minimal_spanning_set(self) -> dict:
255267
a list of trajectories that begin at that angle
256268
257269
"""
258-
quadrant1_end_poses = defaultdict(list)
270+
quadrant1_end_poses: Dict[int, List[Tuple[Any, int]]] = defaultdict(list)
259271

260272
# Since we only compute for quadrant 1 we only need headings between
261273
# 0 and 90 degrees
@@ -338,7 +350,7 @@ def _generate_minimal_spanning_set(self) -> dict:
338350
# we can leverage symmetry to create the complete minimal set
339351
return self._create_complete_minimal_spanning_set(quadrant1_end_poses)
340352

341-
def _flip_angle(self, angle: float, flip_type: Flip) -> float:
353+
def _flip_angle(self, angle: int, flip_type: Flip) -> float:
342354
"""
343355
Return the the appropriate flip of the angle in self.headings.
344356
@@ -370,8 +382,8 @@ def _flip_angle(self, angle: float, flip_type: Flip) -> float:
370382
return self.headings[int(heading_idx)]
371383

372384
def _create_complete_minimal_spanning_set(
373-
self, single_quadrant_minimal_set: dict
374-
) -> dict:
385+
self, single_quadrant_minimal_set: Dict[int, List[Tuple[Any, int]]]
386+
) -> Dict[float, List[Trajectory]]:
375387
"""
376388
Create the full minimal spanning set from a single quadrant set.
377389
@@ -390,7 +402,7 @@ def _create_complete_minimal_spanning_set(
390402
in all quadrants
391403
392404
"""
393-
all_trajectories = defaultdict(list)
405+
all_trajectories: Dict[float, List[Trajectory]] = defaultdict(list)
394406

395407
for start_angle in single_quadrant_minimal_set.keys():
396408

@@ -425,6 +437,9 @@ def _create_complete_minimal_spanning_set(
425437
)
426438
)
427439

440+
if unflipped_trajectory is None or flipped_x_trajectory is None:
441+
raise ValueError('No trajectory was found')
442+
428443
all_trajectories[
429444
unflipped_trajectory.parameters.start_angle
430445
].append(unflipped_trajectory)
@@ -459,6 +474,9 @@ def _create_complete_minimal_spanning_set(
459474
)
460475
)
461476

477+
if unflipped_trajectory is None or flipped_y_trajectory is None:
478+
raise ValueError('No trajectory was found')
479+
462480
all_trajectories[
463481
unflipped_trajectory.parameters.start_angle
464482
].append(unflipped_trajectory)
@@ -513,6 +531,10 @@ def _create_complete_minimal_spanning_set(
513531
)
514532
)
515533

534+
if (unflipped_trajectory is None or flipped_y_trajectory is None or
535+
flipped_x_trajectory is None or flipped_xy_trajectory is None):
536+
raise ValueError('No trajectory was found')
537+
516538
all_trajectories[
517539
unflipped_trajectory.parameters.start_angle
518540
].append(unflipped_trajectory)
@@ -528,7 +550,8 @@ def _create_complete_minimal_spanning_set(
528550

529551
return all_trajectories
530552

531-
def _handle_motion_model(self, spanning_set: dict) -> dict:
553+
def _handle_motion_model(self, spanning_set: Dict[float, List[Trajectory]]
554+
) -> Dict[float, List[Trajectory]]:
532555
"""
533556
Add the appropriate motions for the user supplied motion model.
534557
@@ -565,7 +588,8 @@ def _handle_motion_model(self, spanning_set: dict) -> dict:
565588
print('No handling implemented for Motion Model: ' + f'{self.motion_model}')
566589
raise NotImplementedError
567590

568-
def _add_in_place_turns(self, spanning_set: dict) -> dict:
591+
def _add_in_place_turns(self, spanning_set: Dict[float, List[Trajectory]]
592+
) -> Dict[float, List[Trajectory]]:
569593
"""
570594
Add in place turns to the spanning set.
571595
@@ -623,7 +647,8 @@ def _add_in_place_turns(self, spanning_set: dict) -> dict:
623647

624648
return spanning_set
625649

626-
def _add_horizontal_motions(self, spanning_set: dict) -> dict:
650+
def _add_horizontal_motions(self, spanning_set: Dict[float, List[Trajectory]]
651+
) -> Dict[float, List[Trajectory]]:
627652
"""
628653
Add horizontal sliding motions to the spanning set.
629654
@@ -723,7 +748,7 @@ def _add_horizontal_motions(self, spanning_set: dict) -> dict:
723748

724749
return spanning_set
725750

726-
def run(self):
751+
def run(self) -> Dict[float, List[Trajectory]]:
727752
"""
728753
Run the lattice generator.
729754

nav2_smac_planner/lattice_primitives/py.typed

Whitespace-only changes.

nav2_smac_planner/lattice_primitives/tests/test_lattice_generator.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import unittest
1616

17-
from lattice_generator import LatticeGenerator
17+
from lattice_generator import ConfigDict, LatticeGenerator
1818
import numpy as np
1919

2020
MOTION_MODEL = 'ackermann'
@@ -28,7 +28,7 @@ class TestLatticeGenerator(unittest.TestCase):
2828
"""Contains the unit tests for the TrajectoryGenerator."""
2929

3030
def setUp(self) -> None:
31-
config = {
31+
config: ConfigDict = {
3232
'motion_model': MOTION_MODEL,
3333
'turning_radius': TURNING_RADIUS,
3434
'grid_resolution': GRID_RESOLUTION,
@@ -40,7 +40,7 @@ def setUp(self) -> None:
4040

4141
self.minimal_set = lattice_gen.run()
4242

43-
def test_minimal_set_lengths_are_positive(self):
43+
def test_minimal_set_lengths_are_positive(self) -> None:
4444
# Test that lengths are all positive
4545

4646
for start_angle in self.minimal_set.keys():
@@ -51,7 +51,7 @@ def test_minimal_set_lengths_are_positive(self):
5151
self.assertGreaterEqual(trajectory.parameters.end_straight_length, 0)
5252
self.assertGreaterEqual(trajectory.parameters.total_length, 0)
5353

54-
def test_minimal_set_end_points_lie_on_grid(self):
54+
def test_minimal_set_end_points_lie_on_grid(self) -> None:
5555
# Test that end points lie on the grid resolution
5656

5757
for start_angle in self.minimal_set.keys():
@@ -66,7 +66,7 @@ def test_minimal_set_end_points_lie_on_grid(self):
6666
self.assertAlmostEqual(div_x, np.round(div_x), delta=0.00001)
6767
self.assertAlmostEqual(div_y, np.round(div_y), delta=0.00001)
6868

69-
def test_minimal_set_end_angle_is_correct(self):
69+
def test_minimal_set_end_angle_is_correct(self) -> None:
7070
# Test that end angle agrees with the end angle parameter
7171

7272
for start_angle in self.minimal_set.keys():
@@ -76,7 +76,7 @@ def test_minimal_set_end_angle_is_correct(self):
7676

7777
self.assertEqual(end_point_angle, trajectory.parameters.end_angle)
7878

79-
def test_output_angles_in_correct_range(self):
79+
def test_output_angles_in_correct_range(self) -> None:
8080
# Test that the outputted angles always lie within 0 to 2*pi
8181

8282
for start_angle in self.minimal_set.keys():

0 commit comments

Comments
 (0)