#!/usr/bin/env python3

import argparse
import collections
import copy
import json
from pathlib import Path
import pprint
import traceback
from typing import Iterable, List, Optional, Union
import sys

from pdl import ast, core

MAX_ARRAY_SIZE = 256
MAX_ARRAY_COUNT = 32
DEFAULT_ARRAY_COUNT = 3
DEFAULT_PAYLOAD_SIZE = 5


class BitSerializer:
    def __init__(self, big_endian: bool):
        self.stream = []
        self.value = 0
        self.shift = 0
        self.byteorder = "big" if big_endian else "little"

    def append(self, value: int, width: int):
        self.value = self.value | (value << self.shift)
        self.shift += width

        if (self.shift % 8) == 0:
            width = int(self.shift / 8)
            self.stream.extend(self.value.to_bytes(width, byteorder=self.byteorder))
            self.shift = 0
            self.value = 0


class Value:
    def __init__(self, value: object, width: Optional[int] = None):
        self.value = value
        if width is not None:
            self.width = width
        elif isinstance(value, int) or callable(value):
            raise Exception("Creating scalar value of unspecified width")
        elif isinstance(value, list):
            self.width = sum([v.width for v in value])
        elif isinstance(value, Packet):
            self.width = value.width
        else:
            raise Exception(f"Malformed value {value}")

    def finalize(self, parent: "Packet"):
        if callable(self.width):
            self.width = self.width(parent)

        if callable(self.value):
            self.value = self.value(parent)
        elif isinstance(self.value, list):
            for v in self.value:
                v.finalize(parent)
        elif isinstance(self.value, Packet):
            self.value.finalize()

    def serialize_(self, serializer: BitSerializer):
        if isinstance(self.value, int):
            serializer.append(self.value, self.width)
        elif isinstance(self.value, list):
            for v in self.value:
                v.serialize_(serializer)
        elif isinstance(self.value, Packet):
            self.value.serialize_(serializer)
        elif self.value == None:
            pass
        else:
            raise Exception(f"Malformed value {self.value}")

    def show(self, indent: int = 0):
        space = " " * indent
        if isinstance(self.value, int):
            print(f"{space}{self.name}: {hex(self.value)}")
        elif isinstance(self.value, list):
            print(f"{space}{self.name}[{len(self.value)}]:")
            for v in self.value:
                v.show(indent + 2)
        elif isinstance(self.value, Packet):
            print(f"{space}{self.name}:")
            self.value.show(indent + 2)

    def to_json(self) -> object:
        if isinstance(self.value, int):
            return self.value
        elif isinstance(self.value, list):
            return [v.to_json() for v in self.value]
        elif isinstance(self.value, Packet):
            return self.value.to_json()


class Field:
    def __init__(self, value: Value, ref: ast.Field):
        self.value = value
        self.ref = ref

    def finalize(self, parent: "Packet"):
        self.value.finalize(parent)

    def serialize_(self, serializer: BitSerializer):
        self.value.serialize_(serializer)

    def clone(self):
        return Field(copy.copy(self.value), self.ref)


class Packet:
    def __init__(self, fields: List[Field], ref: ast.Declaration):
        self.fields = fields
        self.ref = ref

    def finalize(self, parent: Optional["Packet"] = None):
        for f in self.fields:
            f.finalize(self)

    def serialize_(self, serializer: BitSerializer):
        for f in self.fields:
            f.serialize_(serializer)

    def serialize(self, big_endian: bool) -> bytes:
        serializer = BitSerializer(big_endian)
        self.serialize_(serializer)
        if serializer.shift != 0:
            raise Exception("The packet size is not an integral number of octets")
        return bytes(serializer.stream)

    def show(self, indent: int = 0):
        for f in self.fields:
            f.value.show(indent)

    def to_json(self) -> dict:
        result = dict()
        for f in self.fields:
            if isinstance(f.ref, (ast.PayloadField, ast.BodyField)) and isinstance(
                f.value.value, Packet
            ):
                result.update(f.value.to_json())
            elif isinstance(f.ref, (ast.PayloadField, ast.BodyField)):
                result["payload"] = f.value.to_json()
            elif hasattr(f.ref, "id"):
                result[f.ref.id] = f.value.to_json()
        return result

    @property
    def width(self) -> int:
        self.finalize()
        return sum([f.value.width for f in self.fields])


