import pickle, os
import numpy as np

class Library:
    def __init__(self, folder):
        self.folder = folder
        self.key_to_idx = {} 
        self.np_data = []

    def read_library_from_file(self):
        self.key_to_idx = pickle.load(open(f"{self.folder}/sol_to_idx.pickle", "rb"))
        arrays = np.load(open(f"{self.folder}/all_robots.npz", "rb"))
        for arr in arrays:
            self.np_data.append(arrays[arr])

    def get_size(self):
        return len(self.key_to_idx)

    def get_matching_solution(self, start_nodes, goal_nodes):
        query = ''
        for i in range(len(start_nodes)):
            start = start_nodes[i]
            goal = goal_nodes[i]

            start_x = str(start[0])
            start_y = str(start[1])
            goal_x = str(goal[0])
            goal_y = str(goal[1])

            query += start_x
            query += start_y
            query += goal_x
            query += goal_y

        if query not in self.key_to_idx: return None

        idx = self.key_to_idx[query]
        num_robots = len(query) // 4
        sol = self.np_data[num_robots-1][idx]

        formatted_sol = []

        for s in sol:
            formatted_s = []
            for i in range(0, len(s), 2):
                formatted_s.append([s[i], s[i+1]])
            formatted_sol.append(formatted_s)

        return formatted_sol


if __name__ == "__main__":
    # test reading in a library
    lib = Library("2x3_library")
    lib.read_library_from_file()

    print(lib.get_size())

    start_nodes = [[0, 1], [1, 2]]
    goal_nodes = [[1, 2], [0, 1]]
    # goal_nodes = [[0, 1],[1, 2]]
    # start_nodes= [[1, 2],[0, 2]]
    start_nodes = [[0, 0], [0, 2], [1, 2]]
    goal_nodes = [[0, 0],[1, 2], [0, 2]]
    """start_nodes = [[0, 1], [1, 0]]
    goal_nodes = [[0, 2], [1, 2]]
#     start nodes before = [[0, 0], [0, 2], [1, 1]]
# goal nodes before = [[0, 1], [1, 0], [1, 2]]"""
    # start nodes = [(0, 1), (2, 1)]
    # goal nodes = [(2, 1), (0, 1)]
    start_nodes = [[0, 0], [0, 2]]
    goal_nodes = [[0, 2], [0, 0]]
    sol = lib.get_matching_solution(start_nodes, goal_nodes)
    print(sol)