Skip to content

Update smac planner types #4927

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,16 @@ jobs:
linter: ${{ matrix.linter }}
distribution: rolling
package-name: "*"

ament_lint_mypy:
name: ament_mypy
runs-on: ubuntu-latest
container:
image: rostooling/setup-ros-docker:ubuntu-noble-ros-rolling-ros-base-latest
steps:
- uses: actions/checkout@v4
- uses: ros-tooling/[email protected]
with:
linter: mypy
distribution: rolling
package-name: "nav2_smac_planner"
31 changes: 21 additions & 10 deletions nav2_smac_planner/lattice_primitives/generate_motion_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,26 @@
import logging
from pathlib import Path
import time
from typing import Any, cast, Dict, List, TypedDict

import constants
from lattice_generator import LatticeGenerator
from lattice_generator import ConfigDict, LatticeGenerator
import matplotlib.pyplot as plt
import numpy as np
from trajectory import Trajectory

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def handle_arg_parsing():
class HeaderDict(TypedDict):
version: float
date_generated: str
lattice_metadata: Dict[str, Any]
primitives: List[Dict[str, Any]]


def handle_arg_parsing() -> argparse.Namespace:
"""
Handle the parsing of arguments.

Expand Down Expand Up @@ -64,7 +73,8 @@ def handle_arg_parsing():
return parser.parse_args()


def create_heading_angle_list(minimal_set_trajectories: dict) -> list:
def create_heading_angle_list(minimal_set_trajectories: Dict[float, List[Trajectory]]
) -> List[float]:
"""
Create a sorted list of heading angles from the minimal trajectory set.

Expand All @@ -83,7 +93,7 @@ def create_heading_angle_list(minimal_set_trajectories: dict) -> list:
return sorted(heading_angles, key=lambda x: (x < 0, x))


def read_config(config_path) -> dict:
def read_config(config_path: Path) -> ConfigDict:
"""
Read in the user defined parameters via JSON.

Expand All @@ -101,10 +111,11 @@ def read_config(config_path) -> dict:
with open(config_path) as config_file:
config = json.load(config_file)

return config
return cast(ConfigDict, config)


def create_header(config: dict, minimal_set_trajectories: dict) -> dict:
def create_header(config: ConfigDict, minimal_set_trajectories: Dict[float, List[Trajectory]]
) -> HeaderDict:
"""
Create a dict containing all the fields to populate the header with.

Expand All @@ -121,7 +132,7 @@ def create_header(config: dict, minimal_set_trajectories: dict) -> dict:
A dictionary containing the fields to populate the header with

"""
header_dict = {
header_dict: HeaderDict = {
'version': constants.VERSION,
'date_generated': datetime.today().strftime('%Y-%m-%d'),
'lattice_metadata': {},
Expand All @@ -142,7 +153,7 @@ def create_header(config: dict, minimal_set_trajectories: dict) -> dict:


def write_to_json(
output_path: Path, minimal_set_trajectories: dict, config: dict
output_path: Path, minimal_set_trajectories: Dict[float, List[Trajectory]], config: ConfigDict
) -> None:
"""
Write the minimal spanning set to an output file.
Expand Down Expand Up @@ -171,7 +182,7 @@ def write_to_json(
minimal_set_trajectories[start_angle], key=lambda x: x.parameters.end_angle
):

