#
# Copyright (C) 2020 Collabora, Ltd.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice (including the next
# paragraph) shall be included in all copies or substantial portions of the
# Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.

# Useful for autogeneration
COPYRIGHT = """/*
 * Copyright (C) 2020 Collabora, Ltd.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a
 * copy of this software and associated documentation files (the "Software"),
 * to deal in the Software without restriction, including without limitation
 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
 * and/or sell copies of the Software, and to permit persons to whom the
 * Software is furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice (including the next
 * paragraph) shall be included in all copies or substantial portions of the
 * Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */

/* Autogenerated file, do not edit */

"""

# Parse instruction set XML into a normalized form for processing

import xml.etree.ElementTree as ET
import copy
import itertools
from collections import OrderedDict

def parse_cond(cond, aliased = False):
    if cond.tag == 'reserved':
        return None
    
    if cond.attrib.get('alias', False) and not aliased:
        return ['alias', parse_cond(cond, True)]

    if 'left' in cond.attrib:
        return [cond.tag, cond.attrib['left'], cond.attrib['right']]
    else:
        return [cond.tag] + [parse_cond(x) for x in cond.findall('*')]

def parse_exact(obj):
    return [int(obj.attrib['mask'], 0), int(obj.attrib['exact'], 0)]

def parse_derived(obj):
    out = []

    for deriv in obj.findall('derived'):
        loc = [int(deriv.attrib['start']), int(deriv.attrib['size'])]
        count = 1 << loc[1]

        opts = [parse_cond(d) for d in deriv.findall('*')]
        default = [None] * count
        opts_fit = (opts + default)[0:count]

        out.append([loc, opts_fit])

    return out

def parse_modifiers(obj, include_pseudo):
    out = []

    for mod in obj.findall('mod'):
        if mod.attrib.get('pseudo', False) and not include_pseudo:
            continue

        name = mod.attrib['name']
        start = mod.attrib.get('start', None)
        size = int(mod.attrib['size'])

        if start is not None:
            start = int(start)

        opts = [x.text if x.tag == 'opt' else x.tag for x in mod.findall('*')]

        if len(opts) == 0:
            if 'opt' in mod.attrib:
                opts = ['none', mod.attrib['opt']]

        # Find suitable default
        default = mod.attrib.get('default', 'none' if 'none' in opts else None)

        # Pad out as reserved
        count = (1 << size)
        opts = (opts + (['reserved'] * count))[0:count]
        out.append([[name, start, size], default, opts])

    return out

def parse_copy(enc, existing):
    for node in enc.findall('copy'):
        name = node.get('name')
        for ex in existing:
            if ex[0][0] == name:
                ex[0][1] = node.get('start')

mod_names = {
    'cmp'     : [['cmpf', None, None], None, ['eq', 'gt', 'ge', 'ne', 'lt', 'le', 'gtlt', 'total']],
    # FIXME: Valhall can accept any integer comparision, but the old IR generator only generated
    #   gt and ge, and left out lt and le. For now follow the old way so we can compare generated files,
    #   but we should switch over to the proper way once everything is working
#    'cmpfi'    : [['cmpf', None, None], None, ['eq', 'ne', 'gt', 'ge', 'lt', 'le']],
    'cmpfi'     : [['cmpf', None, None], None, ['eq', 'ne', 'gt', 'ge']],
    'eq'        : [['cmpf', None, None], 'ne', ['eq', 'ne']],
    'dimension' : [['dimension', None, None], None, ['1d', '2d', '3d', 'cube']],
    'fetch_component' : [['fetch_component', None, None], None, ['gather4_r', 'gather4_g', 'gather4_b', 'gather4_a']],
    'lod_mode': [['va_lod_mode', None, None], 'zero_lod', ['zero_lod', 'computed_lod', 'explicit', 'computed_bias', 'grdesc']],
    'regfmt'  : [['register_format', None, None], None, ['f16', 'f32', 's32', 'u32', 's16', 'u16', 'f64', 'i64', 'auto']],
    'result_type' : [['result_type', None, None], None, ['i1', 'f1', 'm1']],
    'sample'  : [['sample', None, None], 'none', ['center', 'centroid', 'sample', 'explicit', 'none']],
    'update'  : [['update', None, None], None, ['store', 'retrieve', 'conditional', 'clobber']],
    'vecsize' : [['vecsize', None, None], 'none', ['none', 'v2', 'v3', 'v4']],
    'source_format' : [['source_format', None, None], None, ['flat32', 'flat16', 'f32', 'f16']],

    'array_enable': [['array_enable', None, None], 'none', ['none', 'array_enable']], 
    'integer_coordinates': [['integer_coordinates', None, None], 'none', ['none', 'integer_coordinates']],
    'shadow'      : [['shadow', None, None], 'none', ['none', 'shadow']], 
    'skip'        : [['skip', None, None], 'none', ['none', 'skip']], 
    'texel_offset': [['texel_offset', None, None], 'none', ['none', 'texel_offset']], 
    'wide_indices': [['wide_indices', None, None], 'none', ['none', 'wide_indices']], 
    'write_mask' : [['write_mask', None, None], 'none', ['none', 'r', 'g', 'rg', 'b', 'rb', 'gb', 'rgb', 'a', 'ra', 'ga', 'rga', 'ba', 'rba', 'gba', 'rgba']]
    }

