from collections import defaultdict
from os import PathLike
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple, Union

import networkx as nx
import onnx

from . import graph_ir as g
from .onnx_attr import node_attr_to_dict

GraphT = onnx.GraphProto
NodeT = onnx.NodeProto
NodeT.__hash__ = lambda self: id(self)
NodeT.__repr__ = NodeT.__str__ = lambda self: self.name


class MarkedSubGraph:
    """A subgraph with information on how it should replace a node in a super graph.

    subgraph: a nx.DiGraph subgraph
    entry_edges: a list of edges from nodes "outside" to nodes in self.subgraph
    exit: the exit node of the subgraph.
        When this subgraph replaces a node `n`, self.exit will be connected to
        whateven `n` is connected to.
    """

    def __init__(self, subgraph: nx.DiGraph, entry_edges, exit) -> None:
        assert all(to in subgraph for _, to, _ in entry_edges)
        assert exit in subgraph
        self.subgraph, self.exit = subgraph, exit
        self.entry_edges = [(f, t, {"index": i}) for f, t, i in entry_edges]

    @classmethod
    def idiomatic_1to2(cls, node1, node2, predecessors):
        """Create an idiomatic replacement as follow:

        node(arg1, arg2, arg3) -> node2(node1(arg1, arg2), arg3)"""
        p0, p1, p2 = predecessors
        graph = nx.DiGraph()
        graph.add_edge(node1, node2, index=0)
        return cls(graph, [(p0, node1, 0), (p1, node1, 1), (p2, node2, 1)], node2)


EmitNodeT = Union[MarkedSubGraph, g.DFGNode]


