"""
Database-guided multi-robot motion planning
"""
import random

import sys
sys.path.append("c:\\Users\\rmoan2\\guided_mrmp_24")

from SingleAgentPlanners.RRT import *
from SingleAgentPlanners.RRTStar import *
from Utils import Conflict, Robot

class GuidedMRMP:
    def __init__(self, env):
        """
        inputs:
            - robots (list): list of Robot class objects
            - env (Env): the environment
        """
        self.robots = []
        self.env = env

    def plan_decoupled_path(self, start, goal, solver="RRT*", 
                            step_length=.5, goal_sample_rate=.05, num_samples=10000, r=10):
        """
        Plan decoupled path from a given start to a given goal, using a single-agent solver.

        inputs:
            - start (tuple): (x,y) location of start 
            - goal (tuple): (x,y) location of goal 
            - solver (string): Name of single-agent solver to be used
            - step_length (float): 
            - goal_sample_rate (float):
            - num_samples (int):
            - r (float):
        output:
            - path (list): list of nodes in path 
        """
        if solver == "RRT":
            rrt = RRT(start, goal, step_length, goal_sample_rate, num_samples)
            path = rrt.run()
        elif solver == "RRT*":
            rrtstar = RRTStar(start, goal, step_length, goal_sample_rate, num_samples,r)
            path = rrtstar.run()
        else:
            print(f"Solver {solver} is not yet implemented. Choose something else.")
            return None

        return path

    def initialize_robots_with_paths(self, starts, goals):
        """
        NOTE This function (and the plan_decoupled_paths function could just exist as 
        helper functions elsewhere to make this class easier to understand)
        """
        for start, goal in zip(starts, goals):
            path = self.plan_decoupled_path((start.x,start.y), (goal.x,goal.y))
            r = Robot(start,goal,path)
            self.robots.append(r)


            


    def find_all_conflicts(self, t):
        """
        Loop over all the robots, checking for both node conflicts and edge
        conflicts. 
        """
        conflicts = []

        for r in self.robots:
            next_node = r.get_next_node()

            # Check for node conflicts
            robots_at_node = self.check_for_node_conflicts(next_node, r.get_label())
            if len(robots_at_node) > 1:
                robots_at_node.sort()
                c = Conflict(robots_at_node, t+1, 'node')
                if not self.conflict_already_found(c, conflicts): conflicts.append(c)

            # check for edge conflicts
            robots_with_edge_conflict = self.check_for_edge_conflicts(r)
            # print(robots_with_edge_conflict)
            if len(robots_with_edge_conflict) > 0:
                # print("found an edge conflict")
                robots_with_edge_conflict.append(r.get_label())
                robots_with_edge_conflict.sort()
                c = Conflict(robots_with_edge_conflict, t+1, 'edge')
                if c not in conflicts:
                    conflicts.append(c)

        return conflicts

    def find_edge_conflicts(self, robot):

        """
        Check if there are any edge conflicts with the input robot.
        An edge conflict occurs between robot A and robot B if robot A's current node 
        is robot B's next node and robot B's curent node is robot A's next node
        Input:
            robot - the robot for which we are searching for edge conflicts 
        Outputs:
            indices - a list of indices of robots who have an edge conflict with the input robot
        """

        current_node_a = robot.get_current_position()
        next_node_a = robot.get_next_node()

        indices = []
        for r in self.robots:
            if r.label == robot.label:
                continue

            current_node_b = r.get_current_position()
            next_node_b = r.get_next_node()

            if current_node_a == next_node_b and current_node_b == next_node_a:
                indices.append(r.get_label())

        return indices
    
    def find_node_conflicts(self, node, excluded_r_label):
        """
        Given some node, find all the robots whose next goal is that node.
        Return the indices of these robots.
        Inputs:
            node - the node (r,c) that we are checking for conflicts at
        Outputs:
            indices - a list of indices of robots whose next node is the input node
        """

        indices = []
        for r in self.robots:
            if r.get_next_node() == node:
                indices.append(r.label)

        return indices

    def conflict_already_found(self,new_c, C):
        """
        Return true if new_c is already in C
        """
        for c in C:
            if c.get_robot_idxs() == new_c.get_robot_idxs():
                return True

    def all_robots_at_goal(self):
        pass

    def run(self):
        while not self.all_robots_at_goal():

            conflicts = self.find_all_conflicts()
            conflicts_covered = []
            S = []
            waiting_robots = []

            for c in conflicts:
                if c in conflicts_covered:
                    continue
                else:
                    new_conflicts_covered, s = self.place_subproblem()
                    conflicts_covered += new_conflicts_covered

                    if s is None:
                        # no subproblem can be found. Choose one robot to advance, 
                        # all others in conflict wait 
                        if c.type == 'node':    
                            num = random.randint(0,len(c.get_robot_idxs()))               
                            for i in range(len(c.get_robot_idxs())):
                                if i != num:
                                    if self.robots[c.get_robot_idxs()[i]] not in waiting_robots: 
                                        waiting_robots.append(self.robots[c.get_robot_idxs()[i]])
                        else:
                            for i in range(len(c.get_robot_idxs())):
                                if self.robots[c.get_robot_idxs()[i]] not in waiting_robots: 
                                    waiting_robots.append(self.robots[c.get_robot_idxs()[i]])
                    else:
                        # we found a valid subproblem, query the database
                        S.append(s)
                        soln = self.query_library(s)

                        sol_idx = 0
                        for robot in s.all_robots_involved_in_subproblem: 
                            robot = self.robots[robot.get_label()]
                            
                            solution = soln[sol_idx]
                            if len(solution) == 1:
                                next_node = s.get_world_coordinates(solution[0][0], solution[0][1])
                            elif solution[1][0] == -1:
                                next_node = s.get_world_coordinates(solution[0][0], solution[0][1])
                            elif len(solution) > 1:
                                next_node = s.get_world_coordinates(solution[1][0], solution[1][1])
                            
                            robot.update_next_node(next_node)
                            sol_idx +=1