import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Circle, Rectangle
from casadi import *

from guided_mrmp.conflict_resolvers.local_resolver import LocalResolver

class TrajOptResolver(LocalResolver):
    """
    A class that resolves conflicts using trajectoy optimization.
    """
    def __init__(self, conflicts, all_robots, dt, robot_radius, circle_obstacles, 
                 rectangle_obstacles, rob_dist_weight, obs_dist_weight, time_weight):
        """
        inputs:
            - starts (list): starts for all robots in the traj opt problem
            - goals (list): goals for all robots in the traj opt problem
        """
        super.__init__(conflicts, all_robots, dt)
        self.num_robots = len(all_robots)
        self.starts = None
        self.goals = None
        self.circle_obs = circle_obstacles
        self.rect_obs = rectangle_obstacles
        self.rob_dist_weight = rob_dist_weight
        self.obs_dist_weight = obs_dist_weight
        self.time_weight = time_weight
        self.robot_radius = MX(robot_radius)

        # Set the starts and goals for the robots
        self.starts = [r.current_position for r in all_robots]
        # the goals should be some point in the near future ... 

    def dist(self, robot_position, circle):
        """
        Returns the distance between a robot and a circle

        params:
            robot_position [x,y]
            circle [x,y,radius]
        """
        return sumsqr(robot_position - transpose(circle[:2])) 

    def apply_quadratic_barrier(self, d_max, d, c):
        """
        Applies a quadratic barrier to some given distance. The quadratic barrier 
        is a soft barrier function. We are using it for now to avoid any issues with
        invalid initial solutions, which hard barrier functions cannot handle. 

        params:
            d (float):      distance to the obstacle
            c (float):      controls the steepness of curve. 
                            higher c --> gets more expensive faster as you move toward obs
            d_max (float):  The threshold distance at which the barrier starts to apply 
        """
        return c*fmax(0, d_max-d)**2
    
    def log_normal_barrier(self, sigma, d, c):
        return c*fmax(0, 2-(d/sigma))**2.5

    def solve(self, num_control_intervals, initial_guess):
        """
        Solves the trajectory optimization problem for the robots.
        TODO: This will not work for generic dynamics. It only works for roomba model.
        I don't know how to handle generic dynamics with casadi yet.
        """

        N = num_control_intervals
        opti = Opti() # Optimization problem

        # ---- decision variables --------- #
        X = opti.variable(self.num_robots*3, N+1)   # state trajectory (x,y,heading)
        pos = X[:self.num_robots*2,:]               # position is the first two values
        x = pos[0::2,:]
        y = pos[1::2,:]
        heading = X[self.num_robots*2:,:]           # heading is the last value

        
        circle_obs = DM(self.circle_obs)            # make the obstacles casadi objects 
        
        U = opti.variable(self.num_robots*2, N)     # control trajectory (v, omega)
        vel = U[0::2,:]
        omega = U[1::2,:]
        T = opti.variable()                         # final time


        # sum up the cost of distance to obstacles
        # TODO:: Include rectangular obstacles
        dist_to_other_obstacles = 0
        for r in range(self.num_robots):
            for k in range(N):
                for c in range(circle_obs.shape[0]):
                    circle = circle_obs[c, :]
                    d = self.dist(pos[2*r : 2*(r+1), k], circle)
                    dist_to_other_obstacles += self.apply_quadratic_barrier(self.robot_radius + circle[2] + 0.5, d, 1)
                    # dist_to_other_obstacles += self.log_normal_barrier(5, d, 5)

        dist_to_other_robots = 0
        for k in range(N):
            for r1 in range(self.num_robots):
                for r2 in range(self.num_robots):
                    if r1 != r2:
                        # print(f"\n{r1} position1 = {pos[2*r1 : 2*(r1+1), k]}")
                        # print(f"{r2} position2 = {pos[2*r2 : 2*(r2+1), k]}")

                        # note: using norm 2 here gives an invalid num detected error. 
                        # Must be the sqrt causing an issue
                        # d = norm_2(pos[2*r1 : 2*(r1+1), k] - pos[2*r2 : 2*(r2+1), k]) - 2*self.robot_radius
                        d = sumsqr(pos[2*r1 : 2*(r1+1), k] - pos[2*r2 : 2*(r2+1), k]) 
                        dist_to_other_robots += self.apply_quadratic_barrier(2*self.robot_radius+.5, d, 1)
                      
        dt = T/N # length of a control interval

        # Ensure that the robot moves according to the dynamics
        for k in range(N): # loop over control intervals
            dxdt = vel[:,k] * cos(heading[:,k])
            dydt = vel[:,k] * sin(heading[:,k])
            dthetadt = omega[:,k]
            opti.subject_to(x[:,k+1]==x[:,k] + dt*dxdt)
            opti.subject_to(y[:,k+1]==y[:,k] + dt*dydt) 
            opti.subject_to(heading[:,k+1]==heading[:,k] + dt*dthetadt)

        opti.minimize(self.rob_dist_weight*dist_to_other_robots 
                      + self.obs_dist_weight*dist_to_other_obstacles 
                      + self.time_weight*T)


        # --- v and omega constraints --- #
        for k in range(N):
            for r in range(self.num_robots):
                opti.subject_to(sumsqr(vel[r,k]) <= 0.2**2)
                opti.subject_to(sumsqr(omega[r,k]) <= 0.1**2)

        # --- position constraints --- #
        opti.subject_to(opti.bounded(0,x,10))
        opti.subject_to(opti.bounded(0,y,10))
        

        # ---- start/goal conditions --------
        for r in range(self.num_robots):
            # opti.subject_to(vel[r, 0]==0) 
            opti.subject_to(pos[2*r : 2*(r+1), 0]==self.starts[r])
            opti.subject_to(pos[2*r : 2*(r+1), -1]==self.goals[r])

        # ---- misc. constraints  ----------
        opti.subject_to(opti.bounded(0,T,100))

        # ---- initial values for solver ---
        opti.set_initial(T, 20)
        
        if initial_guess is not None:
            opti.set_initial(pos,initial_guess)
        
        # ---- solve NLP              ------
        opti.solver("ipopt") # set numerical backend
        sol = opti.solve()   # actual solve

        # print(f"pos = {opti.debug.value(pos[2:4,:])}")

        return sol,pos

    def get_local_controls(self):

        for c in self.conflicts:
            # Get the robots involved in the conflict
            robots = [self.all_robots[r.label] for r in c]
            robot_positions = [r.current_position for r in robots]

            # Solve the trajectory optimization problem
            initial_guess = None
            sol, x_opt = self.solve(10, initial_guess)

            # Update the controls for the robots
            for r, pos in zip(robots, x_opt):
                r.next_control = r.tracker.get_next_control(pos)


    def plot_paths(self, x_opt):
        fig, ax = plt.subplots()

        # Plot obstacles
        for obstacle in self.circle_obs:
            # if len(obstacle) == 2:  # Circle
            ax.add_patch(Circle(obstacle, obstacle[2], color='red'))
            # elif len(obstacle) == 4:  # Rectangle
            #     ax.add_patch(Rectangle((obstacle[0], obstacle[1]), obstacle[2], obstacle[3], color='red'))

        if self.num_robots > 20:
            colors = plt.cm.hsv(np.linspace(0.2, 1.0, self.num_robots))
        elif self.num_robots > 10:
            colors = plt.cm.tab20(np.linspace(0, 1, self.num_robots))
        else:
            colors = plt.cm.tab10(np.linspace(0, 1, self.num_robots))

        # Plot robot paths
        for r,color in zip(range(self.num_robots),colors):
            ax.plot(x_opt[r*2, :], x_opt[r*2+1, :], label=f'Robot {r+1}', color=color)
            ax.scatter(x_opt[r*2, :], x_opt[r*2+1, :], color=color, s=10 )
            ax.scatter(self.starts[r][0], self.starts[r][1], s=85,color=color)
            ax.scatter(self.goals[r][0], self.goals[r][1], s=85,facecolors='none', edgecolors=color)

        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.legend()
        ax.set_aspect('equal', 'box')

        plt.ylim(0,10)
        plt.xlim(0,10)
        plt.title('Robot Paths')
        plt.grid(False)
        plt.show()