diff --git a/guided_mrmp/conflict_resolvers/traj_opt_resolver.py b/guided_mrmp/conflict_resolvers/traj_opt_resolver.py
index cd62328d2795b6131baed8f065d13052933050dd..92e3a725491b153f00e0965a148353a6a2d595ca 100644
--- a/guided_mrmp/conflict_resolvers/traj_opt_resolver.py
+++ b/guided_mrmp/conflict_resolvers/traj_opt_resolver.py
@@ -3,34 +3,23 @@ import matplotlib.pyplot as plt
 from matplotlib.patches import Circle, Rectangle
 from casadi import *
 
-from guided_mrmp.conflict_resolvers.local_resolver import LocalResolver
-
-class TrajOptResolver(LocalResolver):
+class TrajOptResolver():
     """
     A class that resolves conflicts using trajectoy optimization.
     """
-    def __init__(self, conflicts, all_robots, dt, robot_radius, circle_obstacles, 
-                 rectangle_obstacles, rob_dist_weight, obs_dist_weight, time_weight):
-        """
-        inputs:
-            - starts (list): starts for all robots in the traj opt problem
-            - goals (list): goals for all robots in the traj opt problem
-        """
-        super.__init__(conflicts, all_robots, dt)
-        self.num_robots = len(all_robots)
-        self.starts = None
-        self.goals = None
+    def __init__(self, num_robots, robot_radius, starts, goals, circle_obstacles, rectangle_obstacles,
+                 rob_dist_weight, obs_dist_weight, control_weight, time_weight):
+        self.num_robots = num_robots
+        self.starts = starts
+        self.goals = goals
         self.circle_obs = circle_obstacles
         self.rect_obs = rectangle_obstacles
         self.rob_dist_weight = rob_dist_weight
         self.obs_dist_weight = obs_dist_weight
+        self.control_weight =control_weight
         self.time_weight = time_weight
         self.robot_radius = MX(robot_radius)
 
-        # Set the starts and goals for the robots
-        self.starts = [r.current_position for r in all_robots]
-        # the goals should be some point in the near future ... 
-
     def dist(self, robot_position, circle):
         """
         Returns the distance between a robot and a circle
@@ -58,14 +47,19 @@ class TrajOptResolver(LocalResolver):
     def log_normal_barrier(self, sigma, d, c):
         return c*fmax(0, 2-(d/sigma))**2.5
 
-    def solve(self, num_control_intervals, initial_guess):
-        """
-        Solves the trajectory optimization problem for the robots.
-        TODO: This will not work for generic dynamics. It only works for roomba model.
-        I don't know how to handle generic dynamics with casadi yet.
+    def problem_setup(self, N, x_range, y_range):
         """
+        Problem setup for the multi-robot collision resolution traj opt problem
+
+        inputs:
+            - N (int): number of control intervals
+            - x_range (tuple): range of x values
+            - y_range (tuple): range of y values
 
-        N = num_control_intervals
+        outputs:
+            - problem (dict): dictionary containing the optimization problem 
+                              and the decision variables
+        """
         opti = Opti() # Optimization problem
 
         # ---- decision variables --------- #
@@ -75,16 +69,15 @@ class TrajOptResolver(LocalResolver):
         y = pos[1::2,:]
         heading = X[self.num_robots*2:,:]           # heading is the last value
 
-        
-        circle_obs = DM(self.circle_obs)            # make the obstacles casadi objects 
-        
         U = opti.variable(self.num_robots*2, N)     # control trajectory (v, omega)
         vel = U[0::2,:]
         omega = U[1::2,:]
         T = opti.variable()                         # final time
 
-
-        # sum up the cost of distance to obstacles
+        # ---- obstacle setup ------------ #
+        circle_obs = DM(self.circle_obs)            # make the obstacles casadi objects 
+        
+        # ------ Obstacle dist cost ------ #
         # TODO:: Include rectangular obstacles
         dist_to_other_obstacles = 0
         for r in range(self.num_robots):
@@ -92,88 +85,150 @@ class TrajOptResolver(LocalResolver):
                 for c in range(circle_obs.shape[0]):
                     circle = circle_obs[c, :]
                     d = self.dist(pos[2*r : 2*(r+1), k], circle)
