#!/usr/bin/python3
# Copyright 2024 The ANGLE Project Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
#
# gen_wgpu_format_table.py:
#  Code generation for wgpu format map. See wgpu_format_map.json for data source.
#  NOTE: don't run this script directly. Run scripts/run_code_generation.py.

import json
import math
import pprint
import os
import re
import sys

sys.path.append('..')
import angle_format

template_table_autogen_cpp = """// GENERATED FILE - DO NOT EDIT.
// Generated by {script_name} using data from {input_file_name}
//
// Copyright 2024 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
// {out_file_name}:
//   Queries for full WebGPU format information based on GL format.

#include "libANGLE/renderer/wgpu/wgpu_format_utils.h"

#include "image_util/loadimage.h"

using namespace angle;

namespace rx
{{
namespace webgpu
{{

void Format::initialize(const angle::Format &angleFormat)
{{
    switch (angleFormat.id)
    {{
{format_case_data}
        default:
            UNREACHABLE();
            break;
    }}
}}

wgpu::TextureFormat GetWgpuTextureFormatFromFormatID(angle::FormatID formatID)
{{
    static constexpr angle::FormatMap<wgpu::TextureFormat> kMap = {{
{image_format_id_cases}
    }};

    return kMap[formatID];
}}

angle::FormatID GetFormatIDFromWgpuTextureFormat(wgpu::TextureFormat wgpuFormat)
{{
    switch (wgpuFormat)
    {{
{wgpu_image_format_cases}
        default:
            UNREACHABLE();
            return angle::FormatID::NONE;
    }}
}}

wgpu::VertexFormat GetWgpuVertexFormatFromFormatID(angle::FormatID formatID)
{{
    static constexpr angle::FormatMap<wgpu::VertexFormat> kMap = {{
{buffer_format_id_cases}
    }};

    return kMap[formatID];
}}

angle::FormatID GetFormatIDFromWgpuBufferFormat(wgpu::VertexFormat wgpuFormat)
{{
    switch (wgpuFormat)
    {{
{wgpu_buffer_format_cases}
        default:
            UNREACHABLE();
            return angle::FormatID::NONE;
    }}
}}
}}  // namespace webgpu
}}  // namespace rx
"""

empty_format_entry_template = """case angle::FormatID::{format_id}:
// This format is not implemented in WebGPU.
break;
"""

format_entry_template = """case angle::FormatID::{format_id}:
mIntendedGLFormat = {internal_format};
{image_template}
{buffer_template}
break;
"""

image_snorm_template = """mActualImageFormatID = {image};
mImageInitializerFunction = {image_initializer};
mIsRenderable = false;"""

image_basic_template = """mActualImageFormatID = {image};
mImageInitializerFunction = {image_initializer};
mIsRenderable = true;"""

image_struct_template = "{{{image}, {image_initializer}}}"

image_fallback_template = """{{
static constexpr ImageFormatInitInfo kInfo[] = {{{image_list}}};
initImageFallback(kInfo, ArraySize(kInfo));
}}"""

buffer_basic_template = """mActualBufferFormatID = {buffer};
mVertexLoadFunction = {vertex_load_function};
mVertexLoadRequiresConversion = {vertex_load_converts};"""

buffer_struct_template = """{{{buffer}, {vertex_load_function}, {vertex_load_converts}}}"""

buffer_fallback_template = """{{
static constexpr BufferFormatInitInfo kInfo[] = {{{buffer_list}}};
initBufferFallback(kInfo, ArraySize(kInfo));
}}"""


def verify_wgpu_image_map_keys(angle_to_gl, wgpu_json_data):
    """Verify that the keys in WebGPU format tables exist in the ANGLE table.  If they don't, the
    entry in the WebGPU file is incorrect and needs to be fixed."""

    no_error = True
    for table in ["image_map", "buffer_map", "fallbacks"]:
        for angle_format in wgpu_json_data[table].keys():
            if not angle_format in angle_to_gl.keys():
                print("Invalid format " + angle_format + " in wgpu_format_map.json in " + table)
                no_error = False

    return no_error


def get_vertex_copy_function(src_format, dst_format, wgpu_format):
    if 'R10G10B10A2' in src_format:
        # When the R10G10B10A2 type can't be used by the vertex buffer,
        # it needs to be converted to the type which can be used by it.
        is_signed = 'false' if 'UINT' in src_format or 'UNORM' in src_format or 'USCALED' in src_format else 'true'
        normalized = 'true' if 'NORM' in src_format else 'false'
        to_float = 'false' if 'INT' in src_format else 'true'
        to_half = to_float
        return 'CopyXYZ10W2ToXYZWFloatVertexData<%s, %s, %s, %s>' % (is_signed, normalized,
                                                                     to_float, to_half)
    return angle_format.get_vertex_copy_function(src_format, dst_format)


