Skip to content
Snippets Groups Projects
RRT.py 7.88 KiB
"""
RRT implementation
"""
import math
import numpy as np

from guided_mrmp.utils import Node, Env

class RRT:
    def __init__(self, env, s_start, s_goal, step_len, goal_sample_rate, iter_max, sampled_vertices=None):
        self.s_start = Node(s_start,s_start,0,0)
        self.s_goal = Node(s_goal,s_goal, 0,0)
        self.step_len = step_len
        self.goal_sample_rate = goal_sample_rate
        self.iter_max = iter_max

        if sampled_vertices is None:
            self.sampled_vertices = [self.s_start]
        else:
            self.sampled_vertices = sampled_vertices

        self.env = env
        # self.plotting = Plotting(self.env)
        # self.utils = utils.Utils()

        self.x_range = self.env.boundary[0]
        self.y_range = self.env.boundary[1]
        self.obs_circle = self.env.circle_obs
        self.obs_rectangle = self.env.rect_obs
    
    def dist(self, node1, node2):
        return math.hypot(node2.x - node1.x, node2.y - node1.y)

    def angle(self, node1, node2):
        return math.atan2(node2.y - node1.y, node2.x - node1.x)

    def is_collision(self, start, end):
        if self.env.is_inside_obs(start) or self.env.is_inside_obs(end):
            return True

        o, d = self.env.get_ray(start, end)
        obs_vertex = self.env.get_obs_vertex()

        for (v1, v2, v3, v4) in obs_vertex:
            if self.env.is_intersect_rec(start, end, o, d, v1, v2):
                return True
            if self.env.is_intersect_rec(start, end, o, d, v2, v3):
                return True
            if self.env.is_intersect_rec(start, end, o, d, v3, v4):
                return True
            if self.env.is_intersect_rec(start, end, o, d, v4, v1):
                return True

        for (x, y, r) in self.obs_circle:
            if self.env.is_intersect_circle(o, d, [x, y], r):
                return True

        return False


    def get_nearest_node(self, node_list, node):
        # find nearest neighbor
        dist = [self.dist(node, nd) for nd in node_list]
        node_near = node_list[int(np.argmin(dist))]

        # regular and generate new node
        dist, theta = self.dist(node_near, node), self.angle(node_near, node)
        dist = min(self.max_dist, dist)
        node_new = Node((node_near.x + dist * math.cos(theta),
                        (node_near.y + dist * math.sin(theta))),
                         node_near.current, node_near.g + dist, 0)
        
        # obstacle check
        if not self.is_collision(node_new, node_near):
            #  rewire optimization
            for node_n in node_list:
                #  check if the node is within optimization distance
                dist_to_new_node = self.dist(node_n, node_new)
                if dist_to_new_node < self.r:
                    cost = node_n.g + dist_to_new_node
                    #  update new sample node's cost and parent
                    if node_new.g > cost and not self.is_collision(node_n, node_new):
                        node_new.parent = node_n.current
                        node_new.g = cost
                    else:
                        #  update node's cost inside the radius
                        cost = node_new.g + dist_to_new_node
                        if node_n.g > cost and not self.is_collision(node_n, node_new):
                            node_n.parent = node_new.current
                            node_n.g = cost
                else:
                    continue
            return node_new
        

    def generate_random_node(self, goal_sample_rate):
        delta = self.env.delta

        if np.random.random() > goal_sample_rate:
            # return Node((np.random.uniform(self.x_range[0] + delta, self.x_range[1] - delta),
            #              np.random.uniform(self.y_range[0] + delta, self.y_range[1] - delta)))
            return Node((np.random.randint(self.x_range[0] + delta, self.x_range[1]),
                         np.random.randint(self.y_range[0] + delta, self.y_range[1])))

        return self.s_goal

    def nearest_neighbor(self, node_list, node):
        # find nearest neighbor
        dist = [self.dist(node, nd) for nd in node_list]
        node_near = node_list[int(np.argmin(dist))]

        # regular and generate new node
        dist, theta = self.dist(node_near, node), self.angle(node_near, node)
        dist = min(self.step_len, dist)
        node_new = Node((node_near.x + dist * math.cos(theta),
                        (node_near.y + dist * math.sin(theta))),
                         node_near.current, node_near.g + dist, 0)
        
        # obstacle check
        if self.is_collision(node_new, node_near):
            return None
        return node_new
    
    def new_state(self, node_start, node_end):
        dist, theta = self.get_distance_and_angle(node_start, node_end)

        dist = min(self.step_len, dist)
        node_new = Node((node_start.x + dist * math.cos(theta),
                         node_start.y + dist * math.sin(theta)))
        node_new.parent = node_start.current

        return node_new

    def extract_path(self, node_end):
        path = [(self.s_goal.x, self.s_goal.y)]
        node_now = node_end

        while node_now.parent is not None:
            node_now = node_now.parent
            path.append((node_now.x, node_now.y))

        return path

    def get_distance_and_angle(self, node_start, node_end):
        dx = node_end.x - node_start.x
        dy = node_end.y - node_start.y
        return math.hypot(dx, dy), math.atan2(dy, dx)
    
    def extractPath(self, closed_set):
        """
        Extract the path based on the CLOSED set.

        Parameters:
            closed_set (list): CLOSED set

        Returns
            cost (float): the cost of planning path
            path (list): the planning path
        """
        node = closed_set[closed_set.index(self.s_goal)]
        path = [node.current]
        cost = node.g
        while node != self.s_start:
            node_parent = closed_set[closed_set.index(Node(node.parent, None, None, None))]
            node = node_parent
            
            path.append(node.current)


        # return the cost, path, and the tree of the sampled vertices
        return cost, path, closed_set

    def get_distance_and_angle(self, node_start, node_end):
        dx = node_end.x - node_start.x
        dy = node_end.y - node_start.y
        return math.hypot(dx, dy), math.atan2(dy, dx)

    def plan(self):
        for _ in range(self.iter_max):

            # generate a random node in the map
            node_rand = self.generate_random_node(self.goal_sample_rate)

            if node_rand in self.sampled_vertices:
                continue

            node_new = self.nearest_neighbor(self.sampled_vertices, node_rand)
            # print(f"node new = {node_new}")
            # node_new = self.new_state(node_near, node_rand)

            if node_new:
                
                self.sampled_vertices.append(node_new)
                dist, __ = self.get_distance_and_angle(node_new, self.s_goal)

                if dist <= self.step_len and not self.is_collision(node_new, self.s_goal):
                    self.s_goal.parent = node_new.current
                    self.s_goal.g = node_new.g + self.dist(self.s_goal, node_new)
                    self.sampled_vertices.append(self.s_goal)
                    return self.extractPath(self.sampled_vertices)


        return 0, None, None
    
    def run(self):
        cost, path, tree = self.plan()
        # self.plotting.animation([path], "RRT", cost, self.sampled_vertices)
        # print(f"num of sampled vertices = {len(self.sampled_vertices)}")

        # for node in self.sampled_vertices:
        #     print(f"{node.current}")
        return path, tree


if __name__ == "__main__":
    x_start = (6.394267984578837, 0.25010755222666936)  # Starting node
    x_goal = (2.7502931836911926, 2.2321073814882277)  # Goal node

    rrt = RRT(x_start, x_goal, 0.5, 0.05, 50000)
    path = rrt.run()

    print(path)