import numpy as np

class QueryDatabase:
    def __init__(self, lib2x3, lib3x3, lib5x2):
        self.lib2x3 = lib2x3
        self.lib3x3 = lib3x3
        self.lib5x2 = lib5x2 

    def order_query(self, starts, goals):
        """
        Order the starts and goals in the way that the library expects to see them
        Ordering is determined the start node.
        We use row major ordering:
         ___ ___ ___
        |_6_|_7_|_7_|
        |_3_|_4_|_5_|
        |_0_|_1_|_2_|

        Inputs:
            starts - the starts of each robot to be reordered
            goals - the goals of each robot to be reordered
        Outputs - 
            ordered_starts - The starts in the correct order
            ordered_goals - The goals in the correct order
            new_to_old - The mapping of indices, so that we can recover the original order.
        """
        fake_starts = []
        for start in starts:
            fake_starts.append([3*(start[0]), start[1]])
            
        # get the sums of the starts
        sum_starts = []
        for start in fake_starts:
            sum_starts.append(sum(start))

        # use argsort to sort them
        sum_starts = np.asarray(sum_starts)

        sum_starts_sorted_idxs = np.argsort(sum_starts)
        # sum_starts_sorted_idxs = np.flip(sum_starts_sorted_idxs)

        ordered_starts = [starts[i] for i in sum_starts_sorted_idxs]
        ordered_goals = [goals[i] for i in sum_starts_sorted_idxs]

        new_to_old= [0]*len(sum_starts_sorted_idxs)
        for i in range(len(sum_starts_sorted_idxs)):
            new_to_old[sum_starts_sorted_idxs[i]] = i

        return ordered_starts, ordered_goals, new_to_old
        
    def restore_original_order(self, ordered, new_to_old_idx):
        """
        Order a solution according to its original start and goal ordering. 
        Inputs:
            ordered - list of paths, given in order from the database
            new_to_old_idx - a mapping from the ordered index to the old order index.
        Outputs:
            og_order - the same solution, but in the order of the orignal order given by new_to_old_idx
        """
        og_order = [ordered[new_to_old_idx[i]] for i in range(len(ordered))]
        return og_order

    def query_library(self, s):
        """
        Query the library to get a solution for the subproblem s

        inputs:
            s - (instance of subproblem class) The subproblem the we need a solution for
        outputs:
            sol - (list of paths) A valid solution for s. This is a list of paths 
        """

        start_nodes = s.get_starts()
        goal_nodes = s.get_goals()
        if start_nodes == goal_nodes:
            # print(f"start and goal are equal")
            sol = []
            for goal in goal_nodes: sol.append([goal])
            return sol
         
        
        sol = None
        count = 0
        while sol is None:
            for i in range(8):
                if s.type == 23:
                    start_nodes = s.get_starts()
                    goal_nodes = s.get_goals()


                    # print(f"type = {s.type}")
                    # print(f"tl = {s.top_left}")
                    # print(f"br = {s.bottom_right}")
                    # for r in s.all_robots_involved_in_subproblem: print(r.get_label())
                    # print(f"temp starts = {s.temp_starts}")
                    # print(f"temp goals = {s.temp_goals}")
                    # print(f"goal nodes before = {goal_nodes}")
                    # print(f"start nodes before = {start_nodes}")
                    # print(f"goal nodes before = {goal_nodes}")

                    # reorder the starts and goals for the query
                    start_nodes, goal_nodes, mapping_idxs = self.order_query(start_nodes, goal_nodes)
                    sol = self.lib_2x3.get_matching_solution(start_nodes, goal_nodes)

                    # reorder the solution to match the robots in our og order
                    if sol is not None: sol = self.restore_original_order(sol, mapping_idxs)

                    # print(f"subproblem = {s.subproblem_layout}")
                    # print(f"start nodes after= {start_nodes}")
                    # print(f"goal nodes after = {goal_nodes}\n\n\n")

                    
                    if sol is not None: break

                elif s.type == 33:
                    start_nodes = s.get_starts()
                    goal_nodes = s.get_goals()

                    # print(f"type = {s.type}")
                    # print(f"tl = {s.top_left}")
                    # print(f"br = {s.bottom_right}")
                    # print(f"temp starts = {s.temp_starts}")
                    # print(f"temp goals = {s.temp_goals}")
                    # print(f"goal nodes before = {goal_nodes}")
                    # print(f"start nodes before = {start_nodes}")
                    # print(f"goal nodes before = {goal_nodes}")

                    # reorder the starts and goals for the query
                    start_nodes, goal_nodes, mapping_idxs = self.order_query(start_nodes, goal_nodes)
                    sol = self.lib_3x3.get_matching_solution(start_nodes, goal_nodes)

                    # reorder the solution to match the robots in our og order
                    if sol is not None: sol = self.restore_original_order(sol, mapping_idxs)

                    # print(f"subproblem = {s.subproblem_layout}")
                    # print(f"start nodes after= {start_nodes}")
                    # print(f"goal nodes after = {goal_nodes}")

                    if sol is not None: break

                elif s.type == 25:
                    start_nodes = s.get_starts()
                    goal_nodes = s.get_goals()


                    # print(f"subproblem = {len(s.subproblem_layout)}")

                    

                    if len(s.subproblem_layout) == 5:
                        obs_locat = s.subproblem_layout[2][1][0]

                        if self.obstacle_map[obs_locat[0]][obs_locat[1]]:

                            # print(f"type = {s.type}")
                            # print(f"tl = {s.top_left}")
                            # print(f"br = {s.bottom_right}")
                            # print(f"temp starts = {s.temp_starts}")
                            # print(f"temp goals = {s.temp_goals}")
                            # print(f"start nodes before = {start_nodes}")
                            # print(f"goal nodes before = {goal_nodes}")


                            # reorder the starts and goals for the query
                            start_nodes, goal_nodes, mapping_idxs = self.order_query(start_nodes, goal_nodes)
                            sol = self.lib_2x5.get_matching_solution(start_nodes, goal_nodes)

                            # reorder the solution to match the robots in our og order
                            if sol is not None: 
                                sol = self.restore_original_order(sol, mapping_idxs)

                            # print(f"subproblem = {s.subproblem_layout}")
                            # print(f"start nodes after= {start_nodes}")
                            # print(f"goal nodes after = {goal_nodes}\n\n\n")
                            if sol is not None: 
                                break
                    #     else:
                    #         print("obs in wrong spot")
                    # else:
                    #     print("subproblem wrong shape")
                    
                s.rotate()
                count += 1
            if sol is None: 
                s.flip()
            if count >= 10: break

        if sol == None: raise Exception("Error. Failed to get solution from database")
        return sol