class BitGenerator:
    def __init__(self):
        self.value = 0
        self.shift = 0

    def generate(self, width: int) -> Value:
        """Generate an integer value of the selected width."""
        value = 0
        remains = width
        while remains > 0:
            w = min(8 - self.shift, remains)
            mask = (1 << w) - 1
            value = (value << w) | ((self.value >> self.shift) & mask)
            remains -= w
            self.shift += w
            if self.shift >= 8:
                self.shift = 0
                self.value = (self.value + 1) % 0xFF
        return Value(value, width)

    def generate_list(self, width: int, count: int) -> List[Value]:
        return [self.generate(width) for n in range(count)]


generator = BitGenerator()


def generate_cond_field_values(field: ast.ScalarField) -> List[Value]:
    cond_value_present = field.cond_for.cond.value
    cond_value_absent = 0 if field.cond_for.cond.value != 0 else 1

    def get_cond_value(parent: Packet, field: ast.Field) -> int:
        for f in parent.fields:
            if f.ref is field:
                return cond_value_absent if f.value.value is None else cond_value_present

    return [Value(lambda p: get_cond_value(p, field.cond_for), field.width)]


def generate_size_field_values(field: ast.SizeField) -> List[Value]:
    def get_field_size(parent: Packet, field_id: str) -> int:
        for f in parent.fields:
            if (
                (field_id == "_payload_" and isinstance(f.ref, ast.PayloadField))
                or (field_id == "_body_" and isinstance(f.ref, ast.BodyField))
                or (getattr(f.ref, "id", None) == field_id)
            ):
                assert f.value.width % 8 == 0
                size_modifier = int(getattr(f.ref, "size_modifier", None) or 0)
                return int(f.value.width / 8) + size_modifier
        raise Exception(
            "Field {} not found in packet {}".format(field_id, parent.ref.id)
        )

    return [Value(lambda p: get_field_size(p, field.field_id), field.width)]


def generate_count_field_values(field: ast.CountField) -> List[Value]:
    def get_array_count(parent: Packet, field_id: str) -> int:
        for f in parent.fields:
            if getattr(f.ref, "id", None) == field_id:
                assert isinstance(f.value.value, list)
                return len(f.value.value)
        raise Exception(
            "Field {} not found in packet {}".format(field_id, parent.ref.id)
        )

    return [Value(lambda p: get_array_count(p, field.field_id), field.width)]


def generate_checksum_field_values(field: ast.TypedefField) -> List[Value]:
    field_width = core.get_field_size(field)

    def basic_checksum(input: bytes, width: int):
        assert width == 8
        return sum(input) % 256

    def compute_checksum(parent: Packet, field_id: str) -> int:
        serializer = None
        for f in parent.fields:
            if isinstance(f.ref, ast.ChecksumField) and f.ref.field_id == field_id:
                serializer = BitSerializer(
                    f.ref.parent.file.endianness.value == "big_endian"
                )
            elif isinstance(f.ref, ast.TypedefField) and f.ref.id == field_id:
                return basic_checksum(serializer.stream, field_width)
            elif serializer:
                f.value.serialize_(serializer)
        raise Exception("malformed checksum")

    return [Value(lambda p: compute_checksum(p, field.id), field_width)]


def generate_padding_field_values(field: ast.PaddingField) -> List[Value]:
    preceding_field_id = field.padded_field.id

    def get_padding(parent: Packet, field_id: str, width: int) -> List[Value]:
        for f in parent.fields:
            if (
                (field_id == "_payload_" and isinstance(f.ref, ast.PayloadField))
                or (field_id == "_body_" and isinstance(f.ref, ast.BodyField))
                or (getattr(f.ref, "id", None) == field_id)
            ):
                assert f.value.width % 8 == 0
                assert f.value.width <= width
                return width - f.value.width
        raise Exception(
            "Field {} not found in packet {}".format(field_id, parent.ref.id)
        )

    return [Value(0, lambda p: get_padding(p, preceding_field_id, 8 * field.size))]


def generate_payload_field_values(
    field: Union[ast.PayloadField, ast.BodyField]
) -> List[Value]:
    payload_size = core.get_payload_field_size(field)
    size_modifier = int(getattr(field, "size_modifier", None) or 0)

    # If the paylaod has a size field, generate an empty payload and
    # a payload of maximum size. If not generate a payload of the default size.
    max_size = (1 << payload_size.width) - 1 if payload_size else DEFAULT_PAYLOAD_SIZE
    max_size -= size_modifier

    assert max_size > 0
    return [Value([]), Value(generator.generate_list(8, max_size))]


