#!/usr/bin/env python3
#
# Copyright © 2024 Igalia S.L.
# SPDX-License-Identifier: MIT

import argparse
import ctypes
import logging
import re
from datetime import datetime
from mako.template import Template
from pycparser import parse_file, c_ast

logger = logging.getLogger('hwdb')

template = """/*
 * Copyright © 2024 Igalia S.L.
 * SPDX-License-Identifier: MIT
 */

#pragma once

#include <stdint.h>

typedef struct
{
% for type, name in struct:
   ${type} ${name};
% endfor
} gcsFEATURE_DATABASE;

static gcsFEATURE_DATABASE gChipInfo[] = {
% for entry in entries:
   {
% for name, value in entry:
      ${value}, /* ${name} */
% endfor
   },
% endfor
};

static inline gcsFEATURE_DATABASE *
gcQueryFeatureDB(
    uint32_t ChipID,
    uint32_t ChipVersion,
    uint32_t ProductID,
    uint32_t EcoID,
    uint32_t CustomerID
    )
{
    int entryNum = sizeof(gChipInfo) / sizeof(gChipInfo[0]);

    /* check formal release entries first */
    for (int i = 0; i < entryNum; ++i)
    {
        if ((gChipInfo[i].chipID == ChipID)
            && (gChipInfo[i].chipVersion == ChipVersion)
            && (gChipInfo[i].productID == ProductID)
            && (gChipInfo[i].ecoID == EcoID)
            && (gChipInfo[i].customerID == CustomerID)
            && (gChipInfo[i].formalRelease)
           )
        {
            return &gChipInfo[i];
        }
    }

    /* check informal release entries if we dont find in formal entries */
    for (int i = 0; i < entryNum; ++i)
    {
        if ((gChipInfo[i].chipID == ChipID)
            && ((gChipInfo[i].chipVersion & 0xFFF0) == (ChipVersion & 0xFFF0))
            && (gChipInfo[i].productID == ProductID)
            && (gChipInfo[i].ecoID == EcoID)
            && (gChipInfo[i].customerID == CustomerID)
            && (!gChipInfo[i].formalRelease)
           )
        {
            return &gChipInfo[i];
        }
    }

    return 0;
}
"""


class HeaderFile(c_ast.NodeVisitor):
    """Class representing a complete header file"""

    # Regular expression to match the date and time in the comment
    _DATE_RE = re.compile(r'/\*Auto created on (\d{4}-\d{2}-\d{2} \d{2}:\d{2})\*/')

    def __init__(self, filename):
        self.filename = filename
        self.structs = {}
        self.data = []
        self.date_time = None
        self.database_struct = None

        self._read_date()
        self._parse()

        logger.debug('Parsed %s (autogenerated at %s, %u struct members, %u entries)', self.filename, self.date_time, len(self.database_struct._fields_), len(self.data))

    def _read_date(self):
        """Function parsing the creation date with re."""
        # Read the content of the file and search for pattern
        with open(self.filename, 'r', encoding="utf-8") as file:
            file_content = file.read()

        match = self._DATE_RE.search(file_content)

        if match:
            self.date_time = datetime.strptime(match.group(1), '%Y-%m-%d %H:%M')

    def _parse(self):
        ast = parse_file(self.filename, use_cpp=True, cpp_args=['-E', r'-I./utils/fake_libc_include', '-DgctUINT32=unsigned int', '-DgctINT=int'])
        self.visit(ast)

        self.database_struct = self.structs['gcsFEATURE_DATABASE']

    def visit_Typedef(self, node):
        if isinstance(node.type, c_ast.TypeDecl) and isinstance(node.type.type, c_ast.Struct):
            struct_node = node.type.type
            struct_name = node.name  # Typedef name as the struct name
            fields = self._extract_fields_from_struct(struct_node)
            if fields is not None:
                # Create the ctypes.Structure subclass and add it to the structures dictionary
                self.structs[struct_name] = type(struct_name, (ctypes.Structure,), {'_fields_': fields})

    def _extract_fields_from_struct(self, struct_node):
        """Function returning all fields of struct."""
        fields = []
        for decl in (struct_node.decls or []):
            if isinstance(decl.type, c_ast.TypeDecl) or isinstance(decl.type, c_ast.PtrDecl):
                field_name = decl.name
                field_type = self._map_type_to_ctypes(decl.type.type, decl.bitsize)
                if field_type:
                    fields.append((field_name, field_type))
            elif isinstance(decl.type, c_ast.ArrayDecl):
                # Handle array type
                field_name = decl.type.type.declname
                element_type = self._map_type_to_ctypes(decl.type.type.type)
                array_size = int(decl.type.dim.value)  # Assuming dim is a Constant node with the size as its value
                if element_type:
                    fields.append((field_name, element_type * array_size))

        return fields if fields else None

    def _map_type_to_ctypes(self, type_node, bitsize=None):
        """Function returning a ctype type based node type."""
        type_mappings = {
            'bool': ctypes.c_bool,
            'unsigned int': ctypes.c_uint,
            'int': ctypes.c_int,
        }

        if isinstance(type_node, c_ast.IdentifierType):
            c_type = ' '.join(type_node.names)

            if bitsize and bitsize.value == '1':
                c_type = 'bool'

            return type_mappings.get(c_type)
        elif isinstance(type_node, c_ast.TypeDecl):
            return ctypes.c_char_p

        return None

    def visit_Decl(self, node):
        # Check if the node is a declaration of an array of structs
        if isinstance(node.type, c_ast.ArrayDecl) and isinstance(node.type.type, c_ast.TypeDecl):
            struct_name = node.type.type.type.names[0]
            if struct_name in self.structs:
                elements = self._parse_array_initializer(node.init, self.structs[struct_name])
                self.data.extend(elements)

    def _parse_initializer(self, initializer, struct):
        """Function returning one parsed struct initializer."""
        return [
            (param if not isinstance(param, str) else param.encode('utf-8'))
            for index, expr in enumerate(initializer.exprs)
            for param in [self._parse_expr(expr, getattr(struct._fields_[index][1], '_length_', None))]
        ]

    def _parse_array_initializer(self, init, struct_class):
        """Function returning a fully processed struct initializer list."""
        assert (isinstance(init, c_ast.InitList))
        return [struct_class(*self._parse_initializer(initializer, struct_class)) for initializer in init.exprs]

    def _parse_expr(self, expr, expected_size=None):
        """Function returning parsed expression."""
        # Direct handling of constant types
        if isinstance(expr, c_ast.Constant):
            if expr.type == "int":
                # Base 0 automatically handles decimal, hex, and octal
                return int(expr.value, 0)
            elif expr.type == "string":
                return expr.value.strip('"')

        # Handling arrays with potential conversion to ctypes arrays
        elif isinstance(expr, c_ast.InitList) and expected_size is not None:
            element_list = [self._parse_expr(e) for e in expr.exprs]

            # Ensure the list matches expected size, filling with zeros if necessary
            if len(element_list) < expected_size:
                element_list.extend([0] * (expected_size - len(element_list)))

            # Convert to ctypes array, dynamically adjusting type based on context if needed
            return (ctypes.c_uint * expected_size)(*element_list)

        # Fallback or default return for unhandled cases
        return None


