-
rachelmoan authoredrachelmoan authored
RRT.py 7.88 KiB
"""
RRT implementation
"""
import math
import numpy as np
from guided_mrmp.utils import Node, Env
class RRT:
def __init__(self, env, s_start, s_goal, step_len, goal_sample_rate, iter_max, sampled_vertices=None):
self.s_start = Node(s_start,s_start,0,0)
self.s_goal = Node(s_goal,s_goal, 0,0)
self.step_len = step_len
self.goal_sample_rate = goal_sample_rate
self.iter_max = iter_max
if sampled_vertices is None:
self.sampled_vertices = [self.s_start]
else:
self.sampled_vertices = sampled_vertices
self.env = env
# self.plotting = Plotting(self.env)
# self.utils = utils.Utils()
self.x_range = self.env.boundary[0]
self.y_range = self.env.boundary[1]
self.obs_circle = self.env.circle_obs
self.obs_rectangle = self.env.rect_obs
def dist(self, node1, node2):
return math.hypot(node2.x - node1.x, node2.y - node1.y)
def angle(self, node1, node2):
return math.atan2(node2.y - node1.y, node2.x - node1.x)
def is_collision(self, start, end):
if self.env.is_inside_obs(start) or self.env.is_inside_obs(end):
return True
o, d = self.env.get_ray(start, end)
obs_vertex = self.env.get_obs_vertex()
for (v1, v2, v3, v4) in obs_vertex:
if self.env.is_intersect_rec(start, end, o, d, v1, v2):
return True
if self.env.is_intersect_rec(start, end, o, d, v2, v3):
return True
if self.env.is_intersect_rec(start, end, o, d, v3, v4):
return True
if self.env.is_intersect_rec(start, end, o, d, v4, v1):
return True
for (x, y, r) in self.obs_circle:
if self.env.is_intersect_circle(o, d, [x, y], r):
return True
return False
def get_nearest_node(self, node_list, node):
# find nearest neighbor
dist = [self.dist(node, nd) for nd in node_list]
node_near = node_list[int(np.argmin(dist))]
# regular and generate new node
dist, theta = self.dist(node_near, node), self.angle(node_near, node)
dist = min(self.max_dist, dist)
node_new = Node((node_near.x + dist * math.cos(theta),
(node_near.y + dist * math.sin(theta))),
node_near.current, node_near.g + dist, 0)
# obstacle check
if not self.is_collision(node_new, node_near):
# rewire optimization
for node_n in node_list:
# check if the node is within optimization distance
dist_to_new_node = self.dist(node_n, node_new)
if dist_to_new_node < self.r:
cost = node_n.g + dist_to_new_node
# update new sample node's cost and parent
if node_new.g > cost and not self.is_collision(node_n, node_new):
node_new.parent = node_n.current
node_new.g = cost
else:
# update node's cost inside the radius
cost = node_new.g + dist_to_new_node
if node_n.g > cost and not self.is_collision(node_n, node_new):
node_n.parent = node_new.current
node_n.g = cost
else:
continue
return node_new
def generate_random_node(self, goal_sample_rate):
delta = self.env.delta
if np.random.random() > goal_sample_rate:
# return Node((np.random.uniform(self.x_range[0] + delta, self.x_range[1] - delta),
# np.random.uniform(self.y_range[0] + delta, self.y_range[1] - delta)))
return Node((np.random.randint(self.x_range[0] + delta, self.x_range[1]),
np.random.randint(self.y_range[0] + delta, self.y_range[1])))
return self.s_goal
def nearest_neighbor(self, node_list, node):
# find nearest neighbor
dist = [self.dist(node, nd) for nd in node_list]
node_near = node_list[int(np.argmin(dist))]
# regular and generate new node
dist, theta = self.dist(node_near, node), self.angle(node_near, node)
dist = min(self.step_len, dist)
node_new = Node((node_near.x + dist * math.cos(theta),
(node_near.y + dist * math.sin(theta))),
node_near.current, node_near.g + dist, 0)
# obstacle check
if self.is_collision(node_new, node_near):
return None
return node_new
def new_state(self, node_start, node_end):
dist, theta = self.get_distance_and_angle(node_start, node_end)
dist = min(self.step_len, dist)
node_new = Node((node_start.x + dist * math.cos(theta),
node_start.y + dist * math.sin(theta)))
node_new.parent = node_start.current
return node_new
def extract_path(self, node_end):
path = [(self.s_goal.x, self.s_goal.y)]
node_now = node_end
while node_now.parent is not None:
node_now = node_now.parent
path.append((node_now.x, node_now.y))
return path
def get_distance_and_angle(self, node_start, node_end):
dx = node_end.x - node_start.x
dy = node_end.y - node_start.y
return math.hypot(dx, dy), math.atan2(dy, dx)
def extractPath(self, closed_set):
"""
Extract the path based on the CLOSED set.
Parameters:
closed_set (list): CLOSED set
Returns
cost (float): the cost of planning path
path (list): the planning path
"""
node = closed_set[closed_set.index(self.s_goal)]
path = [node.current]
cost = node.g
while node != self.s_start:
node_parent = closed_set[closed_set.index(Node(node.parent, None, None, None))]
node = node_parent
path.append(node.current)
# return the cost, path, and the tree of the sampled vertices
return cost, path, closed_set
def get_distance_and_angle(self, node_start, node_end):
dx = node_end.x - node_start.x
dy = node_end.y - node_start.y
return math.hypot(dx, dy), math.atan2(dy, dx)
def plan(self):
for _ in range(self.iter_max):
# generate a random node in the map
node_rand = self.generate_random_node(self.goal_sample_rate)
if node_rand in self.sampled_vertices:
continue
node_new = self.nearest_neighbor(self.sampled_vertices, node_rand)
# print(f"node new = {node_new}")
# node_new = self.new_state(node_near, node_rand)
if node_new:
self.sampled_vertices.append(node_new)
dist, __ = self.get_distance_and_angle(node_new, self.s_goal)
if dist <= self.step_len and not self.is_collision(node_new, self.s_goal):
self.s_goal.parent = node_new.current
self.s_goal.g = node_new.g + self.dist(self.s_goal, node_new)
self.sampled_vertices.append(self.s_goal)
return self.extractPath(self.sampled_vertices)
return 0, None, None
def run(self):
cost, path, tree = self.plan()
# self.plotting.animation([path], "RRT", cost, self.sampled_vertices)
# print(f"num of sampled vertices = {len(self.sampled_vertices)}")
# for node in self.sampled_vertices:
# print(f"{node.current}")
return path, tree
if __name__ == "__main__":
x_start = (6.394267984578837, 0.25010755222666936) # Starting node
x_goal = (2.7502931836911926, 2.2321073814882277) # Goal node
rrt = RRT(x_start, x_goal, 0.5, 0.05, 50000)
path = rrt.run()
print(path)