"""
RRT*
"""
from guided_mrmp.planners.RRT import RRT

class RRTStar(RRT):
    def __init__(self, env, s_start, s_goal, step_len, goal_sample_rate, iter_max, r, sampled_vertices=None):
        super().__init__(env, s_start, s_goal, step_len, goal_sample_rate, iter_max, sampled_vertices=None)
        self.r = r
        self.name="RRT*"

    def nearest_neighbor(self, node_list, node):
        # print("using correct nn funct")
        node_new = super().nearest_neighbor(node_list, node)
        # node_new = self.new_state(node_nearest, node)
        if node_new:
            #  rewire optimization
            for node_n in node_list:
                #  inside the optimization circle
                new_dist = self.dist(node_n, node_new)
                if new_dist < self.r:
                    cost = node_n.g + new_dist
                    #  update new sample node's cost and parent
                    if node_new.g > cost and not self.is_collision(node_n, node_new):
                        node_new.parent = node_n.current
                        node_new.g = cost
                    else:
                        #  update nodes' cost inside the radius
                        cost = node_new.g + new_dist
                        if node_n.g > cost and not self.is_collision(node_n, node_new):
                            node_n.parent = node_new.current
                            node_n.g = cost
                else:
                    continue
            return node_new
        else:
            return None 

if __name__ == "__main__":
    # x_start = (18, 8)  # Starting node
    # x_goal = (37, 18)  # Goal node
    x_start = (6.394267984578837, 0.25010755222666936)  # Starting node
    x_goal = (2.7502931836911926, 2.2321073814882277)  # Goal node

    # rrt = RRT(x_start, x_goal, 0.5, 0.05, 10000)
    # path = rrt.run()

    rrtstar = RRTStar(x_start, x_goal, 0.5, 0.05, 10000, r=10.0)
    path = rrtstar.run()
    print(path)