def merge_structures(structures):
    """Function creating a new type by merging provided types."""
    combined_fields = []
    for struct in structures:
        for field_name, field_type in struct._fields_:
            # Check if the field name already exists
            if field_name not in [field[0] for field in combined_fields]:
                combined_fields.append((field_name, field_type))

    # Create a new structure dynamically
    return type("MergedStructure", (ctypes.Structure,), {'_fields_': combined_fields})


def get_field_type(c, field_name):
    """Function returning the type of a field type based on its name."""
    for field in c._fields_:
        if field[0] == field_name:
            return field[1]

    return None


def create_merged_struct(c, entry):
    """Function returning a fully populate instance of MergedStructure."""
    fields = []
    for field_name, field_type in c._fields_:
        # We might need to 'upgrade' to an array - check field type too.
        # e.g. VIP_SRAM_SIZE_ARRAY -> VIP_SRAM_SIZE_ARRAY[9]

        if hasattr(entry, field_name) and get_field_type(entry, field_name) is field_type:
            fields.append(getattr(entry, field_name))
        else:
            # Add a suitable default value
            if field_type == ctypes.c_uint or field_type == ctypes.c_bool:
                fields.append(0)
            else:
                # It must be an array
                expected_size = getattr(field_type, '_length_')
                fields.append((ctypes.c_uint * expected_size)(*[0] * expected_size))

    return c(*fields)


def enumerate_struct(obj):
    for field_name, field_type in obj._fields_:
        type = 'uint32_t'

        if field_type == ctypes.c_char_p:
            type = 'const char *'

        if field_type == ctypes.c_bool:
            field_name = field_name + ':1'

        if hasattr(field_type, '_length_'):
            field_name = f'{field_name}[{field_type._length_}]'

        yield type, field_name


def enumerate_values(obj):
    for field_name, field_type in obj._fields_:
        value = getattr(obj, field_name)

        if field_type in {ctypes.c_uint, ctypes.c_bool}:
            value = hex(value)
        elif field_type == ctypes.c_char_p:
            value = '"{}"'.format(value.decode('utf-8'))
        elif isinstance(value, ctypes.Array):
            value = '{{{}}}'.format(', '.join(str(element) for element in value))

        yield field_name, value


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--output', required=True, type=str, action="store",
                        help='output C header file')
    parser.add_argument('headers', metavar='header', type=str, nargs='+',
                        help='gc database header to process')
    parser.add_argument('--verbose', "-v", action="store_true",
                        help='be verbose')
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    files = [HeaderFile(header) for header in args.headers]
    structs = [file.database_struct for file in files]
    merged = merge_structures(structs)
    logger.debug('merged struct: %u members', len(merged._fields_))

    entries = []
    for file in files:
        logger.debug('Processing %s', file.filename)
        for entry in file.data:
            merged_entry = create_merged_struct(merged, entry)
            entry_data = list(enumerate_values(merged_entry))
            entries.append(entry_data)

    logger.debug('Total entries: %u', len(entries))

    with open(args.output, "w", encoding="UTF-8") as fh:
        print(Template(template).render(struct=enumerate_struct(merged), entries=entries), file=fh)


if __name__ == '__main__':
    main()
