Skip to content

Commit

Permalink
spiff up comments
Browse files Browse the repository at this point in the history
  • Loading branch information
SchwartzCode committed Mar 8, 2025
1 parent 29e9b73 commit bb7bb85
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 28 deletions.
18 changes: 13 additions & 5 deletions PathPlanning/TimeBasedPathPlanning/GridWithDynamicObstacles.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ class ObstacleArrangement(Enum):
# Obstacles start in a line in y at center of grid and move side-to-side in x
ARRANGEMENT1 = 1

"""
Generates a 2d numpy array with lists for elements.
"""
def empty_2d_array_of_lists(x: int, y: int) -> np.ndarray:
arr = np.empty((x, y), dtype=object)
arr[:] = [[[] for _ in range(y)] for _ in range(x)]
return arr

class Grid:
# Set in constructor
Expand Down Expand Up @@ -239,17 +246,20 @@ def get_obstacle_positions_at_time(self, t: int) -> tuple[list[int], list[int]]:
return (x_positions, y_positions)

"""
Returns safe intervals for each cell
Returns safe intervals for each cell.
"""
def get_safe_intervals(self) -> np.ndarray:
intervals = np.empty((self.grid_size[0], self.grid_size[1]), dtype=object)
intervals[:] = [[[] for _ in range(intervals.shape[1])] for _ in range(intervals.shape[0])]
intervals = empty_2d_array_of_lists(self.grid_size[0], self.grid_size[1])
for x in range(intervals.shape[0]):
for y in range(intervals.shape[1]):
intervals[x, y] = self.get_safe_intervals_at_cell(Position(x, y))

return intervals

"""
Generate the safe intervals for a given cell. The intervals will be in order of start time.
ex: Interval (2, 3) will be before Interval (4, 5)
"""
def get_safe_intervals_at_cell(self, cell: Position) -> list[Interval]:
vals = self.reservation_matrix[cell.x, cell.y, :]
# Find where the array is zero
Expand All @@ -271,10 +281,8 @@ def get_safe_intervals_at_cell(self, cell: Position) -> list[Interval]:
end_indices = np.append(end_indices, len(vals) - 1)

# Create pairs of (first zero, last zero)
# TODO - this is generating np.int instead of normal int, is that alright?
intervals = [Interval(start, end) for start, end in zip(start_indices, end_indices)]


for interval in intervals:
if interval.start_time == interval.end_time:
# TODO: hate this modification in the loop
Expand Down
59 changes: 36 additions & 23 deletions PathPlanning/TimeBasedPathPlanning/SafeInterval.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
"""
Safe interval path planner
This script implements a safe-interval path planner for a 2d grid with dynamic obstacles. It is faster than
SpaceTime A* because it reduces the number of redundant node expansions by pre-computing regions of adjacent
time steps that are safe ("safe intervals") at each position. This allows the algorithm to skip expanding nodes
that are in intervals that have already been visited earlier.
TODO: populate docstring
Reference: https://www.cs.cmu.edu/~maxim/files/sipp_icra11.pdf
"""

import numpy as np
Expand All @@ -11,6 +15,7 @@
Interval,
ObstacleArrangement,
Position,
empty_2d_array_of_lists,
)
import heapq
import random
Expand All @@ -25,9 +30,9 @@

@dataclass()
# Note: Total_ordering is used instead of adding `order=True` to the @dataclass decorator because
# this class needs to override the __lt__ and __eq__ methods to ignore parent_index. Parent
# index is just used to track the path found by the algorithm, and has no effect on the quality
# of a node.
# this class needs to override the __lt__ and __eq__ methods to ignore parent_index. The Parent
# index and interval member variables are just used to track the path found by the algorithm,
# and has no effect on the quality of a node.
@total_ordering
class Node:
position: Position
Expand All @@ -43,16 +48,16 @@ class Node:
def __lt__(self, other: object):
if not isinstance(other, Node):
return NotImplementedError(f"Cannot compare Node with object of type: {type(other)}")
# TODO: assumption that these two carry all the info needed for intervals. I think that makes sense but should think about it
return (self.time + self.heuristic) < (other.time + other.heuristic)

"""
TODO - note about interval being included here
Equality only cares about position and time. Heuristic and interval will always be the same for a given
(position, time) pairing, so they are not considered in equality.
"""
def __eq__(self, other: object):
if not isinstance(other, Node):
return NotImplementedError(f"Cannot compare Node with object of type: {type(other)}")
return self.position == other.position and self.time == other.time and self.interval == other.interval
return self.position == other.position and self.time == other.time