traj_info = {}
traj_info: Dict[str, Any] = {}
traj_info['trajectory_id'] = idx
traj_info['start_angle_index'] = heading_lookup[
trajectory.parameters.start_angle
Expand Down Expand Up @@ -202,7 +213,7 @@ def write_to_json(


def save_visualizations(
visualizations_folder: Path, minimal_set_trajectories: dict
visualizations_folder: Path, minimal_set_trajectories: Dict[float, List[Trajectory]]
) -> None:
"""
Draw the visualizations for every trajectory and save it as an image.
Expand Down
13 changes: 9 additions & 4 deletions nav2_smac_planner/lattice_primitives/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License. Reserved.

from typing import Any, Optional

import numpy as np
from numpy.typing import NDArray


def normalize_angle(angle):
def normalize_angle(angle: float) -> float:
"""
Normalize the angle to between [0, 2pi).

Expand All @@ -37,7 +40,8 @@ def normalize_angle(angle):
return angle


def angle_difference(angle_1, angle_2, left_turn=None):
def angle_difference(angle_1: float, angle_2: float,
left_turn: Optional[float] = None) -> float:
"""
Calculate the difference between two angles based on a given direction.

Expand Down Expand Up @@ -76,7 +80,8 @@ def angle_difference(angle_1, angle_2, left_turn=None):
return 2 * np.pi - abs(angle_1 - angle_2)


def interpolate_yaws(start_angle, end_angle, left_turn, steps):
def interpolate_yaws(start_angle: float, end_angle: float,
left_turn: bool, steps: int) -> Any:
"""
Create equally spaced yaws between two angles.

Expand Down Expand Up @@ -110,7 +115,7 @@ def interpolate_yaws(start_angle, end_angle, left_turn, steps):
return yaws


def get_rotation_matrix(angle):
def get_rotation_matrix(angle: float) -> NDArray[np.floating[Any]]:
"""
Return a rotation matrix that is equivalent to a 2D rotation of angle.

Expand Down
55 changes: 40 additions & 15 deletions nav2_smac_planner/lattice_primitives/lattice_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,25 @@

from collections import defaultdict
from enum import Enum
from typing import Any, Dict, List, Tuple, TypedDict

from helper import angle_difference, interpolate_yaws
import numpy as np
from rtree import index
from trajectory import Path, Trajectory, TrajectoryParameters

from trajectory import AnyFloat, FloatNDArray, Path, Trajectory, TrajectoryParameters

from trajectory_generator import TrajectoryGenerator


class ConfigDict(TypedDict):
grid_resolution: float
turning_radius: float
stopping_threshold: int
num_of_headings: int
motion_model: str


class LatticeGenerator:
"""
Handles all the logic for computing the minimal control set.
Expand All @@ -46,7 +57,7 @@ class Flip(Enum):
Y = 2
BOTH = 3

def __init__(self, config: dict):
def __init__(self, config: ConfigDict):
"""Init the lattice generator from the user supplied config."""
self.trajectory_generator = TrajectoryGenerator(config)
self.grid_resolution = config['grid_resolution']
Expand All @@ -60,7 +71,7 @@ def __init__(self, config: dict):
self.DISTANCE_THRESHOLD = 0.5 * self.grid_resolution
self.ROTATION_THRESHOLD = 0.5 * (2 * np.pi / self.num_of_headings)

def _get_wave_front_points(self, pos: int) -> np.array:
def _get_wave_front_points(self, pos: int) -> FloatNDArray:
"""
Calculate the end points that lie on the wave front.

Expand Down Expand Up @@ -97,7 +108,7 @@ def _get_wave_front_points(self, pos: int) -> np.array:

return np.array(positions)

def _get_heading_discretization(self, number_of_headings: int) -> list:
def _get_heading_discretization(self, number_of_headings: int) -> List[int]:
"""
Calculate the heading discretization based on the number of headings.

Expand Down Expand Up @@ -131,7 +142,8 @@ def _get_heading_discretization(self, number_of_headings: int) -> list:

return sorted([np.arctan2(j, i) for i, j in zip(outer_edge_x, outer_edge_y)])

def _point_to_line_distance(self, p1: np.array, p2: np.array, q: np.array) -> float:
def _point_to_line_distance(self, p1: FloatNDArray, p2: FloatNDArray,
q: FloatNDArray) -> AnyFloat:
"""
Return the shortest distance from a point to a line segment.

Expand Down Expand Up @@ -241,7 +253,7 @@ def _compute_min_trajectory_length(self) -> float:

return self.turning_radius * min(heading_diff)

def _generate_minimal_spanning_set(self) -> dict:
def _generate_minimal_spanning_set(self) -> Dict[float, List[Trajectory]]:
"""
Generate the minimal spanning set.

Expand All @@ -255,7 +267,7 @@ def _generate_minimal_spanning_set(self) -> dict:
a list of trajectories that begin at that angle

"""
quadrant1_end_poses = defaultdict(list)
quadrant1_end_poses: Dict[int, List[Tuple[Any, int]]] = defaultdict(list)

# Since we only compute for quadrant 1 we only need headings between
# 0 and 90 degrees
Expand Down Expand Up @@ -338,7 +350,7 @@ def _generate_minimal_spanning_set(self) -> dict:
# we can leverage symmetry to create the complete minimal set
return self._create_complete_minimal_spanning_set(quadrant1_end_poses)

def _flip_angle(self, angle: float, flip_type: Flip) -> float:
def _flip_angle(self, angle: int, flip_type: Flip) -> float:
"""
Return the the appropriate flip of the angle in self.headings.

Expand Down Expand Up @@ -370,8 +382,8 @@ def _flip_angle(self, angle: float, flip_type: Flip) -> float:
return self.headings[int(heading_idx)]

def _create_complete_minimal_spanning_set(
self, single_quadrant_minimal_set: dict
) -> dict:
self, single_quadrant_minimal_set: Dict[int, List[Tuple[Any, int]]]
) -> Dict[float, List[Trajectory]]:
"""
Create the full minimal spanning set from a single quadrant set.

Expand All @@ -390,7 +402,7 @@ def _create_complete_minimal_spanning_set(
in all quadrants

"""
all_trajectories = defaultdict(list)
all_trajectories: Dict[float, List[Trajectory]] = defaultdict(list)

for start_angle in single_quadrant_minimal_set.keys():

Expand Down Expand Up @@ -425,6 +437,9 @@ def _create_complete_minimal_spanning_set(
)
)