def generate_scalar_array_field_values(field: ast.ArrayField) -> List[Value]:
    if field.width % 8 != 0:
        if element_width % 8 != 0:
            raise Exception("Array element size is not a multiple of 8")

    array_size = core.get_array_field_size(field)
    element_width = int(field.width / 8)

    # TODO
    # The array might also be bounded if it is included in the sized payload
    # of a packet.

    # Apply the size modifiers.
    size_modifier = int(getattr(field, "size_modifier", None) or 0)

    # The element width is known, and the array element count is known
    # statically.
    if isinstance(array_size, int):
        return [Value(generator.generate_list(field.width, array_size))]

    # The element width is known, and the array element count is known
    # by count field.
    elif isinstance(array_size, ast.CountField):
        min_count = 0
        max_count = (1 << array_size.width) - 1
        return [Value([]), Value(generator.generate_list(field.width, max_count))]

    # The element width is known, and the array full size is known
    # by size field.
    elif isinstance(array_size, ast.SizeField):
        min_count = 0
        max_size = (1 << array_size.width) - 1 - size_modifier
        max_count = int(max_size / element_width)
        return [Value([]), Value(generator.generate_list(field.width, max_count))]

    # The element width is known, but the array size is unknown.
    # Generate two arrays: one empty and one including some possible element
    # values.
    else:
        return [
            Value([]),
            Value(generator.generate_list(field.width, DEFAULT_ARRAY_COUNT)),
        ]


def generate_typedef_array_field_values(field: ast.ArrayField) -> List[Value]:
    array_size = core.get_array_field_size(field)
    element_width = core.get_array_element_size(field)
    if element_width:
        if element_width % 8 != 0:
            raise Exception("Array element size is not a multiple of 8")
        element_width = int(element_width / 8)

    # Generate element values to use for the generation.
    type_decl = field.parent.file.typedef_scope[field.type_id]

    def generate_list(count: Optional[int]) -> List[Value]:
        """Generate an array of specified length.
        If the count is None all possible array items are returned."""
        element_values = generate_typedef_values(type_decl)

        # Requested a variable count, send everything in one chunk.
        if count is None:
            return [Value(element_values)]
        # Have more items than the requested count.
        # Slice the possible array values in multiple slices.
        elif len(element_values) > count:
            # Add more elements in case of wrap-over.
            elements_count = len(element_values)
            element_values.extend(generate_typedef_values(type_decl))
            chunk_count = int((len(elements) + count - 1) / count)
            return [
                Value(element_values[n * count : (n + 1) * count])
                for n in range(chunk_count)
            ]
        # Have less items than the requested count.
        # Generate additional items to fill the gap.
        else:
            chunk = element_values
            while len(chunk) < count:
                chunk.extend(generate_typedef_values(type_decl))
            return [Value(chunk[:count])]

    # TODO
    # The array might also be bounded if it is included in the sized payload
    # of a packet.

    # Apply the size modifier.
    size_modifier = int(getattr(field, "size_modifier", None) or 0)

    min_size = 0
    max_size = MAX_ARRAY_SIZE
    min_count = 0
    max_count = MAX_ARRAY_COUNT

    if field.padded_size:
        max_size = field.padded_size

    if isinstance(array_size, ast.SizeField):
        max_size = (1 << array_size.width) - 1 - size_modifier
        min_size = size_modifier
    elif isinstance(array_size, ast.CountField):
        max_count = (1 << array_size.width) - 1
    elif isinstance(array_size, int):
        min_count = array_size
        max_count = array_size

    values = []
    chunk = []
    chunk_size = 0

    while not values:
        element_values = generate_typedef_values(type_decl)
        for element_value in element_values:
            element_size = int(element_value.width / 8)

            if len(chunk) >= max_count or chunk_size + element_size > max_size:
                assert len(chunk) >= min_count
                values.append(Value(chunk))
                chunk = []
                chunk_size = 0

            chunk.append(element_value)
            chunk_size += element_size

    if min_count == 0:
        values.append(Value([]))

    return values

    # The element width is not known, but the array full octet size
    # is known by size field. Generate two arrays: of minimal and maximum
    # size. All unused element values are packed into arrays of varying size.
    if element_width is None and isinstance(array_size, ast.SizeField):
        element_values = generate_typedef_values(type_decl)
        chunk = []
        chunk_size = 0
        values = [Value([])]
        for element_value in element_values:
            assert element_value.width % 8 == 0
            element_size = int(element_value.width / 8)
            if chunk_size + element_size > max_size:
                values.append(Value(chunk))
                chunk = []
            chunk.append(element_value)
            chunk_size += element_size
        if chunk:
            values.append(Value(chunk))
        return values

    # The element width is not known, but the array element count
    # is known statically or by count field. Generate two arrays:
    # of minimal and maximum length. All unused element values are packed into
    # arrays of varying count.
    elif element_width is None and isinstance(array_size, ast.CountField):
        return [Value([])] + generate_list(max_count)

    # The element width is not known, and the array element count is known
    # statically.
    elif element_width is None and isinstance(array_size, int):
        return generate_list(array_size)

    # Neither the count not size is known,
    # generate two arrays: one empty and one including all possible element
    # values.
    elif element_width is None:
        return [Value([])] + generate_list(None)

    # The element width is known, and the array element count is known
    # statically.
    elif isinstance(array_size, int):
        return generate_list(array_size)

    # The element width is known, and the array element count is known
    # by count field.
    elif isinstance(array_size, ast.CountField):
        return [Value([])] + generate_list(max_count)

    # The element width is known, and the array full size is known
    # by size field.
    elif isinstance(array_size, ast.SizeField):
        return [Value([])] + generate_list(max_count)

    # The element width is known, but the array size is unknown.
    # Generate two arrays: one empty and one including all possible element
    # values.
    else:
        return [Value([])] + generate_list(None)


