#!/usr/bin/env python3

# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Tests the generated python backend against standard PDL
# constructs, with matching input vectors.

import dataclasses
import enum
import json
import typing
import unittest
from importlib import resources

# (le|be)_backend are the names of the modules generated from the canonical
# little endian and big endian test grammars. The purpose of this module
# is to validate the generated parsers against the set of pre-generated
# test vectors in canonical/(le|be)_test_vectors.json.
import le_backend
import be_backend


SKIPPED_TESTS = [
    "Packet_Array_Field_VariableElementSize_ConstantSize",
    "Packet_Array_Field_VariableElementSize_VariableSize",
    "Packet_Array_Field_VariableElementSize_VariableCount",
    "Packet_Array_Field_VariableElementSize_UnknownSize",
]


def match_object(self, left, right):
    """Recursively match a python class object against a reference
    json object."""
    if isinstance(right, int):
        self.assertEqual(left, right)
    elif isinstance(right, list):
        self.assertEqual(len(left), len(right))
        for n in range(len(right)):
            match_object(self, left[n], right[n])
    elif isinstance(right, dict):
        for (k, v) in right.items():
            self.assertTrue(hasattr(left, k))
            match_object(self, getattr(left, k), v)


def create_object(typ, value):
    """Build an object of the selected type using the input value."""
    if dataclasses.is_dataclass(typ):
        field_types = dict([(f.name, f.type) for f in dataclasses.fields(typ)])
        values = dict()
        for (f, v) in value.items():
            field_type = field_types[f]
            values[f] = create_object(field_type, v)
        return typ(**values)
    elif typing.get_origin(typ) is list:
        typ = typing.get_args(typ)[0]
        return [create_object(typ, v) for v in value]
    elif typing.get_origin(typ) is typing.Union:
        # typing.Optional[int] expands to typing.Union[int, None]
        typ = typing.get_args(typ)[0]
        return create_object(typ, value) if value is not None else None
    elif typ is bytes:
        return bytes(value)
    elif typ is bytearray:
        return bytearray(value)
    elif issubclass(typ, enum.Enum):
        from_int = getattr(typ, 'from_int')
        return from_int(value)
    elif typ is int:
        return value
    else:
        raise Exception(f"unsupported type annotation {typ}")


class PacketParserTest(unittest.TestCase):
    """Validate the generated parser against pre-generated test
       vectors in canonical/(le|be)_test_vectors.json"""

    def testLittleEndian(self):
        with resources.files('tests.canonical').joinpath('le_test_vectors.json').open('r') as f:
            reference = json.load(f)

        for item in reference:
            # 'packet' is the name of the packet being tested,
            # 'tests' lists input vectors that must match the
            # selected packet.
            packet = item['packet']
            tests = item['tests']

            if packet in SKIPPED_TESTS:
                continue

            with self.subTest(packet=packet):
                # Retrieve the class object from the generated
                # module, in order to invoke the proper parse
                # method for this test.
                cls = getattr(le_backend, packet)
                for test in tests:
                    result = cls.parse_all(bytes.fromhex(test['packed']))
                    match_object(self, result, test['unpacked'])

    def testBigEndian(self):
        with resources.files('tests.canonical').joinpath('be_test_vectors.json').open('r') as f:
            reference = json.load(f)

        for item in reference:
            # 'packet' is the name of the packet being tested,
            # 'tests' lists input vectors that must match the
            # selected packet.
            packet = item['packet']
            tests = item['tests']

            if packet in SKIPPED_TESTS:
                continue

            with self.subTest(packet=packet):
                # Retrieve the class object from the generated
                # module, in order to invoke the proper constructor
                # method for this test.
                cls = getattr(be_backend, packet)
                for test in tests:
                    result = cls.parse_all(bytes.fromhex(test['packed']))
                    match_object(self, result, test['unpacked'])


class PacketSerializerTest(unittest.TestCase):
    """Validate the generated serializer against pre-generated test
       vectors in canonical/(le|be)_test_vectors.json"""

    def testLittleEndian(self):
        with resources.files('tests.canonical').joinpath('le_test_vectors.json').open('r') as f:
            reference = json.load(f)

        for item in reference:
            # 'packet' is the name of the packet being tested,
            # 'tests' lists input vectors that must match the
            # selected packet.
            packet = item['packet']
            tests = item['tests']

            if packet in SKIPPED_TESTS:
                continue

            with self.subTest(packet=packet):
                # Retrieve the class object from the generated
                # module, in order to invoke the proper constructor
                # method for this test.
                for test in tests:
                    cls = getattr(le_backend, test.get('packet', packet))
                    obj = create_object(cls, test['unpacked'])
                    result = obj.serialize()
                    self.assertEqual(result, bytes.fromhex(test['packed']))

    def testBigEndian(self):
        with resources.files('tests.canonical').joinpath('be_test_vectors.json').open('r') as f:
            reference = json.load(f)

        for item in reference:
            # 'packet' is the name of the packet being tested,
            # 'tests' lists input vectors that must match the
            # selected packet.
            packet = item['packet']
            tests = item['tests']

            if packet in SKIPPED_TESTS:
                continue

            with self.subTest(packet=packet):
                # Retrieve the class object from the generated
                # module, in order to invoke the proper parse
                # method for this test.
                for test in tests:
                    cls = getattr(be_backend, test.get('packet', packet))
                    obj = create_object(cls, test['unpacked'])
                    result = obj.serialize()
                    self.assertEqual(result, bytes.fromhex(test['packed']))


class CustomPacketParserTest(unittest.TestCase):
    """Manual testing for custom fields."""

    def testCustomField(self):
        result = le_backend.Packet_Custom_Field_ConstantSize.parse_all([1])
        self.assertEqual(result.a.value, 1)

        result = le_backend.Packet_Custom_Field_VariableSize.parse_all([1])
        self.assertEqual(result.a.value, 1)

        result = le_backend.Struct_Custom_Field_ConstantSize.parse_all([1])
        self.assertEqual(result.s.a.value, 1)

        result = le_backend.Struct_Custom_Field_VariableSize.parse_all([1])
        self.assertEqual(result.s.a.value, 1)

        result = be_backend.Packet_Custom_Field_ConstantSize.parse_all([1])
        self.assertEqual(result.a.value, 1)

        result = be_backend.Packet_Custom_Field_VariableSize.parse_all([1])
        self.assertEqual(result.a.value, 1)

        result = be_backend.Struct_Custom_Field_ConstantSize.parse_all([1])
        self.assertEqual(result.s.a.value, 1)

        result = be_backend.Struct_Custom_Field_VariableSize.parse_all([1])
        self.assertEqual(result.s.a.value, 1)


if __name__ == '__main__':
    unittest.main(verbosity=3)
