 #!/usr/bin/python3
"""Read intermediate tensors generated by DumpAllTensors activity

Tools for reading/ parsing intermediate tensors.
"""

import argparse
import datetime
import numpy as np
import os
import pandas as pd
import tensorflow as tf
import json
import seaborn as sns
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import multiprocessing

from matplotlib.pylab import *
from tqdm import tqdm
# Enable large animation size
matplotlib.rcParams['animation.embed_limit'] = 2**128
# Enable tensor.numpy()
tf.compat.v1.enable_eager_execution()


############################ Helper Functions ############################
def reshape_to_matrix(array):
  """Reshape an array to a square matrix padded with np.nan at the end."""
  array = array.astype(float)
  width = math.ceil(len(array)**0.5)
  height = math.ceil(len(array)/ width)
  padded = np.pad(array=array,
                  pad_width=(0, width * height - len(array)),
                  mode='constant',
                  constant_values=np.nan)
  padded = padded.reshape(width, -1)
  return padded

def save_ani_to_video(ani, save_video_path, video_fps=5):
  Writer = animation.writers['ffmpeg']
  writer = Writer(fps=video_fps)
  #Save the movie
  ani.save(save_video_path, writer=writer, dpi=250)

def save_ani_to_html(ani, save_html_path):
  with open(save_html_path, 'w') as f:
    f.write(ani.to_jshtml())