def parse_instruction(ins, include_pseudo):
    common = {
            'srcs': [],
            'modifiers': [],
            'immediates': [],
            'swaps': [],
            'derived': [],
            'staging': ins.attrib.get('staging', '').split('=')[0],
            'staging_count': ins.attrib.get('staging', '=0').split('=')[1],
            'dests': int(ins.attrib.get('dests', '1')),
            'variable_dests': ins.attrib.get('variable_dests', False),
            'variable_srcs': ins.attrib.get('variable_srcs', False),
            'unused': ins.attrib.get('unused', False),
            'pseudo': ins.attrib.get('pseudo', False),
            'message': ins.attrib.get('message', 'none'),
            'last': ins.attrib.get('last', False),
            'table': ins.attrib.get('table', False),
    }

    if 'exact' in ins.attrib:
        common['exact'] = parse_exact(ins)

    extra_modifiers=[]
    src_num = 0
    for src in ins.findall('src'):
        if src.attrib.get('pseudo', False) and not include_pseudo:
            continue

        mask = int(src.attrib['mask'], 0) if ('mask' in src.attrib) else 0xFF
        if src.attrib.get('start') is not None:
            common['srcs'].append([int(src.attrib['start'], 0), mask])
        else:
            common['srcs'].append([src_num*3, mask])
        if src.attrib.get('absneg', False):
            extra_modifiers.append([['neg'+str(src_num), '0', '1'], 'neg',['none', 'neg']])
            extra_modifiers.append([['abs'+str(src_num), '0', '1'], 'abs',['none', 'abs']])
        src_num += 1

    for imm in ins.findall('immediate'):
        if imm.attrib.get('pseudo', False) and not include_pseudo:
            continue

        start = int(imm.attrib['start']) if 'start' in imm.attrib else None
        common['immediates'].append([imm.attrib['name'], start, int(imm.attrib['size'])])

    # FIXME valhall ISA.xml uses <imm/> instead of <immediate/>
    for imm in ins.findall('imm'):
        if imm.attrib.get('pseudo', False) and not include_pseudo:
            continue

        base_name = imm.attrib['name']
        name = imm.attrib.get('ir_name', base_name)
        if not name:
            continue

        start = int(imm.attrib['start']) if 'start' in imm.attrib else None
        common['immediates'].append([name, start, int(imm.attrib['size'])])

    staging_read = False
    staging_write = False
    for sr in ins.findall('sr'):
        if sr.attrib.get('read', False):
            staging_read = True
        if sr.attrib.get('write', False):
            staging_write = True

    if staging_read:
        common['staging'] = 'r'
    if staging_write:
        common['staging'] += 'w'
    for sr in ins.findall('sr_count'):
        size = sr.attrib.get('count', 'sr_count')
        common['staging_count'] = size

    for m in ins.findall('*'):
        name = m.tag
        if name == 'cmp':
            if m.attrib.get('int_only', False):
                name = 'cmpfi'
            elif m.attrib.get('eqne_only', False):
                name = 'cmpfeq'
        if name == 'va_mod':
            name = m.attrib.get('name', '')
        if name in mod_names:
            extra_modifiers.append(mod_names[name])

    common['derived'] = parse_derived(ins)
    common['modifiers'] = parse_modifiers(ins, include_pseudo) + extra_modifiers

    for swap in ins.findall('swap'):
        lr = [int(swap.get('left')), int(swap.get('right'))]
        cond = parse_cond(swap.findall('*')[0])
        rewrites = {}

        for rw in swap.findall('rewrite'):
            mp = {}

            for m in rw.findall('map'):
                mp[m.attrib['from']] = m.attrib['to']

            rewrites[rw.attrib['name']] = mp

        common['swaps'].append([lr, cond, rewrites])

    encodings = ins.findall('encoding')
    variants = []

    if len(encodings) == 0:
        variants = [[None, common]]
    else:
        for enc in encodings:
            variant = copy.deepcopy(common)
            assert(len(variant['derived']) == 0)

            variant['exact'] = parse_exact(enc)
            variant['derived'] = parse_derived(enc)
            parse_copy(enc, variant['modifiers'])

            cond = parse_cond(enc.findall('*')[0])
            variants.append([cond, variant])

    return variants