@dataclass
class EntryTimeAndInterval:
Expand Down Expand Up @@ -103,6 +108,12 @@ def __init__(self, grid: Grid, start: Position, goal: Position):
self.start = start
self.goal = goal

"""
Generate a plan given the loaded problem statement. Raises an exception if it fails to find a path.
Arguments:
verbose (bool): set to True to print debug information
"""
def plan(self, verbose: bool = False) -> NodePath:

safe_intervals = self.grid.get_safe_intervals()
Expand All @@ -114,11 +125,7 @@ def plan(self, verbose: bool = False) -> NodePath:
)

expanded_list: list[Node] = []
# TODO: copy pasta from Grid file
# 2d np array of lists of (entry time, interval tuples)
# TODO: use a dataclass for the tuple
visited_intervals = np.empty((self.grid.grid_size[0], self.grid.grid_size[1]), dtype=object)
visited_intervals[:] = [[[] for _ in range(visited_intervals.shape[1])] for _ in range(visited_intervals.shape[0])]
visited_intervals = empty_2d_array_of_lists(self.grid.grid_size[0], self.grid.grid_size[1])
while open_set:
expanded_node: Node = heapq.heappop(open_set)
if verbose:
Expand Down Expand Up @@ -154,9 +161,8 @@ def plan(self, verbose: bool = False) -> NodePath:
raise Exception("No path found")

"""
Generate possible successors of the provided `parent_node`
Generate list of possible successors of the provided `parent_node` that are worth expanding
"""
# TODO: is intervals being passed by ref? (i think so?)
def generate_successors(
self, parent_node: Node, parent_node_idx: int, intervals: np.ndarray, visited_intervals: np.ndarray
) -> list[Node]:
Expand All @@ -177,16 +183,16 @@ def generate_successors(

new_cell_intervals: list[Interval] = intervals[new_pos.x, new_pos.y]
for interval in new_cell_intervals:
# if interval ends before current starts, skip
if interval.end_time < current_interval.start_time:
continue

# if interval starts after current ends, break
# TODO: assumption here that intervals are sorted (they should be)
# assumption: intervals are sorted by start time, so all future intervals will hit this condition as well
if interval.start_time > current_interval.end_time:
break

# if we have already expanded a node in this interval with a <= starting time, continue
# if interval ends before current starts, skip
if interval.end_time < current_interval.start_time:
continue

# if we have already expanded a node in this interval with a <= starting time, skip
better_node_expanded = False
for visited in visited_intervals[new_pos.x, new_pos.y]:
if interval == visited.interval and visited.entry_time <= parent_node.time + 1:
Expand All @@ -195,14 +201,14 @@ def generate_successors(
if better_node_expanded:
continue

# We know there is some overlap. Generate successor at the earliest possible time the
# We know there is a node worth expanding. Generate successor at the earliest possible time the
# new interval can be entered
for possible_t in range(max(parent_node.time + 1, interval.start_time), min(current_interval.end_time, interval.end_time)):
if self.grid.valid_position(new_pos, possible_t):
new_nodes.append(Node(
new_pos,
# entry is max of interval start and parent node start time (get there as soon as possible)
max(parent_node.time + 1, interval.start_time),
# entry is max of interval start and parent node time + 1 (get there as soon as possible)
max(interval.start_time, parent_node.time + 1),
self.calculate_heuristic(new_pos),
parent_node_idx,
interval,
Expand All @@ -212,11 +218,18 @@ def generate_successors(

return new_nodes

"""
Calculate the heuristic for a given position - Manhattan distance to the goal
"""
def calculate_heuristic(self, position) -> int:
diff = self.goal - position
return abs(diff.x) + abs(diff.y)


"""
Adds a new entry to the visited intervals array. If the entry is already present, the entry time is updated if the new
entry time is better. Otherwise, the entry is added to `visited_intervals` at the position of `expanded_node`.
"""
def add_entry_to_visited_intervals_array(entry_time_and_interval: EntryTimeAndInterval, visited_intervals: np.ndarray, expanded_node: Node):
# if entry is present, update entry time if better
for existing_entry_and_interval in visited_intervals[expanded_node.position.x, expanded_node.position.y]:
Expand Down

0 comments on commit bb7bb85

Please sign in to comment.