############################ ModelMetaDataManager ############################
class ModelMetaDataManager(object):
  """Maps model name in nnapi to its graph architecture with lazy initialization.

  # Arguments
    android_build_top: the root directory of android source tree dump_dir:
    directory containing intermediate tensors pulled from device
    tflite_model_json_path: directory containing intermediate json output of
    model visualization tool (third_party/tensorflow/lite/tools:visualize) The
    json output path from the tool is always /tmp.
  """
  ############################ ModelMetaData ############################
  class ModelMetaData(object):
    """Store graph information of a model."""

    def __init__(self, tflite_model_json_path='/tmp'):
      with open(tflite_model_json_path, 'rb') as f:
        model_json = json.load(f)
      self.operators = model_json['subgraphs'][0]['operators']
      self.operator_codes = [item['builtin_code']\
                            for item in model_json['operator_codes']]
      self.output_meta_data = []
      self.load_output_meta_data()

    def load_output_meta_data(self):
      for operator in self.operators:
        data = {}
        # Each operator can only have one output
        assert(len(operator['outputs']) == 1)
        data['output_tensor_index'] = operator['outputs'][0]
        data['fused_activation_function'] = operator\
          .get('builtin_options', {})\
          .get('fused_activation_function', '')
        data['operator_code'] = self.operator_codes[operator['opcode_index']]
        self.output_meta_data.append(data)

  def __init__(self, android_build_top, dump_dir, tflite_model_json_dir='/tmp'):
    # key: nnapi model name, value: ModelMetaData
    self.models = dict()
    self.ANDROID_BUILD_TOP = android_build_top + "/"
    self.TFLITE_MODEL_JSON_DIR = tflite_model_json_dir + "/"
    self.DUMP_DIR = dump_dir + "/"
    self.nnapi_to_tflite_name = dict()
    self.tflite_to_nnapi_name = dict()
    self.__load_mobilenet_topk_aosp()
    self.model_names = sorted(os.listdir(dump_dir))

  def __load_mobilenet_topk_aosp(self):
    """Load information about tflite and nnapi model names."""
    json_path = '{}/{}'.format(
        self.ANDROID_BUILD_TOP,
        'test/mlts/models/assets/models_list/mobilenet_topk_aosp.json')
    with open(json_path, 'rb') as f:
      topk_aosp = json.load(f)
    for model in topk_aosp['models']:
      self.nnapi_to_tflite_name[model['name']] = model['modelFile']
      self.tflite_to_nnapi_name[model['modelFile']] = model['name']

  def __get_model_json_path(self, tflite_model_name):
    """Return tflite model jason path."""
    json_path = '{}/{}.json'.format(self.TFLITE_MODEL_JSON_DIR,
                                    tflite_model_name)
    return json_path

  def __load_model(self, tflite_model_name):
    """Initialize a ModelMetaData for this model."""
    model = self.ModelMetaData(self.__get_model_json_path(tflite_model_name))
    nnapi_model_name = self.model_name_tflite_to_nnapi(tflite_model_name)
    self.models[nnapi_model_name] = model

  def model_name_nnapi_to_tflite(self, nnapi_model_name):
    return self.nnapi_to_tflite_name.get(nnapi_model_name, nnapi_model_name)

  def model_name_tflite_to_nnapi(self, tflite_model_name):
    return self.tflite_to_nnapi_name.get(tflite_model_name, tflite_model_name)

  def get_model_meta_data(self, nnapi_model_name):
    """Retrieve the ModelMetaData with lazy initialization."""
    tflite_model_name = self.model_name_nnapi_to_tflite(nnapi_model_name)
    if nnapi_model_name not in self.models:
      self.__load_model(tflite_model_name)
    return self.models[nnapi_model_name]

  def generate_animation_html(self, output_file_path, model_names=None, heatmap=True):
    """Generate a html file containing the hist and heatmap animation of all models"""
    model_names = self.model_names if model_names is None else model_names
    html_data = ''
    for model_name in tqdm(model_names):
      print(datetime.datetime.now(), 'Processing', model_name)
      html_data += '<h3>{}</h3>'.format(model_name)
      model_data = ModelData(nnapi_model_name=model_name, manager=self)
      ani = model_data.gen_error_hist_animation()
      html_data += ani.to_jshtml()
      if heatmap:
        ani = model_data.gen_heatmap_animation()
        html_data += ani.to_jshtml()
    with open(output_file_path, 'w') as f:
      f.write(html_data)

  def generate_hist_animation_html(self, model_name):
    """Generate a html hist animation for a model, used for multiprocessing"""
    html_data = '<h3>{}</h3>'.format(model_name)
    model_data = ModelData(nnapi_model_name=model_name, manager=self)
    ani = model_data.gen_error_hist_animation()
    html_data += ani.to_jshtml()
    print(datetime.datetime.now(), "Done histogram for", model_name)
    self.return_dict[model_name + "-hist"] = html_data

  def generate_heatmap_animation_html(self, model_name):
    """Generate a html hist animation for a model, used for multiprocessing"""
    model_data = ModelData(nnapi_model_name=model_name, manager=self)
    ani = model_data.gen_heatmap_animation()
    html_data = ani.to_jshtml()
    print(datetime.datetime.now(), "Done heatmap for", model_name)
    self.return_dict[model_name + "-heatmap"] = html_data

  def multiprocessing_generate_animation_html(self, output_file_path,
                                       model_names=None, heatmap=True):
    """
    Generate a html file containing the hist and heatmap animation of all models
    with multiple process.
    """
    model_names = self.model_names if model_names is None else model_names
    manager = multiprocessing.Manager()
    self.return_dict = manager.dict()
    jobs = []
    for model_name in model_names:
      for target_func in [self.generate_hist_animation_html, self.generate_heatmap_animation_html]:
        p = multiprocessing.Process(target=target_func, args=(model_name,))
        jobs.append(p)
        p.start()
    # wait for completion
    for job in jobs:
      job.join()

    with open(output_file_path, 'w') as f:
      for model_name in model_names:
        f.write(self.return_dict[model_name + "-hist"])
        f.write(self.return_dict[model_name + "-heatmap"])


