import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np

# Simulator class:
# Responsible for managing the simulation of the robots in the environment
class Simulator:
    def __init__(self, robots, dynamics_models, env, policy, settings):
        """
        robots: list of Robot objects
        dynamics_models: list of DynamicsModel objects
        circle_obstacles: list of tuples (x,y,radius)
        rectangle_obstacles: list of tuples (x,y,width,height)
        policy: The policy that gives us the controls for each robot at a given time
        state: the current state of the world. This is a list of tuples, where each tuple is the state of a robot
        time: the current time
        """
        self.robots = robots
        self.env = env
        self.circ_obstacles = env.circle_obs
        self.rect_obstacles = env.rect_obs
        self.policy = policy 

        self.state = [robot.current_position for robot in robots]

        self.num_robots = len(robots)
        self.dynamics_models = dynamics_models
        self.time = 0

        self.scaling_factor = settings['simulator']['scaling_factor']

        # Helper variables to keep track of the sim
        self.sim_time = 0
        self.x_history = [ [] for _ in range(self.num_robots) ]
        self.y_history = [ [] for _ in range(self.num_robots) ]
        self.h_history = [ [] for _ in range(self.num_robots) ]
        self.optimized_trajectories_hist = [ [] for _ in range(self.num_robots) ]
        self.optimized_trajectory = None

    def all_robots_at_goal(self):
        for i in range(self.num_robots):
            if (np.sqrt((self.state[i][0] - self.robots[i].goal[0]) ** 2 + (self.state[i][1] - self.robots[i].goal[1]) ** 2) > .5):
                return False
        return True

    def advance(self, screen, dt):
        """
        Advance the simulation by dt seconds
        """

        # Get the controls from the policy
        x_mpc, controls = self.policy.advance(screen, self.state)

        # # Update the state of each robot
        # for i in range(self.num_robots):
        #     new_state = self.dynamics_models[i].next_state(self.state[i], controls[i], dt)
        #     self.robots[i].current_position = new_state
        #     self.state[i] = new_state

        # Update the time
        self.time += dt
        return x_mpc, controls

    def run(self, show_plots=False):
        """
        Run the path tracker algorithm.
        Parameters:
        - show_plots (bool): Flag indicating whether to show plots during the simulation. Default is False.
        Returns:
        - numpy.ndarray: Array containing the history of x, y, and h coordinates.
        """

        # Add the initial state to the histories
        self.state = np.array(self.state)
        for i in range(self.num_robots):
            self.x_history[i].append(self.state[i, 0])
            self.y_history[i].append(self.state[i, 1])
            self.h_history[i].append(self.state[i, 2])
        # if show_plots: self.plot_sim()

        self.plot_current_world_state()
        
        while 1:
            # check if all robots have reached their goal
            if self.all_robots_at_goal():
                print("Success! Goal Reached")
                return np.asarray([self.x_history, self.y_history, self.h_history])
            
            # plot the current state of the robots
            self.plot_current_world_state()
            
            # get the next control for all robots
            x_mpc, controls = self.advance(self.state, self.policy.DT)

            next_states = []
            for i in range(self.num_robots):
                next_states.append(self.policy.dynamics.next_state(self.state[i], controls[i], self.policy.DT))

            self.state = next_states

            self.state = np.array(self.state)
            for i in range(self.num_robots):
                self.x_history[i].append(self.state[i, 0])
                self.y_history[i].append(self.state[i, 1])
                self.h_history[i].append(self.state[i, 2])

            # use the optimizer output to preview the predicted state trajectory
            # self.optimized_trajectory = self.ego_to_global(x_mpc.value)
            # if show_plots: self.optimized_trajectory = self.ego_to_global_roomba(x_mpc)
            # if show_plots: self.plot_sim()


  
    def plot_current_world_state(self):
        """
        Plot the current state of the world.
        """
        colors = cm.rainbow(np.linspace(0, 1, self.num_robots))

        # plot the obstacles
        for obs in self.circ_obstacles:
            circle1 = plt.Circle((obs[0], obs[1]), obs[2], color='k', fill=True)
            plt.gca().add_artist(circle1)

        # Plot the current state of each robot using the most recent values from
        # x_history, y_history, and h_history
        for i in range(self.num_robots):
            self.plot_roomba(self.x_history[i][-1], self.y_history[i][-1], self.h_history[i][-1], colors[i], False, self.policy.radius)

            
            x, y, theta = self.policy.paths[i][:, -1]
            plt.plot(x, y, 'o', color=colors[i])
            circle1 = plt.Circle((x, y), self.policy.radius, color=colors[i], fill=False)
            plt.gca().add_artist(circle1)


        # plot the ref path of each robot
        for i in range(self.num_robots):
            plt.plot(self.policy.paths[i][0, :], self.policy.paths[i][1, :], '--', color=colors[i])

        x_range = self.env.boundary[0]
        y_range = self.env.boundary[1]
        plt.xlim(x_range[0], x_range[1])
        plt.ylim(y_range[0], y_range[1])

        # force equal aspect ratio
        plt.gca().set_aspect('equal', adjustable='box')

        plt.tight_layout()


        plt.show()
        # plt.draw()
        # plt.pause(0.1)
        # plt.clf()

    def plot_roomba(self, x, y, yaw, color, fill, radius):
        """

        Args:
            x ():
            y ():
            yaw ():
        """
        ax = plt.gca()
        if fill: alpha = .3
        else: alpha = 1
        circle = plt.Circle((x, y), radius, color=color, fill=fill, alpha=alpha)
        ax.add_patch(circle)

        # Plot direction marker
        dx = 1 * np.cos(yaw)
        dy = 1 * np.sin(yaw)
        ax.arrow(x, y, dx, dy, head_width=0.1, head_length=0.1, fc='r', ec='r')