diff --git a/guided_mrmp/planners/singlerobot/RRT.py b/guided_mrmp/planners/singlerobot/RRT.py
index 507aeb9e838471baef6364a88994f062f4b0273f..53944f12b872957c0b2ca78da66aea4869bab370 100644
--- a/guided_mrmp/planners/singlerobot/RRT.py
+++ b/guided_mrmp/planners/singlerobot/RRT.py
@@ -7,13 +7,17 @@ import numpy as np
 from guided_mrmp.utils import Node, Env
 
 class RRT:
-    def __init__(self, env, s_start, s_goal, step_len, goal_sample_rate, iter_max):
+    def __init__(self, env, s_start, s_goal, step_len, goal_sample_rate, iter_max, sampled_vertices=None):
         self.s_start = Node(s_start,s_start,0,0)
         self.s_goal = Node(s_goal,s_goal, 0,0)
         self.step_len = step_len
         self.goal_sample_rate = goal_sample_rate
         self.iter_max = iter_max
-        self.sampled_vertices = [self.s_start]
+
+        if sampled_vertices is None:
+            self.sampled_vertices = [self.s_start]
+        else:
+            self.sampled_vertices = sampled_vertices
 
         self.env = env
         # self.plotting = Plotting(self.env)
@@ -141,6 +145,7 @@ class RRT:
         dx = node_end.x - node_start.x
         dy = node_end.y - node_start.y
         return math.hypot(dx, dy), math.atan2(dy, dx)
+    
     def extractPath(self, closed_set):
         """
         Extract the path based on the CLOSED set.
@@ -161,7 +166,9 @@ class RRT:
             
             path.append(node.current)
 
-        return cost, path
+
+        # return the cost, path, and the tree of the sampled vertices
+        return cost, path, closed_set
 
     def get_distance_and_angle(self, node_start, node_end):
         dx = node_end.x - node_start.x
@@ -193,16 +200,16 @@ class RRT:
                     return self.extractPath(self.sampled_vertices)
 
 
-        return 0, None
+        return 0, None, None
     
     def run(self):
-        cost, path = self.plan()
+        cost, path, tree = self.plan()
         # self.plotting.animation([path], "RRT", cost, self.sampled_vertices)
         # print(f"num of sampled vertices = {len(self.sampled_vertices)}")
 
         # for node in self.sampled_vertices:
         #     print(f"{node.current}")
-        return path
+        return path, tree
 
 
 if __name__ == "__main__":
diff --git a/guided_mrmp/planners/singlerobot/RRTStar.py b/guided_mrmp/planners/singlerobot/RRTStar.py
index 8f1a0da5250c88a452fb6d7d85e910582a82c0d3..385a92338030d9ee7562c12b6ae0292aac091faa 100644
--- a/guided_mrmp/planners/singlerobot/RRTStar.py
+++ b/guided_mrmp/planners/singlerobot/RRTStar.py
@@ -4,8 +4,8 @@ RRT*
 from guided_mrmp.planners.singlerobot.RRT import RRT
 
 class RRTStar(RRT):
-    def __init__(self, env, s_start, s_goal, step_len, goal_sample_rate, iter_max, r):
-        super().__init__(env, s_start, s_goal, step_len, goal_sample_rate, iter_max)
+    def __init__(self, env, s_start, s_goal, step_len, goal_sample_rate, iter_max, r, sampled_vertices=None):
+        super().__init__(env, s_start, s_goal, step_len, goal_sample_rate, iter_max, sampled_vertices=None)
         self.r = r
         self.name="RRT*"