-                    dist_to_other_obstacles += self.apply_quadratic_barrier(self.robot_radius + circle[2] + 0.5, d, 1)
-                    # dist_to_other_obstacles += self.log_normal_barrier(5, d, 5)
+                    dist_to_other_obstacles += self.apply_quadratic_barrier(2*(self.robot_radius + circle[2]), d, 5)
 
+        # ------ Robot dist cost ------ #
         dist_to_other_robots = 0
         for k in range(N):
             for r1 in range(self.num_robots):
                 for r2 in range(self.num_robots):
                     if r1 != r2:
-                        # print(f"\n{r1} position1 = {pos[2*r1 : 2*(r1+1), k]}")
-                        # print(f"{r2} position2 = {pos[2*r2 : 2*(r2+1), k]}")
-
-                        # note: using norm 2 here gives an invalid num detected error. 
-                        # Must be the sqrt causing an issue
-                        # d = norm_2(pos[2*r1 : 2*(r1+1), k] - pos[2*r2 : 2*(r2+1), k]) - 2*self.robot_radius
                         d = sumsqr(pos[2*r1 : 2*(r1+1), k] - pos[2*r2 : 2*(r2+1), k]) 
-                        dist_to_other_robots += self.apply_quadratic_barrier(2*self.robot_radius+.5, d, 1)
-                      
+                        dist_to_other_robots += self.apply_quadratic_barrier(2*self.robot_radius, d, 1)
+
+
+        # ---- dynamics constraints ---- #              
         dt = T/N # length of a control interval
 
-        # Ensure that the robot moves according to the dynamics
+        pi = [3.14159]*self.num_robots
+        pi = np.array(pi)
+        pi = DM(pi)
+
         for k in range(N): # loop over control intervals
             dxdt = vel[:,k] * cos(heading[:,k])
             dydt = vel[:,k] * sin(heading[:,k])
             dthetadt = omega[:,k]
             opti.subject_to(x[:,k+1]==x[:,k] + dt*dxdt)
             opti.subject_to(y[:,k+1]==y[:,k] + dt*dydt) 
-            opti.subject_to(heading[:,k+1]==heading[:,k] + dt*dthetadt)
+            opti.subject_to(heading[:,k+1]==fmod(heading[:,k] + dt*dthetadt, 2*pi))
+
+
+        # ------ Control panalty ------ #
+        # Calculate the sum of squared differences between consecutive heading angles
+        heading_diff_penalty = 0
+        for k in range(N-1):
+            heading_diff_penalty += sumsqr(fmod(heading[:,k+1] - heading[:,k] + pi, 2*pi) - pi)
 
+
+        # ------ cost function ------ #
         opti.minimize(self.rob_dist_weight*dist_to_other_robots 
-                      + self.obs_dist_weight*dist_to_other_obstacles 
-                      + self.time_weight*T)
+                    + self.obs_dist_weight*dist_to_other_obstacles 
+                    + self.time_weight*T
+                    + self.control_weight*heading_diff_penalty)
 
 
-        # --- v and omega constraints --- #
+        # ------ control constraints ------ #
         for k in range(N):
             for r in range(self.num_robots):
                 opti.subject_to(sumsqr(vel[r,k]) <= 0.2**2)
-                opti.subject_to(sumsqr(omega[r,k]) <= 0.1**2)
+                opti.subject_to(sumsqr(omega[r,k]) <= 0.2**2)
 
-        # --- position constraints --- #
-        opti.subject_to(opti.bounded(0,x,10))
-        opti.subject_to(opti.bounded(0,y,10))
-        
+        # ------ bound x, y, and time  ------ #
+        opti.subject_to(opti.bounded(x_range[0],x,x_range[1]))
+        opti.subject_to(opti.bounded(y_range[0],y,y_range[1]))
+        opti.subject_to(opti.bounded(0,T,100))
 
-        # ---- start/goal conditions --------
+        # ------ initial conditions ------ #
         for r in range(self.num_robots):