############################ TensorDict ############################
class TensorDict(dict):
  """A class to store cpu and nnapi tensors.

  # Arguments
    model_dir: directory containing intermediate tensors pulled from device
  """
  def __init__(self, model_dir):
    super().__init__()
    for useNNAPIDir in ['cpu', 'nnapi']:
      dir_path = model_dir + useNNAPIDir + "/"
      self[useNNAPIDir] = self.read_tensors_from_dir(dir_path)
    self.tensor_sanity_check()
    self.max_absolute_diff, self.min_absolute_diff = 0.0, 0.0
    self.max_relative_diff, self.min_relative_diff = 0.0, 0.0
    self.layers = sorted(self['cpu'].keys())
    self.calc_range()

  def bytes_to_numpy_tensor(self, file_path):
    """Load bytes outputed from DumpIntermediateTensor into numpy tensor."""
    if 'quant' in file_path or '8bit' in file_path:
      tensor_type = tf.int8
    else:
      tensor_type = tf.float32
    with open(file_path, mode='rb') as f:
      tensor_bytes = f.read()
      tensor = tf.decode_raw(input_bytes=tensor_bytes, out_type=tensor_type)
    if np.isnan(np.sum(tensor)):
      print('WARNING: tensor contains inf or nan')
    return tensor.numpy()

  def read_tensors_from_dir(self, dir_path):
    tensor_dict = dict()
    for tensor_file in os.listdir(dir_path):
      tensor = self.bytes_to_numpy_tensor(dir_path + tensor_file)
      tensor_dict[tensor_file] = tensor
    return tensor_dict

  def tensor_sanity_check(self):
    # Make sure the cpu tensors and nnapi tensors have the same outputs
    assert(set(self['cpu'].keys()) == set(self['nnapi'].keys()))
    print('Tensor sanity check passed')

  def calc_range(self):
    for layer in self.layers:
      diff = self.calc_diff(layer, relative_error=False)
      # update absolute max, min
      self.max_absolute_diff = max(self.max_absolute_diff, np.max(diff))
      self.min_absolute_diff = min(self.min_absolute_diff, np.min(diff))
      self.absolute_range = max(abs(self.min_absolute_diff),
                                abs(self.max_absolute_diff))

  def calc_diff(self, layer, relative_error=True):
    cpu_tensor = self['cpu'][layer]
    nnapi_tensor = self['nnapi'][layer]
    assert(cpu_tensor.shape == nnapi_tensor.shape)
    diff = cpu_tensor - nnapi_tensor
    if not relative_error:
      return diff
    diff = diff.astype(float)
    cpu_tensor = cpu_tensor.astype(float)
    # Devide by max so the relative error range is conveniently [-1, 1]
    max_cpu_nnapi_tensor = np.maximum(np.abs(cpu_tensor), np.abs(nnapi_tensor))
    relative_diff = np.divide(diff, max_cpu_nnapi_tensor, out=np.zeros_like(diff),
                              where=max_cpu_nnapi_tensor>0)
    return relative_diff

  def gen_tensor_diff_stats(self, relative_error=True, return_df=True, plot_diff=False):
    stats = []
    for layer in self.layers:
      diff = self.calc_diff(layer, relative_error)
      if plot_diff:
        self.plot_tensor_diff(diff)
      if return_df:
        stats.append({
          'layer': layer,
          'min': np.min(diff),
          'max': np.max(diff),
          'mean': np.mean(diff),
          'median': np.median(diff)
        })
    if return_df:
      return pd.DataFrame(stats)

  def plot_tensor_diff(diff):
    plt.figure()
    plt.hist(diff, bins=50, log=True)
    plt.plot()