def generate_array_field_values(field: ast.ArrayField) -> List[Value]:
    if field.width is not None:
        return generate_scalar_array_field_values(field)
    else:
        return generate_typedef_array_field_values(field)


def generate_typedef_field_values(
    field: ast.TypedefField, constraints: List[ast.Constraint]
) -> List[Value]:
    type_decl = field.parent.file.typedef_scope[field.type_id]

    # Check for constraint on enum field.
    if isinstance(type_decl, ast.EnumDeclaration):
        for c in constraints:
            if c.id == field.id:
                for tag in type_decl.tags:
                    if tag.id == c.tag_id:
                        return [Value(tag.value, type_decl.width)]
                raise Exception("undefined enum tag")

    # Checksum field needs to known the checksum range.
    if isinstance(type_decl, ast.ChecksumDeclaration):
        return generate_checksum_field_values(field)

    return generate_typedef_values(type_decl)


def generate_field_values(
    field: ast.Field, constraints: List[ast.Constraint], payload: Optional[List[Packet]]
) -> List[Value]:
    if field.cond_for:
        return generate_cond_field_values(field)

    elif isinstance(field, ast.ChecksumField):
        # Checksum fields are just markers.
        return [Value(0, 0)]

    elif isinstance(field, ast.PaddingField):
        return generate_padding_field_values(field)

    elif isinstance(field, ast.SizeField):
        return generate_size_field_values(field)

    elif isinstance(field, ast.CountField):
        return generate_count_field_values(field)

    elif isinstance(field, (ast.BodyField, ast.PayloadField)) and payload:
        return [Value(p) for p in payload]

    elif isinstance(field, (ast.BodyField, ast.PayloadField)):
        return generate_payload_field_values(field)

    elif isinstance(field, ast.FixedField) and field.enum_id:
        enum_decl = field.parent.file.typedef_scope[field.enum_id]
        for tag in enum_decl.tags:
            if tag.id == field.tag_id:
                return [Value(tag.value, enum_decl.width)]
        raise Exception("undefined enum tag")

    elif isinstance(field, ast.FixedField) and field.width:
        return [Value(field.value, field.width)]

    elif isinstance(field, ast.ReservedField):
        return [Value(0, field.width)]

    elif isinstance(field, ast.ArrayField):
        return generate_array_field_values(field)

    elif isinstance(field, ast.ScalarField):
        for c in constraints:
            if c.id == field.id:
                return [Value(c.value, field.width)]
        mask = (1 << field.width) - 1
        return [
            Value(0, field.width),
            Value(-1 & mask, field.width),
            generator.generate(field.width),
        ]

    elif isinstance(field, ast.TypedefField):
        return generate_typedef_field_values(field, constraints)

    else:
        raise Exception("unsupported field kind")