def ins_name(ins, group = False):
    # a historical artifact: the first character of the name should contain
    # a single character tag to indicate the unit: '+' for add, '*' for fma
    # bifrost has only those two units, valhall has others but for those
    # it doesn't matter (pretend they are '+')
    if not group:
        group = ins

    base_name = ins.attrib['name']
    if group.attrib['unit'] == 'fma':
        tagged_name = '*' + base_name
    else:
        tagged_name = '+' + base_name
    return tagged_name

def parse_instructions(xml, include_unused = False, include_pseudo = False):
    final = {}

    # look at groups of instructions
    groups = ET.parse(xml).getroot().findall('group')
    for gr in groups:
        if gr.attrib.get('unused', False) and not include_unused:
            continue
        if gr.attrib.get('pseudo', False) and not include_pseudo:
            continue
        group_base = parse_instruction(gr, include_pseudo)
        for ins in gr.findall('ins'):
            parsed = copy.deepcopy(group_base)
            tagged_name = ins_name(ins, gr)
            final[tagged_name] = parsed

    # now look at individual instructions
    instructions = ET.parse(xml).getroot().findall('ins')

    for ins in instructions:
        parsed = parse_instruction(ins, include_pseudo)

        # Some instructions are for useful disassembly only and can be stripped
        # out of the compiler, particularly useful for release builds
        if parsed[0][1]["unused"] and not include_unused:
            continue

        # On the other hand, some instructions are only for the IR, not disassembly
        if parsed[0][1]["pseudo"] and not include_pseudo:
            continue

        tagged_name = ins_name(ins)
        final[tagged_name] = parsed

    return final

# Expand out an opcode name to something C-escaped

def opname_to_c(name):
    return name.lower().replace('*', 'fma_').replace('+', 'add_').replace('.', '_')

# Expand out distinct states to distrinct instructions, with a placeholder
# condition for instructions with a single state

def expand_states(instructions):
    out = {}

    for ins in instructions:
        c = instructions[ins]

        for ((test, desc), i) in zip(c, range(len(c))):
            # Construct a name for the state
            name = ins + (('.' + str(i)) if len(c) > 1 else '')

            out[name] = (ins, test if test is not None else [], desc)

    return out

# Drop keys used for packing to simplify IR representation, so we can check for
# equivalence easier

def simplify_to_ir(ins):
    return {
            'staging': ins['staging'],
            'srcs': len(ins['srcs']),
            'dests': ins['dests'],
            'variable_dests': ins['variable_dests'],
            'variable_srcs': ins['variable_srcs'],
            'modifiers': [[m[0][0], m[2]] for m in ins['modifiers']],
            'immediates': [m[0] for m in ins['immediates']]
        }