if unflipped_trajectory is None or flipped_x_trajectory is None:
raise ValueError('No trajectory was found')

all_trajectories[
unflipped_trajectory.parameters.start_angle
].append(unflipped_trajectory)
Expand Down Expand Up @@ -459,6 +474,9 @@ def _create_complete_minimal_spanning_set(
)
)

if unflipped_trajectory is None or flipped_y_trajectory is None:
raise ValueError('No trajectory was found')

all_trajectories[
unflipped_trajectory.parameters.start_angle
].append(unflipped_trajectory)
Expand Down Expand Up @@ -513,6 +531,10 @@ def _create_complete_minimal_spanning_set(
)
)

if (unflipped_trajectory is None or flipped_y_trajectory is None or
flipped_x_trajectory is None or flipped_xy_trajectory is None):
raise ValueError('No trajectory was found')

all_trajectories[
unflipped_trajectory.parameters.start_angle
].append(unflipped_trajectory)
Expand All @@ -528,7 +550,8 @@ def _create_complete_minimal_spanning_set(

return all_trajectories

def _handle_motion_model(self, spanning_set: dict) -> dict:
def _handle_motion_model(self, spanning_set: Dict[float, List[Trajectory]]
) -> Dict[float, List[Trajectory]]:
"""
Add the appropriate motions for the user supplied motion model.

Expand Down Expand Up @@ -565,7 +588,8 @@ def _handle_motion_model(self, spanning_set: dict) -> dict:
print('No handling implemented for Motion Model: ' + f'{self.motion_model}')
raise NotImplementedError

def _add_in_place_turns(self, spanning_set: dict) -> dict:
def _add_in_place_turns(self, spanning_set: Dict[float, List[Trajectory]]
) -> Dict[float, List[Trajectory]]:
"""
Add in place turns to the spanning set.

Expand Down Expand Up @@ -623,7 +647,8 @@ def _add_in_place_turns(self, spanning_set: dict) -> dict:

return spanning_set

def _add_horizontal_motions(self, spanning_set: dict) -> dict:
def _add_horizontal_motions(self, spanning_set: Dict[float, List[Trajectory]]
) -> Dict[float, List[Trajectory]]:
"""
Add horizontal sliding motions to the spanning set.

Expand Down Expand Up @@ -723,7 +748,7 @@ def _add_horizontal_motions(self, spanning_set: dict) -> dict:

return spanning_set

def run(self):
def run(self) -> Dict[float, List[Trajectory]]:
"""
Run the lattice generator.

Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import unittest

from lattice_generator import LatticeGenerator
from lattice_generator import ConfigDict, LatticeGenerator
import numpy as np

MOTION_MODEL = 'ackermann'
Expand All @@ -28,7 +28,7 @@ class TestLatticeGenerator(unittest.TestCase):
"""Contains the unit tests for the TrajectoryGenerator."""

def setUp(self) -> None:
config = {
config: ConfigDict = {
'motion_model': MOTION_MODEL,
'turning_radius': TURNING_RADIUS,
'grid_resolution': GRID_RESOLUTION,
Expand All @@ -40,7 +40,7 @@ def setUp(self) -> None:

self.minimal_set = lattice_gen.run()

def test_minimal_set_lengths_are_positive(self):
def test_minimal_set_lengths_are_positive(self) -> None:
# Test that lengths are all positive

for start_angle in self.minimal_set.keys():
Expand All @@ -51,7 +51,7 @@ def test_minimal_set_lengths_are_positive(self):
self.assertGreaterEqual(trajectory.parameters.end_straight_length, 0)
self.assertGreaterEqual(trajectory.parameters.total_length, 0)

def test_minimal_set_end_points_lie_on_grid(self):
def test_minimal_set_end_points_lie_on_grid(self) -> None:
# Test that end points lie on the grid resolution

for start_angle in self.minimal_set.keys():
Expand All @@ -66,7 +66,7 @@ def test_minimal_set_end_points_lie_on_grid(self):
self.assertAlmostEqual(div_x, np.round(div_x), delta=0.00001)
self.assertAlmostEqual(div_y, np.round(div_y), delta=0.00001)

def test_minimal_set_end_angle_is_correct(self):
def test_minimal_set_end_angle_is_correct(self) -> None:
# Test that end angle agrees with the end angle parameter

for start_angle in self.minimal_set.keys():
Expand All @@ -76,7 +76,7 @@ def test_minimal_set_end_angle_is_correct(self):

self.assertEqual(end_point_angle, trajectory.parameters.end_angle)

def test_output_angles_in_correct_range(self):
def test_output_angles_in_correct_range(self) -> None:
# Test that the outputted angles always lie within 0 to 2*pi

for start_angle in self.minimal_set.keys():
Expand Down
Loading
Loading