############################ Model Data ############################
class ModelData(object):
  """A class to store all relevant inormation of a model.

  # Arguments
    nnapi_model_name: the name of the model
    manager: ModelMetaDataManager
  """
  def __init__(self, nnapi_model_name, manager, seq_limit=10):
    self.nnapi_model_name = nnapi_model_name
    self.manager = manager
    self.model_dir = self.get_target_model_dir(manager.DUMP_DIR,
                                               nnapi_model_name)
    self.tensor_dict = TensorDict(self.model_dir)
    self.mmd = manager.get_model_meta_data(nnapi_model_name)
    self.stats = self.tensor_dict.gen_tensor_diff_stats(relative_error=True,
                                                        return_df=True)
    self.layers = sorted(self.tensor_dict['cpu'].keys())
    self.cmap = sns.diverging_palette(220, 20, sep=20, as_cmap=True)
    self.seq_limit = seq_limit

  def get_target_model_dir(self, dump_dir, target_model_name):
    # Get the model directory path
    target_model_dir = dump_dir + target_model_name + "/"
    return target_model_dir

  def __sns_distplot(self, layer, bins, ax, range, relative_error):
    sns.distplot(self.tensor_dict.calc_diff(layer, relative_error=relative_error), bins=bins,
             hist_kws={"range":range, "log":True}, ax=ax, kde=False)

  def __plt_hist(self, layer, bins, ax, range, relative_error):
    ax.hist(self.tensor_dict.calc_diff(layer, relative_error=relative_error), bins=bins,
             range=range, log=True)

  def __get_layer_num(self):
    if self.seq_limit:
      return min(len(self.layers), len(self.mmd.output_meta_data) * self.seq_limit)
    return len(self.layers)

  def update_hist_data(self, i, fig, ax1, ax2, bins=50, plot_library='sns'):
    # Use % because there may be multiple testing samples
    operation = self.mmd.output_meta_data[i % len(self.mmd.output_meta_data)]['operator_code']
    layer = self.layers[i]
    subtitle = fig.suptitle('{} | {}\n{}'
                      .format(self.nnapi_model_name, layer, operation),
                      fontsize='x-large')
    for ax in (ax1, ax2):
      ax.clear()
    ax1.set_title('Relative Error')
    ax2.set_title('Absolute Error')
    absolute_range = self.tensor_dict.absolute_range

    # Determine underlying plotting library
    hist_func = self.__plt_hist if plot_library == 'matplotlib' else self.__sns_distplot
    hist_func(layer=layer, bins=bins, ax=ax1,
              range=(-1, 1), relative_error=True)
    hist_func(layer=layer, bins=bins, ax=ax2,
              range=(-absolute_range, absolute_range), relative_error=False)

  def gen_error_hist_animation(self, save_video_path=None, video_fps=10):
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12,9))
    ani = animation.FuncAnimation(fig, self.update_hist_data, self.__get_layer_num(),
                                  fargs=(fig, ax1, ax2),
                                  interval=200, repeat=False)
    # close before return to avoid dangling plot
    if save_video_path:
      save_ani_to_video(ani, save_video_path, video_fps)
    plt.close()
    return ani

  def __sns_heatmap(self, data, ax, cbar_ax, **kwargs):
    return sns.heatmap(data, cmap=self.cmap, cbar=True, ax=ax, cbar_ax=cbar_ax,
                     cbar_kws={"orientation": "horizontal"}, center=0, **kwargs)

  def update_heatmap_data(self, i, fig, axs):
    # Use % because there may be multiple testing samples
    operation = self.mmd.output_meta_data[i % len(self.mmd.output_meta_data)]['operator_code']
    layer = self.layers[i]
    subtitle = fig.suptitle('{} | {}\n{}\n'
                      .format(self.nnapi_model_name, layer, operation),
                      fontsize='x-large')
    # Clear all the axs and redraw
    # It's important to clear the colorbars as well to avoid duplicate colorbars
    for ax_tuple in axs:
      for ax in ax_tuple:
        ax.clear()
    axs[0][0].set_title('Diff')
    axs[0][1].set_title('CPU Tensor')
    axs[0][2].set_title('NNAPI Tensor')

    reshaped_diff = reshape_to_matrix(self.tensor_dict.calc_diff(layer, relative_error=False))
    reshaped_cpu = reshape_to_matrix(self.tensor_dict['cpu'][layer])
    reshaped_nnapi = reshape_to_matrix(self.tensor_dict['nnapi'][layer])
    absolute_range = self.tensor_dict.absolute_range
    g1 = self.__sns_heatmap(data=reshaped_diff, ax=axs[0][0], cbar_ax=axs[1][0],
                            vmin=-absolute_range, vmax=absolute_range)
    g2 = self.__sns_heatmap(data=reshaped_cpu, ax=axs[0][1], cbar_ax=axs[1][1])
    g3 = self.__sns_heatmap(data=reshaped_nnapi, ax=axs[0][2], cbar_ax=axs[1][2])

  def gen_heatmap_animation(self, save_video_path=None, video_fps=10, figsize=(13,6)):
    fig = plt.figure(constrained_layout=True, figsize=figsize)
    widths = [1, 1, 1]
    heights = [7, 1]
    spec = fig.add_gridspec(ncols=3, nrows=2, width_ratios=widths,
                            height_ratios=heights)
    axs = []
    for row in range(2):
      axs.append([])
      for col in range(3):
          axs[-1].append(fig.add_subplot(spec[row, col]))

    ani = animation.FuncAnimation(fig, self.update_heatmap_data, self.__get_layer_num(),
                                  fargs=(fig, axs),
                                  interval=200, repeat=False)
    if save_video_path:
      save_ani_to_video(ani, save_video_path, video_fps)
    # close before return to avoid dangling plot
    plt.close()
    return ani

  def plot_error_heatmap(self, target_layer, vmin=None, vmax=None):
    # Plot the diff heatmap for a given layer
    target_diff = self.tensor_dict['cpu'][target_layer] - \
                  self.tensor_dict['nnapi'][target_layer]
    reshaped_target_diff = reshape_to_matrix(target_diff)
    fig, ax = subplots(figsize=(9, 9))
    plt.title('Heat Map of Error between CPU and NNAPI')
    sns.heatmap(reshaped_target_diff,
                cmap=self.cmap,
                mask=np.isnan(reshaped_target_diff),
                center=0)
    plt.show()