# Converstions to integers default to rounding-to-zero
# All other opcodes default to rounding to nearest even
def default_round_to_zero(name):
    # 8-bit int to float is exact
    subs = ['_TO_U', '_TO_S', '_TO_V2U', '_TO_V2S', '_TO_V4U', '_TO_V4S']
    return any([x in name for x in subs])

def combine_ir_variants(instructions, key):
    seen = [op for op in instructions.keys() if op[1:] == key]
    variant_objs = [[simplify_to_ir(Q[1]) for Q in instructions[x]] for x in seen]
    variants = sum(variant_objs, [])

    # Accumulate modifiers across variants
    modifiers = {}

    for s in variants[0:]:
        # Check consistency
        assert(s['srcs'] == variants[0]['srcs'])
        assert(s['dests'] == variants[0]['dests'])
        assert(s['immediates'] == variants[0]['immediates'])
        assert(s['staging'] == variants[0]['staging'])

        for name, opts in s['modifiers']:
            if name not in modifiers:
                modifiers[name] = copy.deepcopy(opts)
            else:
                modifiers[name] += opts

    # Great, we've checked srcs/immediates are consistent and we've summed over
    # modifiers
    return {
            'key': key,
            'srcs': variants[0]['srcs'],
            'dests': variants[0]['dests'],
            'variable_dests': variants[0]['variable_dests'],
            'variable_srcs': variants[0]['variable_srcs'],
            'staging': variants[0]['staging'],
            'immediates': sorted(variants[0]['immediates']),
            'modifiers': modifiers,
            'v': len(variants),
            'ir': variants,
            'rtz': default_round_to_zero(key)
        }

# Partition instructions to mnemonics, considering units and variants
# equivalent.

def partition_mnemonics(instructions):
    key_func = lambda x: x[1:]
    sorted_instrs = sorted(instructions.keys(), key = key_func)
    partitions = itertools.groupby(sorted_instrs, key_func)
    return { k: combine_ir_variants(instructions, k) for k, v in partitions }

# Generate modifier lists, by accumulating all the possible modifiers, and
# deduplicating thus assigning canonical enum values. We don't try _too_ hard
# to be clever, but by preserving as much of the original orderings as
# possible, later instruction encoding is simplified a bit.  Probably a micro
# optimization but we have to pick _some_ ordering, might as well choose the
# most convenient.
#
# THIS MUST BE DETERMINISTIC

def order_modifiers(ir_instructions):
    out = {}

    # modifier name -> (list of option strings)
    modifier_lists = {}

    for ins in sorted(ir_instructions):
        modifiers = ir_instructions[ins]["modifiers"]

        for name in modifiers:
            name_ = name[0:-1] if name[-1] in "0123" else name

            if name_ not in modifier_lists:
                modifier_lists[name_] = copy.deepcopy(modifiers[name])
            else:
                modifier_lists[name_] += modifiers[name]

    for mod in modifier_lists:
        lst = list(OrderedDict.fromkeys(modifier_lists[mod]))

        # Ensure none is false for booleans so the builder makes sense
        if len(lst) == 2 and lst[1] == "none":
            lst.reverse()
        elif mod == "table":
            # We really need a zero sentinel to materialize DTSEL
            assert(lst[2] == "none")
            lst[2] = lst[0]
            lst[0] = "none"

        out[mod] = lst

    return out

# Count sources for a simplified (IR) instruction, including a source for a
# staging register if necessary
def src_count(op):
    staging = 1 if (op["staging"] in ["r", "rw"]) else 0
    return op["srcs"] + staging

# Parses out the size part of an opocde name
def typesize(opcode):
    if opcode[-3:] == '128':
        return 128
    if opcode[-2:] == '48':
        return 48
    elif opcode[-1] == '8':
        return 8
    else:
        try:
            return int(opcode[-2:])
        except:
            return 32
