import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Circle, Rectangle
from guided_mrmp.optimizer import Optimizer
from casadi import *

class TrajOptResolver():
    """
    A class that resolves conflicts using trajectoy optimization.
    """
    def __init__(self, num_robots, robot_radius, starts, goals, circle_obstacles, rectangle_obstacles,
                 rob_dist_weight, obs_dist_weight, control_weight, time_weight):
        self.num_robots = num_robots
        self.starts = starts
        self.goals = goals
        self.circle_obs = circle_obstacles
        self.rect_obs = rectangle_obstacles
        self.rob_dist_weight = rob_dist_weight
        self.obs_dist_weight = obs_dist_weight
        self.control_weight =control_weight
        self.time_weight = time_weight
        self.robot_radius = MX(robot_radius)

    def dist(self, robot_position, circle):
        """
        Returns the distance between a robot and a circle

        params:
            robot_position [x,y]
            circle [x,y,radius]
        """
        return sumsqr(robot_position - transpose(circle[:2])) 

    def apply_quadratic_barrier(self, d_max, d, c):
        """
        Applies a quadratic barrier to some given distance. The quadratic barrier 
        is a soft barrier function. We are using it for now to avoid any issues with
        invalid initial solutions, which hard barrier functions cannot handle. 

        params:
            d (float):      distance to the obstacle
            c (float):      controls the steepness of curve. 
                            higher c --> gets more expensive faster as you move toward obs
            d_max (float):  The threshold distance at which the barrier starts to apply 
        """
        return c*fmax(0, d_max-d)**2
    
    def log_normal_barrier(self, sigma, d, c):
        return c*fmax(0, 2-(d/sigma))**2.5

    def problem_setup(self, N, x_range, y_range):
        """
        Problem setup for the multi-robot collision resolution traj opt problem

        inputs:
            - N (int): number of control intervals
            - x_range (tuple): range of x values
            - y_range (tuple): range of y values

        outputs:
            - problem (dict): dictionary containing the optimization problem 
                              and the decision variables
        """
        opti = Opti() # Optimization problem

        # ---- decision variables --------- #
        X = opti.variable(self.num_robots*3, N+1)   # state trajectory (x,y,heading)
        pos = X[:self.num_robots*2,:]               # position is the first two values
        x = pos[0::2,:]
        y = pos[1::2,:]
        heading = X[self.num_robots*2:,:]           # heading is the last value

        U = opti.variable(self.num_robots*2, N)     # control trajectory (v, omega)
        vel = U[0::2,:]
        omega = U[1::2,:]
        T = opti.variable()                         # final time

        # ---- obstacle setup ------------ #
        circle_obs = DM(self.circle_obs)            # make the obstacles casadi objects 
        
        # ------ Obstacle dist cost ------ #
        # TODO:: Include rectangular obstacles
        dist_to_other_obstacles = 0
        for r in range(self.num_robots):
            for k in range(N):
                for c in range(circle_obs.shape[0]):
                    circle = circle_obs[c, :]
                    d = self.dist(pos[2*r : 2*(r+1), k], circle)
                    dist_to_other_obstacles += self.apply_quadratic_barrier(2*(self.robot_radius + circle[2]), d, 5)

        # ------ Robot dist cost ------ #
        dist_to_other_robots = 0
        for k in range(N):
            for r1 in range(self.num_robots):
                for r2 in range(self.num_robots):
                    if r1 != r2:
                        d = sumsqr(pos[2*r1 : 2*(r1+1), k] - pos[2*r2 : 2*(r2+1), k]) 
                        dist_to_other_robots += self.apply_quadratic_barrier(2*self.robot_radius, d, 1)


        # ---- dynamics constraints ---- #              
        dt = T/N # length of a control interval

        pi = [3.14159]*self.num_robots
        pi = np.array(pi)
        pi = DM(pi)

        for k in range(N): # loop over control intervals
            dxdt = vel[:,k] * cos(heading[:,k])
            dydt = vel[:,k] * sin(heading[:,k])
            dthetadt = omega[:,k]
            opti.subject_to(x[:,k+1]==x[:,k] + dt*dxdt)
            opti.subject_to(y[:,k+1]==y[:,k] + dt*dydt) 
            opti.subject_to(heading[:,k+1]==fmod(heading[:,k] + dt*dthetadt, 2*pi))


        # ------ Control panalty ------ #
        # Calculate the sum of squared differences between consecutive heading angles
        heading_diff_penalty = 0
        for k in range(N-1):
            heading_diff_penalty += sumsqr(fmod(heading[:,k+1] - heading[:,k] + pi, 2*pi) - pi)


        # ------ cost function ------ #
        opti.minimize(self.rob_dist_weight*dist_to_other_robots 
                    + self.obs_dist_weight*dist_to_other_obstacles 
                    + self.time_weight*T
                    + self.control_weight*heading_diff_penalty)


        # ------ control constraints ------ #
        for k in range(N):
            for r in range(self.num_robots):
                opti.subject_to(sumsqr(vel[r,k]) <= 0.2**2)
                opti.subject_to(sumsqr(omega[r,k]) <= 0.2**2)

        # ------ bound x, y, and time  ------ #
        opti.subject_to(opti.bounded(x_range[0],x,x_range[1]))
        opti.subject_to(opti.bounded(y_range[0],y,y_range[1]))
        opti.subject_to(opti.bounded(0,T,100))

        # ------ initial conditions ------ #
        for r in range(self.num_robots):
            
            opti.subject_to(heading[r, 0]==self.starts[r][2])
            opti.subject_to(pos[2*r : 2*(r+1), 0]==self.starts[r][0:2])
            opti.subject_to(pos[2*r : 2*(r+1), -1]==self.goals[r])

        return {'opti':opti, 'X':X, 'T':T}

    def solve_optimization_problem(self, problem, initial_guesses=None, solver_options=None):
        opt = Optimizer(problem)
        results = opt.solve_optimization_problem(initial_guesses, solver_options)
        return results
    
    def solve(self, N, x_range, y_range, initial_guesses):
        """
        Setup and solve a multi-robot traj opt problem

        input: 
            - N (int): the number of control intervals
            - x_range (tuple): 
            - y_range (tuple): 
        """
        problem = self.problem_setup(N, x_range, y_range)
        results = self.solve_optimization_problem(problem, initial_guesses)

        X = results['X']
        sol = results['solution']

        # Extract the values that we want from the optimizer's solution
        pos = X[:self.num_robots*2,:]               
        x_vals = pos[0::2,:]                             
        y_vals = pos[1::2,:]
        theta_vals = X[self.num_robots*2:,:]

        return sol,pos, x_vals, y_vals, theta_vals

    def get_local_controls(self, controls):
        """ 
        Get the local controls for the robots in the conflict
        """

        l = self.num_robots

        final_trajs = [None]*l

        for c in self.conflicts:
            # Get the robots involved in the conflict
            robots = [self.all_robots[r.label] for r in c]

            # Solve the trajectory optimization problem
            initial_guess = None
            sol, x_opt, vels, omegas, xs,ys = self.solve(20, initial_guess)

            pos_vals = np.array(sol.value(x_opt))

            # Update the controls for the robots
            for r, vel, omega, x,y in zip(robots, vels, omegas, xs,ys):
                controls[r.label] = [vel, omega]
                final_trajs[r.label] = [x,y]

        return controls, final_trajs

    def plot_paths(self, x_opt):
        fig, ax = plt.subplots()

        # Plot obstacles
        for obstacle in self.circle_obs:
            # if len(obstacle) == 2:  # Circle
            ax.add_patch(Circle(obstacle, obstacle[2], color='red'))
            # elif len(obstacle) == 4:  # Rectangle
            #     ax.add_patch(Rectangle((obstacle[0], obstacle[1]), obstacle[2], obstacle[3], color='red'))

        if self.num_robots > 20:
            colors = plt.cm.hsv(np.linspace(0.2, 1.0, self.num_robots))
        elif self.num_robots > 10:
            colors = plt.cm.tab20(np.linspace(0, 1, self.num_robots))
        else:
            colors = plt.cm.tab10(np.linspace(0, 1, self.num_robots))

        # Plot robot paths
        for r,color in zip(range(self.num_robots),colors):
            ax.plot(x_opt[r*2, :], x_opt[r*2+1, :], label=f'Robot {r+1}', color=color)
            ax.scatter(x_opt[r*2, :], x_opt[r*2+1, :], color=color, s=10 )
            ax.scatter(self.starts[r][0], self.starts[r][1], s=85,color=color)
            ax.scatter(self.goals[r][0], self.goals[r][1], s=85,facecolors='none', edgecolors=color)

        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.legend()
        ax.set_aspect('equal', 'box')

        plt.ylim(0,640)
        plt.xlim(0,480)
        plt.title('Robot Paths')
        plt.grid(False)
        plt.show()