############################ ModelDataComparison ############################
class ModelDataComparison:
  """A class to store and compare multiple ModelData.

  # Arguments
    model_data_list: a list of ModelData to be compared. Can be modified through
    the class variable.
  """
  def __init__(self, dump_dir_list, android_build_top, tflite_model_json_dir, model_name):
    self.dump_dir_list = dump_dir_list
    self.android_build_top = android_build_top
    self.tflite_model_json_dir = tflite_model_json_dir
    self.set_model_name(model_name)

  def set_model_name(self, model_name):
    # Set model to be compared and load/ reload all model data
    self.model_name = model_name
    self.__load_data()

  def __load_data(self):
    # Load all model data
    self.manager_list = []
    self.model_data_list = []
    for i, dump_dir in enumerate(self.dump_dir_list):
      manager = ModelMetaDataManager(self.android_build_top,
                                     dump_dir,
                                     tflite_model_json_dir=self.tflite_model_json_dir)
      model_data = ModelData(nnapi_model_name=self.model_name, manager=manager)
      self.manager_list.append(manager)
      self.model_data_list.append(model_data)
    self.sanity_check()

  def sanity_check(self):
    # Check
    # 1) if there are more than one model to be compared
    # 2) The data has the same intermediate layers
    assert(len(self.model_data_list) >= 1)
    sample_model_data = self.model_data_list[0]
    for i in range(1, len(self.model_data_list)):
      assert(set(sample_model_data.tensor_dict['cpu'].keys()) ==
             set(self.model_data_list[i].tensor_dict['nnapi'].keys()))
    print('Sanity Check Passed')
    self.layers = sample_model_data.layers
    self.mmd = sample_model_data.mmd

  def update_hist_comparison_data(self, i, fig, axs, bins=50):
    # Use % because there may be multiple testing samples
    sample_model_data = self.model_data_list[0]
    operation = self.mmd.output_meta_data[i % len(self.mmd.output_meta_data)]['operator_code']
    layer = self.layers[i]
    subtitle = fig.suptitle('{} | {}\n{}'
                      .format(sample_model_data.nnapi_model_name, layer, operation),
                      fontsize='x-large')
    for row in axs:
      for ax in row:
        ax.clear()

    hist_ax = axs[0][0]
    hist_ax.set_title('Diff Histogram')
    labels = [dump_dir.split('/')[-2] for dump_dir in self.dump_dir_list]
    cmap = sns.diverging_palette(220, 20, sep=20, as_cmap=True)
    for i, ax in enumerate(axs[1]):
      model_data = self.model_data_list[i]
      axs[1][i].set_title(labels[i])
      reshaped_diff = reshape_to_matrix(
        self.model_data_list[i].tensor_dict.calc_diff(layer, relative_error=False))
      sns.heatmap(reshaped_diff, cmap=cmap, cbar=True, ax=axs[1][i], cbar_ax=axs[2][i],
                     cbar_kws={"orientation": "horizontal"}, center=0)
      sns.distplot(model_data.tensor_dict.calc_diff(layer, relative_error=False), bins=bins,
             hist_kws={"log":True}, ax=hist_ax, kde=False)
    hist_ax.legend(labels)

  def gen_error_hist_comparison_animation(self, save_video_path=None, video_fps=10):
    layers = self.layers
    N = len(self.model_data_list)
    widths = [1] * N
    heights = [N * 0.7, 1, 0.2]
    fig = plt.figure(figsize=(5 * N, 4 * N))
    gs = fig.add_gridspec(3, N, width_ratios=widths, height_ratios=heights)
    axs = [[], [], []]
    axs[0].append(fig.add_subplot(gs[0, :]))
    for i in range(N):
      # heatmap
      axs[1].append(fig.add_subplot(gs[1, i]))
      # colorbar
      axs[2].append(fig.add_subplot(gs[2, i]))
    ani = animation.FuncAnimation(fig, self.update_hist_comparison_data, len(layers),
                                  fargs=(fig, axs),
                                  interval=200, repeat=False)
    if save_video_path:
      save_ani_to_video(ani, save_video_path, video_fps)
    # close before return to avoid dangling plot
    plt.close()
    return ani