def gen_format_case(angle, internal_format, wgpu_json_data):
    wgpu_image_map = wgpu_json_data["image_map"]
    wgpu_buffer_map = wgpu_json_data["buffer_map"]
    wgpu_fallbacks = wgpu_json_data["fallbacks"]
    args = dict(
        format_id=angle, internal_format=internal_format, image_template="", buffer_template="")

    if ((angle not in wgpu_image_map) and (angle not in wgpu_buffer_map) and
        (angle not in wgpu_fallbacks)):
        return empty_format_entry_template.format(**args)

    # get_formats returns override format (if any) + fallbacks
    # this was necessary to support D32_UNORM. There is no appropriate override that allows
    # us to fallback to D32_FLOAT, so now we leave the image override empty and function will
    # give us the fallbacks.
    def get_formats(format, type):
        fallbacks = wgpu_fallbacks.get(format, {}).get(type, [])
        if not isinstance(fallbacks, list):
            fallbacks = [fallbacks]

        if (format in wgpu_image_map and type == "image") or (format in wgpu_buffer_map and
                                                              type == "buffer"):
            assert format not in fallbacks
            fallbacks = [format] + fallbacks

        return fallbacks

    def image_args(format):
        return dict(
            image="angle::FormatID::" + format,
            image_initializer=angle_format.get_internal_format_initializer(
                internal_format, format))

    def buffer_args(format):
        wgpu_buffer_format = wgpu_buffer_map[format]
        return dict(
            buffer="angle::FormatID::" + format,
            vertex_load_function=get_vertex_copy_function(angle, format, wgpu_buffer_format),
            vertex_load_converts='false' if angle == format else 'true',
        )

    images = get_formats(angle, "image")
    if len(images) == 1:
        if 'SNORM' in angle:
            args.update(image_template=image_snorm_template)
        else:
            args.update(image_template=image_basic_template)
        args.update(image_args(images[0]))
    elif len(images) > 1:
        args.update(
            image_template=image_fallback_template,
            image_list=", ".join(image_struct_template.format(**image_args(i)) for i in images))

    buffers = get_formats(angle, "buffer")
    if len(buffers) == 1 and buffers[0] in wgpu_buffer_map:
        args.update(buffer_template=buffer_basic_template)
        args.update(buffer_args(buffers[0]))
    elif len(buffers) > 1:
        args.update(
            buffer_template=buffer_fallback_template,
            buffer_list=", ".join(
                buffer_struct_template.format(**buffer_args(i)) for i in buffers))

    return format_entry_template.format(**args).format(**args)


def get_format_id_case(format_id, format_type, wgpu_format):
    # wgpu::VertexFormat::Undefined was replaced with wgpu::VertexFormat(0u)
    # in https://dawn-review.googlesource.com/c/dawn/+/193360
    if 'Undefined' in wgpu_format and 'VertexFormat' in format_type:
        return "{angle::FormatID::%s, wgpu::%s(0u)}" % (format_id, format_type)
    return "{angle::FormatID::%s, wgpu::%s::%s}" % (format_id, format_type, wgpu_format)


def get_wgpu_format_case(format_type, format_id, wgpu_format):
    # wgpu::VertexFormat::Undefined was replaced with wgpu::VertexFormat(0u)
    # in https://dawn-review.googlesource.com/c/dawn/+/193360
    # so there is no 'case' needed for it.
    if 'Undefined' in wgpu_format and 'VertexFormat' in format_type:
        return ''
    # don't generate the reverse mapping for the external format slots because they _all_ map
    # to WGPU_FORMAT_UNDEFINED and so clash with NONE
    if 'EXTERNAL' in format_id:
        return ''
    return """\
        case wgpu::%s::%s:
            return angle::FormatID::%s;
""" % (format_type, wgpu_format, format_id)


def main():

    input_file_name = 'wgpu_format_map.json'
    out_file_name = 'wgpu_format_table_autogen.cpp'

    # auto_script parameters.
    if len(sys.argv) > 1:
        inputs = ['../angle_format.py', '../angle_format_map.json', input_file_name]
        outputs = [out_file_name]

        if sys.argv[1] == 'inputs':
            print(','.join(inputs))
        elif sys.argv[1] == 'outputs':
            print(','.join(outputs))
        else:
            print('Invalid script parameters')
            return 1
        return 0

    angle_to_gl = angle_format.load_inverse_table(os.path.join('..', 'angle_format_map.json'))
    wgpu_json_data = angle_format.load_json(input_file_name)

    if not verify_wgpu_image_map_keys(angle_to_gl, wgpu_json_data):
        return 1

    image_format_id_cases = [
        get_format_id_case(format_id, "TextureFormat", wgpu_format)
        for format_id, wgpu_format in sorted(wgpu_json_data["image_map"].items())
    ]

    wgpu_image_format_cases = [
        get_wgpu_format_case("TextureFormat", format_id, wgpu_format)
        for format_id, wgpu_format in sorted(wgpu_json_data["image_map"].items())
    ]

    buffer_format_id_cases = [
        get_format_id_case(format_id, "VertexFormat", wgpu_format)
        for format_id, wgpu_format in sorted(wgpu_json_data["buffer_map"].items())
    ]

    wgpu_buffer_format_cases = [
        get_wgpu_format_case("VertexFormat", format_id, wgpu_format)
        for format_id, wgpu_format in sorted(wgpu_json_data["buffer_map"].items())
    ]

    wgpu_cases = [
        gen_format_case(angle, gl, wgpu_json_data) for angle, gl in sorted(angle_to_gl.items())
    ]

    output_cpp = template_table_autogen_cpp.format(
        format_case_data="\n".join(wgpu_cases),
        image_format_id_cases=",\n".join(image_format_id_cases),
        wgpu_image_format_cases="".join(wgpu_image_format_cases),
        buffer_format_id_cases=",\n".join(buffer_format_id_cases),
        wgpu_buffer_format_cases="".join(wgpu_buffer_format_cases),
        script_name=os.path.basename(__file__),
        out_file_name=out_file_name,
        input_file_name=input_file_name)

    with open(out_file_name, 'wt') as out_file:
        out_file.write(output_cpp)
        out_file.close()
    return 0


if __name__ == '__main__':
    sys.exit(main())
