#!/usr/bin/python3

# Copyright 2018, The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Example generator

Compiles spec files and generates the corresponding C++ TestModel definitions.
Invoked by ml/nn/runtime/test/specs/generate_all_tests.sh;
See that script for details on how this script is used.

"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import traceback

import test_generator as tg

MIN_PYTHON_VERSION = (3, 7)
assert sys.version_info >= MIN_PYTHON_VERSION, "requires Python 3.7 or newer"

# See ToCpp()
COMMENT_KEY = "__COMMENT__"

# Take a model from command line
def ParseCmdLine():
    parser = tg.ArgumentParser()
    parser.add_argument("-e", "--example", help="the output example file or directory")
    args = tg.ParseArgs(parser)
    tg.FileNames.InitializeFileLists(args.spec, args.example)

# Write headers for generated files, which are boilerplate codes only related to filenames
def InitializeFiles(example_fd):
    specFileBase = os.path.basename(tg.FileNames.specFile)
    fileHeader = """\
// Generated from {spec_file}
// DO NOT EDIT
// clang-format off
#include "TestHarness.h"
using namespace test_helper;  // NOLINT(google-build-using-namespace)
"""
    if example_fd is not None:
        print(fileHeader.format(spec_file=specFileBase), file=example_fd)

def IndentedStr(s, indent):
    return ("\n" + " " * indent).join(s.split('\n'))

def ToCpp(var, indent=0):
    """Get the C++-style representation of a Python object.

    For Python dictionary, it will be mapped to C++ struct aggregate initialization:
        {
            .key0 = value0,
            .key1 = value1,
            ...
        }

    For Python list, it will be mapped to C++ list initalization:
        {value0, value1, ...}

    In both cases, value0, value1, ... are stringified by invoking this method recursively.
    """
    if isinstance(var, dict):
        if not var:
            return "{}"
        comment = var.get(COMMENT_KEY)
        comment = "" if comment is None else " // %s" % comment
        str_pair = lambda k, v: "    .%s = %s" % (k, ToCpp(v, indent + 4))
        agg_init = "{%s\n%s\n}" % (comment,
                                   ",\n".join(str_pair(k, var[k])
                                              for k in var.keys()
                                              if k != COMMENT_KEY))
        return IndentedStr(agg_init, indent)
    elif isinstance(var, (list, tuple)):
        return "{%s}" % (", ".join(ToCpp(i, indent) for i in var))
    elif type(var) is bool:
        return "true" if var else "false"
    elif type(var) is float:
        return tg.PrettyPrintAsFloat(var)
    else:
        return str(var)

def GetSymmPerChannelQuantParams(extraParams):
    """Get the dictionary that corresponds to test_helper::TestSymmPerChannelQuantParams."""
    if extraParams is None or extraParams.hide:
        return {}
    else:
        return {"scales": extraParams.scales, "channelDim": extraParams.channelDim}

def GetOperandStruct(operand):
    """Get the dictionary that corresponds to test_helper::TestOperand."""
    return {
        COMMENT_KEY: operand.name,
        "type": "TestOperandType::" + operand.type.type,
        "dimensions": operand.type.dimensions,
        "numberOfConsumers": len(operand.outs),
        "scale": operand.type.scale,
        "zeroPoint": operand.type.zeroPoint,
        "lifetime": "TestOperandLifeTime::" + operand.lifetime,
        "channelQuant": GetSymmPerChannelQuantParams(operand.type.extraParams),
        "isIgnored": isinstance(operand, tg.IgnoredOutput),
        "data": "TestBuffer::createFromVector<{cpp_type}>({data})".format(
            cpp_type=operand.type.GetCppTypeString(),
            data=operand.GetListInitialization(),
        )
    }

def GetOperationStruct(operation):
    """Get the dictionary that corresponds to test_helper::TestOperation."""
    return {
        "type": "TestOperationType::" + operation.optype,
        "inputs": [op.model_index for op in operation.ins],
        "outputs": [op.model_index for op in operation.outs],
    }

def GetSubgraphStruct(subgraph):
    """Get the dictionary that corresponds to test_helper::TestSubgraph."""
    return {
        COMMENT_KEY: subgraph.name,
        "operands": [GetOperandStruct(op) for op in subgraph.operands],
        "operations": [GetOperationStruct(op) for op in subgraph.operations],
        "inputIndexes": [op.model_index for op in subgraph.GetInputs()],
        "outputIndexes": [op.model_index for op in subgraph.GetOutputs()],
    }

def GetModelStruct(example):
    """Get the dictionary that corresponds to test_helper::TestModel."""
    return {
        "main": GetSubgraphStruct(example.model),
        "referenced": [GetSubgraphStruct(model) for model in example.model.GetReferencedModels()],
        "isRelaxed": example.model.isRelaxed,
        "expectedMultinomialDistributionTolerance":
                example.expectedMultinomialDistributionTolerance,
        "expectFailure": example.expectFailure,
        "minSupportedVersion": "TestHalVersion::%s" % (
                example.model.version if example.model.version is not None else "UNKNOWN"),
    }

def DumpExample(example, example_fd):
    assert example.model.compiled
    template = """\
namespace generated_tests::{spec_name} {{

const TestModel& get_{example_name}() {{
    static TestModel model = {aggregate_init};
    return model;
}}

const auto dummy_{example_name} = TestModelManager::get().add("{test_name}", get_{example_name}());

}}  // namespace generated_tests::{spec_name}
"""
    print(template.format(
            spec_name=tg.FileNames.specName,
            test_name=str(example.testName),
            example_name=str(example.examplesName),
            aggregate_init=ToCpp(GetModelStruct(example), indent=4),
        ), file=example_fd)


if __name__ == '__main__':
    ParseCmdLine()
    tg.Run(InitializeFiles=InitializeFiles, DumpExample=DumpExample)