############################ NumpyEncoder ############################
class NumpyEncoder(json.JSONEncoder):
  """Enable numpy array serilization in a dictionary.

  Usage:
    a = np.array([[1, 2, 3], [4, 5, 6]])
    json.dumps({'a': a, 'aa': [2, (2, 3, 4), a], 'bb': [2]}, cls=NumpyEncoder)
  """
  def default(self, obj):
      if isinstance(obj, np.ndarray):
          return obj.tolist()
      return json.JSONEncoder.default(self, obj)

def main(args):
  output_file_path = args.output_file_path if args.output_file_path else '/tmp/intermediate.html'

  manager = ModelMetaDataManager(
    args.android_build_top,
    args.dump_dir,
    tflite_model_json_dir='/tmp')

  if args.no_parallel or args.model_name:
    generation_func = manager.generate_animation_html
  else:
    generation_func = manager.multiprocessing_generate_animation_html

  if args.model_name:
    model_data = ModelData(nnapi_model_name=model_name, manager=manager)
    print(model_data.tensor_dict)
    generation_func(output_file_path=output_file_path, model_names=[args.model_name])
  else:
    generation_func(output_file_path=output_file_path)


if __name__ == '__main__':
  # Example usage
  # python tensor_utils.py ~/android/master/ ~/android/master/intermediate/ tts_float
  parser = argparse.ArgumentParser(description='Utilities for parsing intermediate tensors.')
  parser.add_argument('android_build_top', help='Your Android build top path.')
  parser.add_argument('dump_dir', help='The dump dir pulled from the device.')
  parser.add_argument('--model_name', help='NNAPI model name. Run all models if not specified.')
  parser.add_argument('--output_file_path', help='Animation HTML path.')
  parser.add_argument('--no_parallel', help='Run on a single process instead of multiple processes.')
  args = parser.parse_args()
  main(args)