class DFG(object):
    """ONNX model translated into DFG with `DFGNode`s.

    This class has a DFG, input/output information, and a clear traverse order
    (think dominant tree), and is easier for CodeGen classes to work with."""

    def __init__(self, graph: GraphT):
        self._check_model(graph)
        self._var_count = 0
        # Build explicit DFG with ONNX nodes
        onnx_graph = self._build_onnx_dfg(graph)
        # Convert ONNX dfg into DFGNode DFG
        self.graph = self._build_dfg(onnx_graph)
        # Find out input nodes and output node (unique)
        # removing dead nodes along the way if any
        self.inputs, self.output = self._dce_get_io_info()

    ################ Interfaces:

    @property
    def traverse_order(self) -> List[g.DFGNode]:
        """Get topological order of computational graph by use-def relation."""
        return list(nx.topological_sort(self.graph))

    def node_args(self, node: g.DFGNode):
        """Get input arguments of node."""
        sorted_edges = sorted(self.graph.in_edges(node, "index"), key=lambda p: p[2])
        return [e[0] for e in sorted_edges]

    def dump_weights(self, output_dir: PathLike) -> None:
        """Dump `WeightTensor`s into output_dir."""
        output_dir = Path(output_dir)
        for node in self.graph.nodes:
            if not isinstance(node, g.WeightTensor):
                continue
            node.dump_weight(output_dir / (node.new_name + "_path.bin"))

    ################ Internal methods (high-level):

    @staticmethod
    def _check_model(onnx_graph: GraphT):
        """Check model validaty and single output (which is our limitation)"""

        import warnings
        from onnx import checker, onnx_cpp2py_export

        # try use onnx's own model checker before converting any model
        try:
            checker.check_graph(onnx_graph)
        except onnx_cpp2py_export.checker.ValidationError as e:
            warnings.warn(str(e))
        if any(len(n.output) > 1 for n in onnx_graph.node):
            raise ValueError("All node must have single output")
        if len(onnx_graph.output) > 1:
            raise ValueError("Graph must have single output")

    @staticmethod
    def _build_onnx_dfg(graph: GraphT) -> nx.DiGraph:
        """Creates a DiGraph (by use-def relation) of onnx nodes from onnx GraphProto.
        DiGraph is easier to use as a graph compared to GraphProto where use-def is implicit."""

        ret_graph = nx.DiGraph()
        onnx_defs, onnx_uses = def_use(graph.node)
        tensors = extract_tensors_from_graph(graph)
        ret_graph.add_nodes_from(graph.node)
        for onnx_value_name, use_nodes in onnx_uses.items():
            def_node = onnx_defs.get(onnx_value_name)
            if def_node is None:
                def_node = tensors[onnx_value_name]
            for use_node, used_at_narg in use_nodes:
                ret_graph.add_edge(def_node, use_node, index=used_at_narg)
        return ret_graph

    def _build_dfg(self, onnx_graph: nx.DiGraph) -> nx.DiGraph:
        """Translate _build_onnx_dfg output into DFGNode DFG.
        
        First run some passes to process subgraphs that needs to be
        processed together, then each unprocessed node is generated into
        1 or more nodes."""

        # Remove subgraphs that can be a single Flatten instead
        onnx_graph = detect_flatten(onnx_graph)
        # Remove subgraphs that look like padding but does nothing
        onnx_graph = remove_no_padding(onnx_graph)
        # For each onnx node, generate our nodes
        node_to_nodes, error_nodes = {}, []
        for onnx_node in nx.topological_sort(onnx_graph):
            our_nodes = self._emit_node(onnx_graph, onnx_node)
            if our_nodes is None:
                error_nodes.append(onnx_node)
            else:
                node_to_nodes[onnx_node] = our_nodes
        if error_nodes:
            error_repr = [f"{n.name}({n.op_type})" for n in error_nodes]
            if len(error_nodes) > 10:  # Magic number
                raise ValueError(f"Unsupported operators (first 10): {error_repr[:10]}")
            else:
                raise ValueError(f"Unsupported operators: {error_repr}")
        # Apply node_to_nodes replacement on onnx_graph to create a new DFG
        return build_graph_with_mapping(onnx_graph, node_to_nodes)

    def _dce_get_io_info(self):
        inputs = [n for n in self.graph if isinstance(n, g.InputTensor)]
        inputs_set = set(inputs)
        reachables = set()
        for component in nx.connected_components(self.graph.to_undirected()):
            # If any inputs goes into this subgraph, it's alive.
            if set(component).intersection(inputs_set):
                reachables.update(component)
        unreachables = set(self.graph) - reachables
        # Remove nodes unreachable from input
        self.graph.remove_nodes_from(unreachables)
        # Then outputs are nodes with out_degree = 0
        outputs = [n for n in self.graph if self.graph.out_degree[n] == 0]
        assert len(outputs) == 1
        return inputs, outputs[0]

    @staticmethod
    def _emit_node(in_graph: nx.DiGraph, node: NodeT) -> Optional[EmitNodeT]:
        predec = sorted_inputs(in_graph, node)
        if isinstance(node, g.DFGNode):
            # Directly add node into return graph.
            return node

        attrs = node_attr_to_dict(node)
        if node.op_type == "Conv":
            weight_tensor = predec[1]
            assert isinstance(weight_tensor, g.WeightTensor)
            if len(weight_tensor.shape) != 4:
                return None  # Only supports 2D conv
            conv_node = g.Conv2DNode(node.name, attrs)
            if len(predec) == 2:
                return conv_node
            # Split into conv followed by an addition
            bias_node = g.BiasAddNode(f"Bias_{node.name.split('_')[-1]}")
            return MarkedSubGraph.idiomatic_1to2(conv_node, bias_node, predec)
        elif node.op_type in ("MatMul", "Gemm"):
            mul_node = g.MatMulNode(node.name, attrs)
            if node.op_type == "Gemm":
                mul_node.gemm_transpose(node, predec)
            if len(predec) == 2:
                return mul_node
            # Split into mul followed by an addition
            bias_node = g.BiasAddNode(f"Bias_{node.name.split('_')[-1]}")
            return MarkedSubGraph.idiomatic_1to2(mul_node, bias_node, predec)
        one_to_one_nodes = {
            "MaxPool": g.MaxPool2DNode,
            "AveragePool": g.AveragePool2DNode,
            "Add": g.AddNode,
            "Softmax": g.SoftMaxNode,
            "Relu": g.ReluNode,
            "Tanh": g.TanhNode,
            "BatchNormalization": g.BatchNormalizationNode,
            "Pad": g.PadNode,
            "Identity": g.IdentityNode,
            "Flatten": g.FlattenNode,
        }
        if node.op_type in one_to_one_nodes:
            return one_to_one_nodes[node.op_type](node.name, attrs)
        return None


def def_use(nodes: Iterable) -> Tuple[dict, dict]:
    """Computes def/use relation from a list of node.

    This method is duck-typed and operates on any node defining .input and .output.
    """
    defs, uses = {}, defaultdict(list)
    for n in nodes:
        for i, input_ in enumerate(n.input):
            uses[input_].append((n, i))
        for output in n.output:
            defs[output] = n
    return defs, uses