-            # opti.subject_to(vel[r, 0]==0) 
-            opti.subject_to(pos[2*r : 2*(r+1), 0]==self.starts[r])
+            
+            opti.subject_to(heading[r, 0]==self.starts[r][2])
+            opti.subject_to(pos[2*r : 2*(r+1), 0]==self.starts[r][0:2])
             opti.subject_to(pos[2*r : 2*(r+1), -1]==self.goals[r])
 
-        # ---- misc. constraints  ----------
-        opti.subject_to(opti.bounded(0,T,100))
+        return {'opti':opti, 'X':X, 'T':T}
 
-        # ---- initial values for solver ---
-        opti.set_initial(T, 20)
+    def solve_optimization_problem(self, problem, initial_guesses=None, solver_options=None):
+        opti = problem['opti']
         
-        if initial_guess is not None:
-            opti.set_initial(pos,initial_guess)
-        
-        # ---- solve NLP              ------
-        opti.solver("ipopt") # set numerical backend
-        sol = opti.solve()   # actual solve
+        if initial_guesses:
+            for param, value in initial_guesses.items():
+                print(f"param = {param}")
+                print(f"value = {value}")
+                opti.set_initial(problem[param], value)
+
+        # Set numerical backend, with options if provided
+        if solver_options:
+            opti.solver('ipopt', solver_options)
+        else:
+            opti.solver('ipopt')
+
+        try:
+            sol = opti.solve()   # actual solve
+            status = 'succeeded'
+        except:
+            sol = None
+            status = 'failed'
+
+        results = {
+            'status' : status,
+            'solution' : sol,
+        }
+
+        if sol:
+            for var_name, var in problem.items():
+                if var_name != 'opti':
+                    results[var_name] = sol.value(var)
+
+        return results
+    
+    def solve(self, N, x_range, y_range, initial_guesses):
+        """
+        Setup and solve a multi-robot traj opt problem
 
-        # print(f"pos = {opti.debug.value(pos[2:4,:])}")
+        input: 
+            - N (int): the number of control intervals
+            - x_range (tuple): 
+            - y_range (tuple): 
+        """
+        problem = self.problem_setup(N, x_range, y_range)
+        results = self.solve_optimization_problem(problem, initial_guesses)
 
-        return sol,pos
+        X = results['X']
+        sol = results['solution']
 
-    def get_local_controls(self):
+        # Extract the values that we want from the optimizer's solution
+        pos = X[:self.num_robots*2,:]               
+        x_vals = pos[0::2,:]                             
+        y_vals = pos[1::2,:]
+        theta_vals = X[self.num_robots*2:,:]
+
+        return sol,pos, x_vals, y_vals, theta_vals
+
+    def get_local_controls(self, controls):
+        """ 
+        Get the local controls for the robots in the conflict
+        """
+
+        l = self.num_robots
+
+        final_trajs = [None]*l
 
         for c in self.conflicts:
             # Get the robots involved in the conflict
             robots = [self.all_robots[r.label] for r in c]
-            robot_positions = [r.current_position for r in robots]
 
             # Solve the trajectory optimization problem
             initial_guess = None
-            sol, x_opt = self.solve(10, initial_guess)
+            sol, x_opt, vels, omegas, xs,ys = self.solve(20, initial_guess)
+
+            pos_vals = np.array(sol.value(x_opt))
 
             # Update the controls for the robots
-            for r, pos in zip(robots, x_opt):
-                r.next_control = r.tracker.get_next_control(pos)
+            for r, vel, omega, x,y in zip(robots, vels, omegas, xs,ys):
+                controls[r.label] = [vel, omega]
+                final_trajs[r.label] = [x,y]
 
+        return controls, final_trajs
 
     def plot_paths(self, x_opt):
         fig, ax = plt.subplots()
@@ -204,8 +259,8 @@ class TrajOptResolver(LocalResolver):
         ax.legend()
         ax.set_aspect('equal', 'box')
 
-        plt.ylim(0,10)
-        plt.xlim(0,10)
+        plt.ylim(0,640)
+        plt.xlim(0,480)
         plt.title('Robot Paths')
         plt.grid(False)
         plt.show()