#!/usr/bin/env python3

from __future__ import print_function
import collections
import os
import sys
import logging

BANNER = "Auto-generated by generate-wrappers.py script. Do not modify"
WRAPPER_SRC_NAMES = {
    "PROD_SCALAR_MICROKERNEL_SRCS": None,
    "PROD_FMA_MICROKERNEL_SRCS": "defined(__riscv) || defined(__riscv__)",
    "PROD_ARMSIMD32_MICROKERNEL_SRCS": "defined(__arm__)",
    "PROD_FP16ARITH_MICROKERNEL_SRCS": "defined(__arm__)",
    "PROD_NEON_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
    "PROD_NEONFP16_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
    "PROD_NEONFMA_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
    "PROD_NEON_AARCH64_MICROKERNEL_SRCS": "defined(__aarch64__)",
    "PROD_NEONV8_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
    "PROD_NEONFP16ARITH_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
    "PROD_NEONFP16ARITH_AARCH64_MICROKERNEL_SRCS": "defined(__aarch64__)",
    "PROD_NEONDOT_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
    "PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS": "defined(__aarch64__)",
    "PROD_NEONDOTFP16ARITH_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
    "PROD_NEONDOTFP16ARITH_AARCH64_MICROKERNEL_SRCS": "defined(__aarch64__)",
    "PROD_NEONI8MM_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
    "PROD_SSE_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
    "PROD_SSE2_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
    "PROD_SSSE3_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
    "PROD_SSE41_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
    "PROD_AVX_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
    "PROD_F16C_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
    "PROD_XOP_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
    "PROD_FMA3_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
    "PROD_AVX2_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
    "PROD_AVX512F_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
    "PROD_AVX512SKX_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
    "PROD_AVX512VBMI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
    "PROD_AVX512VNNI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
    "PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
    "PROD_RVV_MICROKERNEL_SRCS": "defined(__riscv) || defined(__riscv__)",
    "PROD_AVXVNNI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
    "AARCH32_ASM_MICROKERNEL_SRCS": "defined(__arm__)",
    "AARCH64_ASM_MICROKERNEL_SRCS": "defined(__aarch64__)",

    # add non-prod microkernel sources here:
}

SRC_NAMES = {
    "OPERATOR_SRCS",
    "SUBGRAPH_SRCS",
    "LOGGING_SRCS",
    "XNNPACK_SRCS",
    "TABLE_SRCS",
    "JIT_SRCS",
    "PROD_SCALAR_MICROKERNEL_SRCS",
    "PROD_FMA_MICROKERNEL_SRCS",
    "PROD_ARMSIMD32_MICROKERNEL_SRCS",
    "PROD_FP16ARITH_MICROKERNEL_SRCS",
    "PROD_NEON_MICROKERNEL_SRCS",
    "PROD_NEONFP16_MICROKERNEL_SRCS",
    "PROD_NEONFMA_MICROKERNEL_SRCS",
    "PROD_NEON_AARCH64_MICROKERNEL_SRCS",
    "PROD_NEONV8_MICROKERNEL_SRCS",
    "PROD_NEONFP16ARITH_MICROKERNEL_SRCS",
    "PROD_NEONFP16ARITH_AARCH64_MICROKERNEL_SRCS",
    "PROD_NEONDOT_MICROKERNEL_SRCS",
    "PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS",
    "PROD_NEONDOTFP16ARITH_MICROKERNEL_SRCS",
    "PROD_NEONDOTFP16ARITH_AARCH64_MICROKERNEL_SRCS",
    "PROD_NEONI8MM_MICROKERNEL_SRCS",
    "PROD_SSE_MICROKERNEL_SRCS",
    "PROD_SSE2_MICROKERNEL_SRCS",
    "PROD_SSSE3_MICROKERNEL_SRCS",
    "PROD_SSE41_MICROKERNEL_SRCS",
    "PROD_AVX_MICROKERNEL_SRCS",
    "PROD_F16C_MICROKERNEL_SRCS",
    "PROD_XOP_MICROKERNEL_SRCS",
    "PROD_FMA3_MICROKERNEL_SRCS",
    "PROD_AVX2_MICROKERNEL_SRCS",
    "PROD_AVX512F_MICROKERNEL_SRCS",
    "PROD_AVX512SKX_MICROKERNEL_SRCS",
    "PROD_AVX512VBMI_MICROKERNEL_SRCS",
    "PROD_AVX512VNNI_MICROKERNEL_SRCS",
    "PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS",
    "PROD_RVV_MICROKERNEL_SRCS",
    "PROD_AVXVNNI_MICROKERNEL_SRCS",
    "AARCH32_ASM_MICROKERNEL_SRCS",
    "AARCH64_ASM_MICROKERNEL_SRCS",

    # add non-prod microkernel sources here:
}