def remove_no_padding(graph: nx.DiGraph) -> nx.DiGraph:
    """Remove subgraphs that look like padding but does nothing."""
    for node in list(graph.nodes):
        if node.op_type != "Pad":
            continue
        input_args = sorted_inputs(graph, node)
        # Find the second input argument to Pad (will be a Constant node)
        # and take that away as well.
        nct = input_args[1]
        padding = node_attr_to_dict(nct)["value"]
        if any(p != 0 for p in padding):
            continue
        # Connect input of Pad to where output of Pad goes
        succ = graph.out_edges(node, "index")
        for _, to, index in succ:
            graph.add_edge(input_args[0], to, index=index)
        # Remove nodes
        graph.remove_nodes_from([node, nct])
    return graph


def detect_flatten(graph: nx.DiGraph) -> nx.DiGraph:
    """Look for a shape-gather-unsqueeze-concat-reshape chain and replace that with flatten."""

    for node in list(graph.nodes):
        if node.op_type != "Shape":
            continue
        ng = get_next_in_chain(graph, "Gather", node)
        # Find the second input argument to Gather (will be a Constant node)
        # and take that away as well.
        nct = sorted_inputs(graph, ng)[1]
        nu = get_next_in_chain(graph, "Unsqueeze", ng)
        nc = get_next_in_chain(graph, "Concat", nu)
        nr = get_next_in_chain(graph, "Reshape", nc)
        if nr is None:
            continue
        _, suffix = node.name.split("_")
        gen_node = g.FlattenNode(f"Flatten_{suffix}")
        replace_chain_with_node_(graph, [node, ng, nct, nu, nc, nr], gen_node)
    return graph


def get_next_in_chain(
    graph: nx.DiGraph, type_: str, node: Optional[NodeT]
) -> Optional[NodeT]:
    """
    Get a unique user node of the unique output of Node `node`,
    and return it if it has Type `type_`.
    """
    if node is None or len(node.output) != 1:
        return None  # Propagates None; Unique output
    users = list(graph.neighbors(node))
    if len(users) != 1 or users[0].op_type != type_:
        return None  # Unique user of the output; Correct type
    return users[0]


def replace_chain_with_node_(graph: nx.DiGraph, chain: list, node) -> nx.DiGraph:
    inputs = sorted_inputs(graph, chain[0])
    succ = graph.out_edges(chain[-1], "index")
    for i, n in enumerate(inputs):
        graph.add_edge(n, node, index=i)
    for _, to, index in succ:
        graph.add_edge(node, to, index=index)
    graph.remove_nodes_from(chain)
    return graph


def build_graph_with_mapping(
    graph: nx.DiGraph, node_mapping: Dict[NodeT, EmitNodeT]
) -> nx.DiGraph:
    graph = graph.copy()
    single_node, multi_node = {}, {}
    for replace_node, by_node in node_mapping.items():
        if isinstance(by_node, g.DFGNode):
            single_node[replace_node] = by_node
        else:
            multi_node[replace_node] = by_node
    # We do one-to-many replacements first
    # because their predecessors are specified as onnx nodes.
    for replace_node, subgraph in multi_node.items():
        # Add subgraph itself
        graph = nx.compose(graph, subgraph.subgraph)
        # Add in edges
        graph.add_edges_from(subgraph.entry_edges)
        # Add out edges
        succ = graph.out_edges(replace_node, "index")
        for _, to, index in succ:
            graph.add_edge(subgraph.exit, to, index=index)
        # Remove old node
        graph.remove_node(replace_node)
    # Then do all one-to-one replacements.
    graph = nx.relabel_nodes(graph, single_node)
    return graph


def extract_tensors_from_graph(onnx_graph: GraphT) -> Dict[str, g.TensorNode]:
    tensors = {}
    # parse weight
    weight_cnt = 0
    for weight_tensor in onnx_graph.initializer:
        tensors[weight_tensor.name] = g.WeightTensor(
            weight_tensor, f"weight_{weight_cnt}"
        )
        weight_cnt += 1
    # parse input
    input_cnt = 0
    for input_ in onnx_graph.input:
        if input_.name in tensors:
            continue
        tensors[input_.name] = g.InputTensor(input_, f"input_{input_cnt}")
        input_cnt += 1
    return tensors


def sorted_inputs(graph: nx.DiGraph, node):
    sorted_edges = sorted(graph.in_edges(node, "index"), key=lambda p: p[2])
    return [e[0] for e in sorted_edges]


def draw_graph(graph: nx.DiGraph, output_to):
    from networkx.drawing.nx_agraph import to_agraph

    agraph = to_agraph(graph)
    agraph.layout("dot")
    agraph.draw(output_to)