import sys
import os
import shutil
import subprocess
from keras.utils.np_utils import to_categorical
from keras.models import load_model
from frontend.approxhpvm_translator import translate_to_approxhpvm
from frontend.weight_utils import dumpCalibrationData
from frontend.weight_utils import reloadHPVMWeights


# Every CNN Benchmark must inherit from Benchmark class
# Defines common interfaces and virtual methods to be overridden by child classes
class Benchmark:

    def __init__(self, name, reload_dir, keras_model_file, data_dir, src_dir, num_classes, batch_size=500):
        self.name = name
        self.reload_dir = reload_dir
        self.keras_model_file = keras_model_file
        self.data_dir = data_dir
        self.src_dir = src_dir
        self.num_classes = num_classes
        self.batch_size = batch_size
        
        
    def buildModel(self):
        return

    def data_preprocess(self):
        return
    
    def trainModel(self, X_train, y_train, X_test, y_test):
        return

    def inference(self):
        return


    # Compiles frontend generated sources
    def compileSource(self, working_dir):
              
        src_file = os.getcwd() + "/" + working_dir + "/approxhpvm_src.cc"
        target_binary = os.getcwd() + "/" + working_dir + "/HPVM_binary"
        approx_conf_file = "tuner_confs.txt"

        FNULL = open(os.devnull, 'w')
        
        try:
            subprocess.run([
                "approxhpvm.py", 
                "-h"
            ], check=True, stdout=FNULL)
            
        except:
            print ("\n\n ERROR: Could not find approxhpvm.py (HPVM compile script)!! \n\n")
            print ("To Compile, Must set PATH to include approxhpvm.py script. Do the following: ")
            print ("**** export PATH=${PATH_TO_YOUR_HPVM_INSTALLATION}/build/bin/:$PATH *****")
            sys.exit(1)


        try:
            subprocess.run([
                "approxhpvm.py", src_file, target_binary,
                "-t", "tensor", "--conf-file", approx_conf_file
            ], check=True)
        except:
            print ("\n\n ERROR: HPVM Compilation Failed!! \n\n")
            sys.exit(1)

        
    def printUsage(self):

        print ("Usage: python ${benchmark.py} [hpvm_reload|train] [frontend] [compile]")
        sys.exit(0)

        
    def run(self, argv):

      if len(argv) < 2:
          self.printUsage()
          
      print ("Build Model ...")
      # Virtual method call implemented by each CNN
      model = self.buildModel()

      print ("Data Preprocess... \n")
      # Virtual method call to preprocess test and train data 
      X_train, y_train, X_test, y_test, X_tuner, y_tuner = self.data_preprocess()   

      if argv[1] == "hpvm_reload":
        print ("loading weights .....\n\n")  
        model = reloadHPVMWeights(model, self.reload_dir, self.keras_model_file)

      elif argv[1] == "keras_reload":
        model.load_weights(self.keras_model_file)
        model.compile(loss='categorical_crossentropy',
                    optimizer='adam',
                    metrics=['accuracy'])   

      elif argv[1] == "train":
        print ("Train Model ...")
        model = self.trainModel(model, X_train, y_train, X_test, y_test)
      else:
          self.printUsage()

          
      score = model.evaluate(X_test, to_categorical(y_test, self.num_classes), verbose=0)
      print('Test accuracy2:', score[1])

      f = open("final_accuracy", "w+")
      f.write(str(score[1] * 100))
      f.close()


      if len(argv) > 2:
        if argv[2] == "frontend":

          if argv[1] == "hpvm_reload": # If reloading HPVM weights use this as directory to load from in HPVM-C generated src
              self.data_dir = self.reload_dir
          
          # Main call to ApproxHPVM-Keras Frontend
          working_dir = translate_to_approxhpvm(model,
                                                self.data_dir, self.src_dir,   
                                                X_test, y_test,
                                                X_tuner, y_tuner,
                                                self.batch_size, # FIXIT
                                                self.num_classes,
                                                (argv[1] == "hpvm_reload")) # Do not redump HPVM weights if `hpvm_reload` used

          if len(argv) > 3 and argv[3] == "compile":
            self.compileSource(working_dir)
          else:
            self.printUsage()


        if argv[2] == "keras_dump":
          model.save_weights(self.keras_model_file)

          
      elif len(argv) > 2:
        self.printUsage()