def handle_singleline_parse(line):
    start_index = line.find("(")
    end_index = line.find(")")
    line = line[start_index+1:end_index]
    key_val = line.split(" ")
    return key_val[0], [x[4:] for x in key_val[1:]]

def update_sources(xnnpack_path, cmakefile = "XNNPACK/CMakeLists.txt"):
    sources = collections.defaultdict(list)
    with open(os.path.join(xnnpack_path, cmakefile)) as cmake:
        lines = cmake.readlines()
        i = 0
        while i < len(lines):
            line = lines[i]

            if lines[i].startswith("SET") and "src/" in lines[i]:
                name, val = handle_singleline_parse(line)
                sources[name].extend(val)
                i+=1
                continue

            if line.startswith("SET") and line.split('(')[1].strip(' \t\n\r') in set(WRAPPER_SRC_NAMES.keys()) | set(SRC_NAMES):
                name = line.split('(')[1].strip(' \t\n\r')
                i += 1
                while i < len(lines) and len(lines[i]) > 0 and ')' not in lines[i]:
                    # remove "src/" at the beginning, remove whitespaces and newline
                    value = lines[i].strip(' \t\n\r')
                    sources[name].append(value[4:])
                    i += 1
                if i < len(lines) and len(lines[i]) > 4:
                    # remove "src/" at the beginning, possibly ')' at the end
                    value = lines[i].strip(' \t\n\r)')
                    sources[name].append(value[4:])
            else:
                i += 1
    return sources

def gen_wrappers(xnnpack_path):
    xnnpack_sources = collections.defaultdict(list)
    sources = update_sources(xnnpack_path)

    microkernels_sources = update_sources(xnnpack_path, "XNNPACK/cmake/microkernels.cmake")
    for key in  microkernels_sources:
        sources[key] = microkernels_sources[key]

    for name in WRAPPER_SRC_NAMES:
        xnnpack_sources[WRAPPER_SRC_NAMES[name]].extend(sources[name])

    for condition, filenames in xnnpack_sources.items():
        print(condition)
        for filename in filenames:
            filepath = os.path.join(xnnpack_path, "xnnpack_wrappers", filename)

            if not os.path.isdir(os.path.dirname(filepath)):
                os.makedirs(os.path.dirname(filepath))
            with open(filepath, "w") as wrapper:
                print("/* {} */".format(BANNER), file=wrapper)
                print(file=wrapper)

                # Architecture- or platform-dependent preprocessor flags can be
                # defined here. Note: platform_preprocessor_flags can't be used
                # because they are ignored by arc focus & buck project.

                if condition is None:
                    print("#include <%s>" % filename, file=wrapper)
                else:
                    # Include source file only if condition is satisfied
                    print("#if %s" % condition, file=wrapper)
                    print("#include <%s>" % filename, file=wrapper)
                    print("#endif /* %s */" % condition, file=wrapper)

    # update xnnpack_wrapper_defs.bzl file under the same folder
    with open(os.path.join(os.path.dirname(__file__), "xnnpack_wrapper_defs.bzl"), 'w') as wrapper_defs:
        print('"""', file=wrapper_defs)
        print(BANNER, file=wrapper_defs)
        print('"""', file=wrapper_defs)
        for name in WRAPPER_SRC_NAMES:
            print('\n' + name + ' = [', file=wrapper_defs)
            for file_name in sources[name]:
                print('    "xnnpack_wrappers/{}",'.format(file_name), file=wrapper_defs)
            print(']', file=wrapper_defs)

    # update xnnpack_src_defs.bzl file under the same folder
    with open(os.path.join(os.path.dirname(__file__), "xnnpack_src_defs.bzl"), 'w') as src_defs:
        print('"""', file=src_defs)
        print(BANNER, file=src_defs)
        print('"""', file=src_defs)
        for name in SRC_NAMES:
            print('\n' + name + ' = [', file=src_defs)
            for file_name in sources[name]:
                print('    "XNNPACK/src/{}",'.format(file_name), file=src_defs)
            print(']', file=src_defs)


def main(argv):
    if argv is None or len(argv) == 0:
        gen_wrappers(".")
    else:
        gen_wrappers(argv[0])

# The first argument is the place where the "xnnpack_wrappers" folder will be created.
# Run it without arguments will generate "xnnpack_wrappers" in the current path.
# The two .bzl files will always be generated in the current path.
if __name__ == "__main__":
    main(sys.argv[1:])