def generate_fields(
    decl: ast.Declaration,
    constraints: List[ast.Constraint],
    payload: Optional[List[Packet]],
) -> List[List[Field]]:
    fields = []
    for f in decl.fields:
        values = generate_field_values(f, constraints, payload)
        optional_none = [] if not f.cond else [Field(Value(None, 0), f)]
        fields.append(optional_none + [Field(v, f) for v in values])
    return fields


def generate_fields_recursive(
    scope: dict,
    decl: ast.Declaration,
    constraints: List[ast.Constraint] = [],
    payload: Optional[List[Packet]] = None,
) -> List[List[Field]]:
    fields = generate_fields(decl, constraints, payload)

    if not decl.parent_id:
        return fields

    packets = [Packet(fields, decl) for fields in product(fields)]
    parent_decl = scope[decl.parent_id]
    return generate_fields_recursive(
        scope, parent_decl, constraints + decl.constraints, payload=packets
    )


def generate_struct_values(decl: ast.StructDeclaration) -> List[Packet]:
    fields = generate_fields_recursive(decl.file.typedef_scope, decl)
    return [Packet(fields, decl) for fields in product(fields)]


def generate_packet_values(decl: ast.PacketDeclaration) -> List[Packet]:
    fields = generate_fields_recursive(decl.file.packet_scope, decl)
    return [Packet(fields, decl) for fields in product(fields)]


def generate_typedef_values(decl: ast.Declaration) -> List[Value]:
    if isinstance(decl, ast.EnumDeclaration):
        return [Value(t.value, decl.width) for t in decl.tags]

    elif isinstance(decl, ast.ChecksumDeclaration):
        raise Exception("ChecksumDeclaration handled in typedef field")

    elif isinstance(decl, ast.CustomFieldDeclaration):
        raise Exception("TODO custom field")

    elif isinstance(decl, ast.StructDeclaration):
        return [Value(p) for p in generate_struct_values(decl)]

    else:
        raise Exception("unsupported typedef declaration type")


def product(fields: List[List[Field]]) -> List[List[Field]]:
    """Perform a cartesian product of generated options for packet field values."""

    def aux(vec: List[List[Field]]) -> List[List[Field]]:
        if len(vec) == 0:
            return [[]]
        return [[item.clone()] + items for item in vec[0] for items in aux(vec[1:])]

    count = 1
    max_len = 0
    for f in fields:
        count *= len(f)
        max_len = max(max_len, len(f))

    # Limit products to 32 elements to prevent combinatorial
    # explosion.
    if count <= 32:
        return aux(fields)

    # If too many products, select samples which test all fields value
    # values at the minimum.
    else:
        return [[f[idx % len(f)] for f in fields] for idx in range(0, max_len + 1)]


def serialize_values(file: ast.File, values: List[Value]) -> List[dict]:
    results = []
    for v in values:
        v.finalize()
        packed = v.serialize(file.endianness.value == "big_endian")
        result = {
            "packed": "".join([f"{b:02x}" for b in packed]),
            "unpacked": v.to_json(),
        }
        if v.ref.parent_id:
            result["packet"] = v.ref.id
        results.append(result)
    return results


def run(input: Path, packet: List[str]):
    with open(input) as f:
        file = ast.File.from_json(json.load(f))
    core.desugar(file)

    results = dict()
    for decl in file.packet_scope.values():
        if core.get_derived_packets(decl) or (packet and decl.id not in packet):
            continue

        try:
            values = generate_packet_values(decl)
            ancestor = core.get_packet_ancestor(decl)
            results[ancestor.id] = results.get(ancestor.id, []) + serialize_values(
                file, values
            )
        except Exception as exn:
            print(
                f"Skipping packet {decl.id}; cannot generate values: {exn}",
                file=sys.stderr,
            )

    results = [{"packet": k, "tests": v} for (k, v) in results.items()]
    json.dump(results, sys.stdout, indent=2)


def main() -> int:
    """Generate test vectors for top-level PDL packets."""
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        "--input", type=Path, required=True, help="Input PDL-JSON source"
    )
    parser.add_argument(
        "--packet",
        type=lambda x: x.split(","),
        required=False,
        action="extend",
        default=[],
        help="Select PDL packet to test",
    )
    return run(**vars(parser.parse_args()))


if __name__ == "__main__":
    sys.exit(main())
