Skip to content

Commit 732ed3e

Browse files
authoredJan 26, 2023
clean up informed_rrt_star.py (AtsushiSakai#785)
* clean up informed_rrt_star.py * clean up informed_rrt_star.py
1 parent 489ee5c commit 732ed3e

File tree

1 file changed

+126
-135
lines changed

1 file changed

+126
-135
lines changed
 

‎PathPlanning/InformedRRTStar/informed_rrt_star.py

+126-135
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"""
1212
import sys
1313
import pathlib
14+
1415
sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))
1516

1617
import copy
@@ -27,129 +28,127 @@
2728

2829
class InformedRRTStar:
2930

30-
def __init__(self, start, goal,
31-
obstacleList, randArea,
32-
expandDis=0.5, goalSampleRate=10, maxIter=200):
31+
def __init__(self, start, goal, obstacle_list, rand_area, expand_dis=0.5,
32+
goal_sample_rate=10, max_iter=200):
3333

3434
self.start = Node(start[0], start[1])
3535
self.goal = Node(goal[0], goal[1])
36-
self.min_rand = randArea[0]
37-
self.max_rand = randArea[1]
38-
self.expand_dis = expandDis
39-
self.goal_sample_rate = goalSampleRate
40-
self.max_iter = maxIter
41-
self.obstacle_list = obstacleList
36+
self.min_rand = rand_area[0]
37+
self.max_rand = rand_area[1]
38+
self.expand_dis = expand_dis
39+
self.goal_sample_rate = goal_sample_rate
40+
self.max_iter = max_iter
41+
self.obstacle_list = obstacle_list
4242
self.node_list = None
4343

4444
def informed_rrt_star_search(self, animation=True):
4545

4646
self.node_list = [self.start]
4747
# max length we expect to find in our 'informed' sample space,
4848
# starts as infinite
49-
cBest = float('inf')
50-
solutionSet = set()
49+
c_best = float('inf')
50+
solution_set = set()
5151
path = None
5252

5353
# Computing the sampling space
54-
cMin = math.sqrt(pow(self.start.x - self.goal.x, 2)
55-
+ pow(self.start.y - self.goal.y, 2))
56-
xCenter = np.array([[(self.start.x + self.goal.x) / 2.0],
57-
[(self.start.y + self.goal.y) / 2.0], [0]])
58-
a1 = np.array([[(self.goal.x - self.start.x) / cMin],
59-
[(self.goal.y - self.start.y) / cMin], [0]])
54+
c_min = math.hypot(self.start.x - self.goal.x,
55+
self.start.y - self.goal.y)
56+
x_center = np.array([[(self.start.x + self.goal.x) / 2.0],
57+
[(self.start.y + self.goal.y) / 2.0], [0]])
58+
a1 = np.array([[(self.goal.x - self.start.x) / c_min],
59+
[(self.goal.y - self.start.y) / c_min], [0]])
6060

6161
e_theta = math.atan2(a1[1], a1[0])
6262
# first column of identity matrix transposed
6363
id1_t = np.array([1.0, 0.0, 0.0]).reshape(1, 3)
64-
M = a1 @ id1_t
65-
U, S, Vh = np.linalg.svd(M, True, True)
66-
C = np.dot(np.dot(U, np.diag(
67-
[1.0, 1.0, np.linalg.det(U) * np.linalg.det(np.transpose(Vh))])),
68-
Vh)
64+
m = a1 @ id1_t
65+
u, s, vh = np.linalg.svd(m, True, True)
66+
c = u @ np.diag(
67+
[1.0, 1.0,
68+
np.linalg.det(u) * np.linalg.det(np.transpose(vh))]) @ vh
6969

7070
for i in range(self.max_iter):
71-
# Sample space is defined by cBest
72-
# cMin is the minimum distance between the start point and the goal
73-
# xCenter is the midpoint between the start and the goal
74-
# cBest changes when a new path is found
71+
# Sample space is defined by c_best
72+
# c_min is the minimum distance between the start point and
73+
# the goal x_center is the midpoint between the start and the
74+
# goal c_best changes when a new path is found
7575

76-
rnd = self.informed_sample(cBest, cMin, xCenter, C)
76+
rnd = self.informed_sample(c_best, c_min, x_center, c)
7777
n_ind = self.get_nearest_list_index(self.node_list, rnd)
78-
nearestNode = self.node_list[n_ind]
78+
nearest_node = self.node_list[n_ind]
7979
# steer
80-
theta = math.atan2(rnd[1] - nearestNode.y, rnd[0] - nearestNode.x)
81-
newNode = self.get_new_node(theta, n_ind, nearestNode)
82-
d = self.line_cost(nearestNode, newNode)
80+
theta = math.atan2(rnd[1] - nearest_node.y,
81+
rnd[0] - nearest_node.x)
82+
new_node = self.get_new_node(theta, n_ind, nearest_node)
83+
d = self.line_cost(nearest_node, new_node)
8384

84-
noCollision = self.check_collision(nearestNode, theta, d)
85+
no_collision = self.check_collision(nearest_node, theta, d)
8586

86-
if noCollision:
87-
nearInds = self.find_near_nodes(newNode)
88-
newNode = self.choose_parent(newNode, nearInds)
87+
if no_collision:
88+
near_inds = self.find_near_nodes(new_node)
89+
new_node = self.choose_parent(new_node, near_inds)
8990

90-
self.node_list.append(newNode)
91-
self.rewire(newNode, nearInds)
91+
self.node_list.append(new_node)
92+
self.rewire(new_node, near_inds)
9293

93-
if self.is_near_goal(newNode):
94-
if self.check_segment_collision(newNode.x, newNode.y,
94+
if self.is_near_goal(new_node):
95+
if self.check_segment_collision(new_node.x, new_node.y,
9596
self.goal.x, self.goal.y):
96-
solutionSet.add(newNode)
97-
lastIndex = len(self.node_list) - 1
98-
tempPath = self.get_final_course(lastIndex)
99-
tempPathLen = self.get_path_len(tempPath)
100-
if tempPathLen < cBest:
101-
path = tempPath
102-
cBest = tempPathLen
97+
solution_set.add(new_node)
98+
last_index = len(self.node_list) - 1
99+
temp_path = self.get_final_course(last_index)
100+
temp_path_len = self.get_path_len(temp_path)
101+
if temp_path_len < c_best:
102+
path = temp_path
103+
c_best = temp_path_len
103104
if animation:
104-
self.draw_graph(xCenter=xCenter,
105-
cBest=cBest, cMin=cMin,
105+
self.draw_graph(x_center=x_center, c_best=c_best, c_min=c_min,
106106
e_theta=e_theta, rnd=rnd)
107107

108108
return path
109109

110-
def choose_parent(self, newNode, nearInds):
111-
if len(nearInds) == 0:
112-
return newNode
110+
def choose_parent(self, new_node, near_inds):
111+
if len(near_inds) == 0:
112+
return new_node
113113

114-
dList = []
115-
for i in nearInds:
116-
dx = newNode.x - self.node_list[i].x
117-
dy = newNode.y - self.node_list[i].y
114+
d_list = []
115+
for i in near_inds:
116+
dx = new_node.x - self.node_list[i].x
117+
dy = new_node.y - self.node_list[i].y
118118
d = math.hypot(dx, dy)
119119
theta = math.atan2(dy, dx)
120120
if self.check_collision(self.node_list[i], theta, d):
121-
dList.append(self.node_list[i].cost + d)
121+
d_list.append(self.node_list[i].cost + d)
122122
else:
123-
dList.append(float('inf'))
123+
d_list.append(float('inf'))
124124

125-
minCost = min(dList)
126-
minInd = nearInds[dList.index(minCost)]
125+
min_cost = min(d_list)
126+
min_ind = near_inds[d_list.index(min_cost)]
127127

128-
if minCost == float('inf'):
128+
if min_cost == float('inf'):
129129
print("min cost is inf")
130-
return newNode
130+
return new_node
131131

132-
newNode.cost = minCost
133-
newNode.parent = minInd
132+
new_node.cost = min_cost
133+
new_node.parent = min_ind
134134

135-
return newNode
135+
return new_node
136136

137-
def find_near_nodes(self, newNode):
137+
def find_near_nodes(self, new_node):
138138
n_node = len(self.node_list)
139139
r = 50.0 * math.sqrt((math.log(n_node) / n_node))
140-
d_list = [(node.x - newNode.x) ** 2 + (node.y - newNode.y) ** 2
141-
for node in self.node_list]
140+
d_list = [(node.x - new_node.x) ** 2 + (node.y - new_node.y) ** 2 for
141+
node in self.node_list]
142142
near_inds = [d_list.index(i) for i in d_list if i <= r ** 2]
143143
return near_inds
144144

145-
def informed_sample(self, cMax, cMin, xCenter, C):
146-
if cMax < float('inf'):
147-
r = [cMax / 2.0,
148-
math.sqrt(cMax ** 2 - cMin ** 2) / 2.0,
149-
math.sqrt(cMax ** 2 - cMin ** 2) / 2.0]
150-
L = np.diag(r)
151-
xBall = self.sample_unit_ball()
152-
rnd = np.dot(np.dot(C, L), xBall) + xCenter
145+
def informed_sample(self, c_max, c_min, x_center, c):
146+
if c_max < float('inf'):
147+
r = [c_max / 2.0, math.sqrt(c_max ** 2 - c_min ** 2) / 2.0,
148+
math.sqrt(c_max ** 2 - c_min ** 2) / 2.0]
149+
rl = np.diag(r)
150+
x_ball = self.sample_unit_ball()
151+
rnd = np.dot(np.dot(c, rl), x_ball) + x_center
153152
rnd = [rnd[(0, 0)], rnd[(1, 0)]]
154153
else:
155154
rnd = self.sample_free_space()
@@ -179,59 +178,58 @@ def sample_free_space(self):
179178

180179
@staticmethod
181180
def get_path_len(path):
182-
pathLen = 0
181+
path_len = 0
183182
for i in range(1, len(path)):
184183
node1_x = path[i][0]
185184
node1_y = path[i][1]
186185
node2_x = path[i - 1][0]
187186
node2_y = path[i - 1][1]
188-
pathLen += math.sqrt((node1_x - node2_x)
189-
** 2 + (node1_y - node2_y) ** 2)
187+
path_len += math.hypot(node1_x - node2_x, node1_y - node2_y)
190188

191-
return pathLen
189+
return path_len
192190

193191
@staticmethod
194192
def line_cost(node1, node2):
195193
return math.hypot(node1.x - node2.x, node1.y - node2.y)
196194

197195
@staticmethod
198196
def get_nearest_list_index(nodes, rnd):
199-
dList = [(node.x - rnd[0]) ** 2
200-
+ (node.y - rnd[1]) ** 2 for node in nodes]
201-
minIndex = dList.index(min(dList))
202-
return minIndex
197+
d_list = [(node.x - rnd[0]) ** 2 + (node.y - rnd[1]) ** 2 for node in
198+
nodes]
199+
min_index = d_list.index(min(d_list))
200+
return min_index
203201

204-
def get_new_node(self, theta, n_ind, nearestNode):
205-
newNode = copy.deepcopy(nearestNode)
202+
def get_new_node(self, theta, n_ind, nearest_node):
203+
new_node = copy.deepcopy(nearest_node)
206204

207-
newNode.x += self.expand_dis * math.cos(theta)
208-
newNode.y += self.expand_dis * math.sin(theta)
205+
new_node.x += self.expand_dis * math.cos(theta)
206+
new_node.y += self.expand_dis * math.sin(theta)
209207

210-
newNode.cost += self.expand_dis
211-
newNode.parent = n_ind
212-
return newNode
208+
new_node.cost += self.expand_dis
209+
new_node.parent = n_ind
210+
return new_node
213211

214212
def is_near_goal(self, node):
215213
d = self.line_cost(node, self.goal)
216214
if d < self.expand_dis:
217215
return True
218216
return False
219217

220-
def rewire(self, newNode, nearInds):
218+
def rewire(self, new_node, near_inds):
221219
n_node = len(self.node_list)
222-
for i in nearInds:
223-
nearNode = self.node_list[i]
220+
for i in near_inds:
221+
near_node = self.node_list[i]
224222

225-
d = math.hypot(nearNode.x - newNode.x, nearNode.y - newNode.y)
223+
d = math.hypot(near_node.x - new_node.x, near_node.y - new_node.y)
226224

227-
s_cost = newNode.cost + d
225+
s_cost = new_node.cost + d
228226

229-
if nearNode.cost > s_cost:
230-
theta = math.atan2(newNode.y - nearNode.y,
231-
newNode.x - nearNode.x)
232-
if self.check_collision(nearNode, theta, d):
233-
nearNode.parent = n_node - 1
234-
nearNode.cost = s_cost
227+
if near_node.cost > s_cost:
228+
theta = math.atan2(new_node.y - near_node.y,
229+
new_node.x - near_node.x)
230+
if self.check_collision(near_node, theta, d):
231+
near_node.parent = n_node - 1
232+
near_node.cost = s_cost
235233

236234
@staticmethod
237235
def distance_squared_point_to_segment(v, w, p):
@@ -251,45 +249,44 @@ def distance_squared_point_to_segment(v, w, p):
251249
def check_segment_collision(self, x1, y1, x2, y2):
252250
for (ox, oy, size) in self.obstacle_list:
253251
dd = self.distance_squared_point_to_segment(
254-
np.array([x1, y1]),
255-
np.array([x2, y2]),
256-
np.array([ox, oy]))
252+
np.array([x1, y1]), np.array([x2, y2]), np.array([ox, oy]))
257253
if dd <= size ** 2:
258254
return False # collision
259255
return True
260256

261-
def check_collision(self, nearNode, theta, d):
262-
tmpNode = copy.deepcopy(nearNode)
263-
end_x = tmpNode.x + math.cos(theta) * d
264-
end_y = tmpNode.y + math.sin(theta) * d
265-
return self.check_segment_collision(tmpNode.x, tmpNode.y, end_x, end_y)
257+
def check_collision(self, near_node, theta, d):
258+
tmp_node = copy.deepcopy(near_node)
259+
end_x = tmp_node.x + math.cos(theta) * d
260+
end_y = tmp_node.y + math.sin(theta) * d
261+
return self.check_segment_collision(tmp_node.x, tmp_node.y,
262+
end_x, end_y)
266263

267-
def get_final_course(self, lastIndex):
264+
def get_final_course(self, last_index):
268265
path = [[self.goal.x, self.goal.y]]
269-
while self.node_list[lastIndex].parent is not None:
270-
node = self.node_list[lastIndex]
266+
while self.node_list[last_index].parent is not None:
267+
node = self.node_list[last_index]
271268
path.append([node.x, node.y])
272-
lastIndex = node.parent
269+
last_index = node.parent
273270
path.append([self.start.x, self.start.y])
274271
return path
275272

276-
def draw_graph(self, xCenter=None, cBest=None, cMin=None, e_theta=None,
273+
def draw_graph(self, x_center=None, c_best=None, c_min=None, e_theta=None,
277274
rnd=None):
278275
plt.clf()
279276
# for stopping simulation with the esc key.
280277
plt.gcf().canvas.mpl_connect(
281-
'key_release_event',
282-
lambda event: [exit(0) if event.key == 'escape' else None])
278+
'key_release_event', lambda event:
279+
[exit(0) if event.key == 'escape' else None])
283280
if rnd is not None:
284281
plt.plot(rnd[0], rnd[1], "^k")
285-
if cBest != float('inf'):
286-
self.plot_ellipse(xCenter, cBest, cMin, e_theta)
282+
if c_best != float('inf'):
283+
self.plot_ellipse(x_center, c_best, c_min, e_theta)
287284

288285
for node in self.node_list:
289286
if node.parent is not None:
290287
if node.x or node.y is not None:
291-
plt.plot([node.x, self.node_list[node.parent].x], [
292-
node.y, self.node_list[node.parent].y], "-g")
288+
plt.plot([node.x, self.node_list[node.parent].x],
289+
[node.y, self.node_list[node.parent].y], "-g")
293290

294291
for (ox, oy, size) in self.obstacle_list:
295292
plt.plot(ox, oy, "ok", ms=30 * size)
@@ -301,13 +298,13 @@ def draw_graph(self, xCenter=None, cBest=None, cMin=None, e_theta=None,
301298
plt.pause(0.01)
302299

303300
@staticmethod
304-
def plot_ellipse(xCenter, cBest, cMin, e_theta): # pragma: no cover
301+
def plot_ellipse(x_center, c_best, c_min, e_theta): # pragma: no cover
305302

306-
a = math.sqrt(cBest ** 2 - cMin ** 2) / 2.0
307-
b = cBest / 2.0
303+
a = math.sqrt(c_best ** 2 - c_min ** 2) / 2.0
304+
b = c_best / 2.0
308305
angle = math.pi / 2.0 - e_theta
309-
cx = xCenter[0]
310-
cy = xCenter[1]
306+
cx = x_center[0]
307+
cy = x_center[1]
311308
t = np.arange(0, 2 * math.pi + 0.1, 0.1)
312309
x = [a * math.cos(it) for it in t]
313310
y = [b * math.sin(it) for it in t]
@@ -331,18 +328,12 @@ def main():
331328
print("Start informed rrt star planning")
332329

333330
# create obstacles
334-
obstacleList = [
335-
(5, 5, 0.5),
336-
(9, 6, 1),
337-
(7, 5, 1),
338-
(1, 5, 1),
339-
(3, 6, 1),
340-
(7, 9, 1)
341-
]
331+
obstacle_list = [(5, 5, 0.5), (9, 6, 1), (7, 5, 1), (1, 5, 1), (3, 6, 1),
332+
(7, 9, 1)]
342333

343334
# Set params
344-
rrt = InformedRRTStar(start=[0, 0], goal=[5, 10],
345-
randArea=[-2, 15], obstacleList=obstacleList)
335+
rrt = InformedRRTStar(start=[0, 0], goal=[5, 10], rand_area=[-2, 15],
336+
obstacle_list=obstacle_list)
346337
path = rrt.informed_rrt_star_search(animation=show_animation)
347338
print("Done!!")
348339

0 commit comments

Comments
 (0)
Please sign in to comment.