#!/usr/bin/env python3

import json
import os
import sys

PRINT_ORIGINAL_FULL = False

# This flags are augmented with flags added to the json files but not present in .gn or .gni files
IGNORED_FLAGS = [
    '-D_DEBUG',
    '-Werror',
    '-Xclang',
    '-target-feature',
    '+crc',
    '+crypto',
]
IGNORED_DEFINES = [
    'HAVE_ARM64_CRC32C=1'
]
DEFAULT_CFLAGS = [
    '-DHAVE_ARM64_CRC32C=0',
    '-DUSE_AURA=1',
    '-DUSE_GLIB=1',
    '-DUSE_NSS_CERTS=1',
    '-DUSE_UDEV',
    '-DUSE_X11=1',
    '-DWEBRTC_ANDROID_PLATFORM_BUILD=1',
    '-DWEBRTC_APM_DEBUG_DUMP=0',
    '-D_FILE_OFFSET_BITS=64',
    '-D_GNU_SOURCE',
    '-D_LARGEFILE64_SOURCE',
    '-D_LARGEFILE_SOURCE',
    '-Wno-global-constructors',
    '-Wno-implicit-const-int-float-conversion',
    '-Wno-missing-field-initializers',
    '-Wno-unreachable-code-aggressive',
    '-Wno-unreachable-code-break',
]

DEFAULT_CFLAGS_BY_ARCH = {
        'x86': ['-mavx2', '-mfma', '-msse2', '-msse3'],
        'x64': ['-mavx2', '-mfma', '-msse2', '-msse3'],
        'arm': ['-mthumb'],
        'arm64': [],
        'riscv64': [],
        }

FLAGS = ['cflags', 'cflags_c', 'cflags_cc', 'asmflags']
FLAG_NAME_MAP = {
    'cflags': 'cflags',
    'asmflags': 'asflags',
    'cflags_cc': 'cppflags',
    'cflags_c': 'conlyflags',
}

ARCH_NAME_MAP = {n: n for n in DEFAULT_CFLAGS_BY_ARCH.keys()}
ARCH_NAME_MAP['x64'] = 'x86_64'

ARCHS = sorted(ARCH_NAME_MAP.keys())

def FormatList(l):
    return json.dumps(sorted(list(l)))

def IsInclude(name):
    return name.endswith('.h') or name.endswith('.inc')

def FilterIncludes(l):
    return filter(lambda x: not IsInclude(x), l)

def PrintOrigin(target):
    print('/* From target:')
    if PRINT_ORIGINAL_FULL:
        print(json.dumps(target, sort_keys = True, indent = 4))
    else:
        print(target['original_name'])
    print('*/')

def MakeRelatives(l):
    return map(lambda x: x.split('//').pop(), l)

def FormatName(name):
    return 'webrtc_' + name.split('/').pop().replace(':', '__')

def FormatNames(target):
    target['original_name'] = target['name']
    target['name'] = FormatName(target['name'])
    target['deps'] = sorted([FormatName(d) for d in target['deps']])
    return target

def FilterFlags(flags, to_skip = set()):
    skipped_opts = set(IGNORED_FLAGS).union(to_skip)
    return [x for x in flags if not any([x.startswith(y) for y in skipped_opts])]

def PrintHeader():
    print('package {')
    print('    default_applicable_licenses: ["external_webrtc_license"],')
    print('}')
    print('')
    print('// Added automatically by a large-scale-change that took the approach of')
    print('// \'apply every license found to every target\'. While this makes sure we respect')
    print('// every license restriction, it may not be entirely correct.')
    print('//')
    print('// e.g. GPL in an MIT project might only apply to the contrib/ directory.')
    print('//')
    print('// Please consider splitting the single license below into multiple licenses,')
    print('// taking care not to lose any license_kind information, and overriding the')
    print('// default license using the \'licenses: [...]\' property on targets as needed.')
    print('//')
    print('// For unused files, consider creating a \'fileGroup\' with "//visibility:private"')
    print('// to attach the license to, and including a comment whether the files may be')
    print('// used in the current project.')
    print('//')
    print('// large-scale-change included anything that looked like it might be a license')
    print('// text as a license_text. e.g. LICENSE, NOTICE, COPYING etc.')
    print('//')
    print('// Please consider removing redundant or irrelevant files from \'license_text:\'.')
    print('// See: http://go/android-license-faq')
    print('')
    print('///////////////////////////////////////////////////////////////////////////////')
    print('// Do not edit this file directly, it\'s automatically generated by a script. //')
    print('// Modify android_tools/generate_android_bp.py and run that instead.         //')
    print('///////////////////////////////////////////////////////////////////////////////')
    print('')
    print('license {')
    print('    name: "external_webrtc_license",')
    print('    visibility: [":__subpackages__"],')
    print('    license_kinds: [')
    print('        "SPDX-license-identifier-Apache-2.0",')
    print('        "SPDX-license-identifier-BSD",')
    print('        "SPDX-license-identifier-MIT",')
    print('        "SPDX-license-identifier-Zlib",')
    print('        "legacy_notice",')
    print('        "legacy_unencumbered",')
    print('    ],')
    print('    license_text: [')
    print('        "LICENSE",')
    print('        "PATENTS",')
    print('        "license_template.txt",')
    print('    ],')
    print('}')



def GatherDefaultFlags(targets_by_arch):
    # Iterate through all of the targets for each architecture collecting the flags that
    # are the same for all targets in that architecture.  Use a list instead of a set
    # to maintain the flag ordering, which may be significant (e.g. -Wno-shadow has to
    # come after -Wshadow).
    arch_default_flags = {}
    for arch, targets in targets_by_arch.items():
        arch_default_flags[arch] = {}
        for target in targets.values():
            typ = target['type']
            if typ != 'static_library':
                continue
            for flag_type in FLAGS:
                if not flag_type in arch_default_flags:
                    arch_default_flags[arch][flag_type] = target[flag_type]
                else:
                    target_flags = set(target[flag_type])
                    flags = arch_default_flags[arch][flag_type]
                    flags[:]  = [ x for x in flags if x in target_flags ]
        for flag_type, flags in arch_default_flags[arch].items():
            arch_default_flags[arch][flag_type] = FilterFlags(flags)
        # Add in the hardcoded extra default cflags
        arch_default_flags[arch]['cflags'] += DEFAULT_CFLAGS_BY_ARCH.get(arch, [])

    # Iterate through all of the architectures collecting the flags that are the same
    # for all targets in all architectures.
    default_flags = {}
    for arch, flagset in arch_default_flags.items():
        for flag_type, arch_flags in flagset.items():
            if not flag_type in default_flags:
                default_flags[flag_type] = arch_flags.copy()
            else:
                flags = default_flags[flag_type]
                flags[:] = [ x for x in flags if x in arch_flags ]
    # Add in the hardcoded extra default cflags
    default_flags['cflags'] += DEFAULT_CFLAGS

    # Remove the global default flags from the per-architecture default flags
    for arch, flagset in arch_default_flags.items():
        for flag_type in flagset.keys():
            flags = flagset[flag_type]
            flags[:] = [ x for x in flags if x not in default_flags[flag_type] ]

    default_flags['arch'] = arch_default_flags
    return default_flags

def GenerateDefault(targets_by_arch):
    in_default = GatherDefaultFlags(targets_by_arch)
    print('cc_defaults {')
    print('    name: "webrtc_defaults",')
    print('    local_include_dirs: [')
    print('      ".",')
    print('      "webrtc",')
    print('      "third_party/crc32c/src/include",')
    print('    ],')
    for typ in sorted(in_default.keys() - {'arch'}):
        flags = in_default[typ]
        if len(flags) > 0:
            print('    {0}: ['.format(FLAG_NAME_MAP[typ]))
            for flag in flags:
                print('        "{0}",'.format(flag.replace('"', '\\"')))
            print('    ],')
    print('    static_libs: [')
    print('        "libabsl",')
    print('        "libaom",')
    print('        "libevent",')
    print('        "libopus",')
    print('        "libsrtp2",')
    print('        "libvpx",')
    print('        "libyuv",')
    print('        "libpffft",')
    print('        "rnnoise_rnn_vad",')
    print('    ],')
    print('    shared_libs: [')
    print('        "libcrypto",')
    print('        "libprotobuf-cpp-full",')
    print('        "libprotobuf-cpp-lite",')
    print('        "libssl",')
    print('    ],')
    print('    host_supported: true,')
    print('    // vendor needed for libpreprocessing effects.')
    print('    vendor: true,')
    print('    target: {')
    print('        darwin: {')
    print('            enabled: false,')
    print('        },')
    print('    },')
    print('    arch: {')
    for a in ARCHS:
        print('        {0}: {{'.format(ARCH_NAME_MAP[a]))
        for typ in FLAGS:
            flags = in_default['arch'].get(a, {}).get(typ, [])
            if len(flags) > 0:
                print('            {0}: ['.format(FLAG_NAME_MAP[typ]))
                for flag in flags:
                    print('                "{0}",'.format(flag.replace('"', '\\"')))
                print('            ],')
        print('        },')
    print('    },')
    print('    visibility: [')
    print('        "//frameworks/av/media/libeffects/preprocessing:__subpackages__",')
    print('        "//device/google/cuttlefish/host/frontend/webrtc:__subpackages__",')
    print('    ],')
    print('}')

    # The flags in the default entry can be safely removed from the targets
    for arch, targets in targets_by_arch.items():
        for flag_type in FLAGS:
            default_flags = set(in_default[flag_type]) | set(in_default['arch'][arch][flag_type])
            for target in targets.values():
                target[flag_type] = FilterFlags(target.get(flag_type, []), default_flags)
                if len(target[flag_type]) == 0:
                    target.pop(flag_type)

    return in_default


def TransitiveDependencies(name, dep_type, targets):
    target = targets[name]
    field = 'transitive_' + dep_type
    if field in target.keys():
        return target[field]
    target[field] = {'global': set()}
    for a in ARCHS:
        target[field][a] = set()
    if target['type'] == dep_type:
        target[field]['global'].add(name)
    for d in target.get('deps', []):
        if targets[d]['type'] == dep_type:
            target[field]['global'].add(d)
        tDeps = TransitiveDependencies(d, dep_type, targets)
        target[field]['global'] |= tDeps['global']
        for a in ARCHS:
            target[field][a] |= tDeps[a]
    if 'arch' in target:
        for a, x in target['arch'].items():
            for d in x.get('deps', []):
                tDeps = TransitiveDependencies(d, dep_type, targets)
                target[field][a] |= tDeps['global'] | tDeps[a]
            target[field][a] -= target[field]['global']

    return target[field]

def GenerateGroup(target):
    # PrintOrigin(target)
    pass

def GenerateStaticLib(target, targets):
    PrintOrigin(target)
    name = target['name']
    print('cc_library_static {')
    print('    name: "{0}",'.format(name))
    print('    defaults: ["webrtc_defaults"],')
    sources = target.get('sources', [])
    print('    srcs: {0},'.format(FormatList(sources)))
    print('    host_supported: true,')
    if 'asmflags' in target.keys():
        asmflags = target['asmflags']
        if len(asmflags) > 0:
            print('    asflags: {0},'.format(FormatList(asmflags)))
    if 'cflags' in target.keys():
        cflags = target['cflags']
        print('    cflags: {0},'.format(FormatList(cflags)))
    if 'cflags_c' in target.keys():
        cflags_c = target['cflags_c']
        if len(cflags_c) > 0:
            print('    conlyflags: {0},'.format(FormatList(cflags_c)))
    if 'cflags_cc' in target.keys():
        cflags_cc = target['cflags_cc']
        if len(cflags_cc) > 0:
            print('    cppflags: {0},'.format(FormatList(cflags_cc)))
    if 'arch' in target:
        print('   arch: {')
        for arch_name in ARCHS:
            if arch_name not in target['arch'].keys():
                continue
            arch = target['arch'][arch_name]
            print('       ' + ARCH_NAME_MAP[arch_name] + ': {')
            if 'cflags' in arch.keys():
                cflags = arch['cflags']
                print('            cflags: {0},'.format(FormatList(cflags)))
            if 'cflags_c' in arch.keys():
                cflags_c = arch['cflags_c']
                if len(cflags_c) > 0:
                    print('            conlyflags: {0},'.format(FormatList(cflags_c)))
            if 'cflags_cc' in arch.keys():
                cflags_cc = arch['cflags_cc']
                if len(cflags_cc) > 0:
                    print('            cppflags: {0},'.format(FormatList(cflags_cc)))
            if 'sources' in arch:
                  sources = arch['sources']
                  print('            srcs: {0},'.format(FormatList(sources)))
            if 'enabled' in arch:
                print('            enabled: {0},'.format(arch['enabled']))
            print('        },')
        print('   },')
    print('}')
    return name

def DFS(seed, targets):
    visited = set()
    stack = [seed]
    while len(stack) > 0:
        nxt = stack.pop()
        if nxt in visited:
            continue
        visited.add(nxt)
        stack += targets[nxt]['deps']
        if 'arch' not in targets[nxt]:
            continue
        for arch in targets[nxt]['arch']:
            if 'deps' in arch:
                stack += arch['deps']
    return visited

def Preprocess(project):
    targets = {}
    for name, target in project['targets'].items():
        target['name'] = name
        targets[name] = target
        if target['type'] == 'shared_library':
            # Don't bother creating shared libraries
            target['type'] = 'static_library'
        if target['type'] == 'source_set':
            # Convert source_sets to static libraires to avoid recompiling sources multiple times.
            target['type'] = 'static_library'
        if 'defines' in target:
            target['cflags'] = target.get('cflags', []) + ['-D{0}'.format(d) for d in target['defines'] if d not in IGNORED_DEFINES]
            target.pop('defines')
        if 'sources' not in target:
            continue
        sources = list(MakeRelatives(FilterIncludes(target['sources'])))
        if len(sources) > 0:
            target['sources'] = sources
        else:
            target.pop('sources')

    # These dependencies are provided by aosp
    ignored_targets = {
            '//third_party/libaom:libaom',
            '//third_party/libevent:libevent',
            '//third_party/opus:opus',
            '//third_party/libsrtp:libsrtp',
            '//third_party/libvpx:libvpx',
            '//third_party/libyuv:libyuv',
            '//third_party/pffft:pffft',
            '//third_party/rnnoise:rnn_vad',
            '//third_party/boringssl:boringssl',
            '//third_party/android_ndk:cpu_features',
            '//buildtools/third_party/libunwind:libunwind',
            '//buildtools/third_party/libc++:libc++',
        }
    for name, target in targets.items():
        # Skip all "action" targets
        if target['type'] in {'action', 'action_foreach'}:
            ignored_targets.add(name)

    def is_ignored(target):
        if target.startswith('//third_party/abseil-cpp'):
            return True
        return target in ignored_targets

    targets = {name: target for name, target in targets.items() if not is_ignored(name)}

    for target in targets.values():
        # Don't depend on ignored targets
        target['deps'] = [d for d in target['deps'] if not is_ignored(d) ]

    # Ignore empty static libraries
    empty_libs = set()
    for name, target in targets.items():
        if target['type'] == 'static_library' and 'sources' not in target and name != '//:webrtc':
            empty_libs.add(name)
    for empty_lib in empty_libs:
        empty_lib_deps = targets[empty_lib].get('deps', [])
        for target in targets.values():
            target['deps'] = FlattenEmptyLibs(target['deps'], empty_lib, empty_lib_deps)
    for s in empty_libs:
        targets.pop(s)

    # Select libwebrtc, libaudio_processing and its dependencies
    selected = set()
    selected |= DFS('//:webrtc', targets)
    selected |= DFS('//modules/audio_processing:audio_processing', targets)

    return {FormatName(n): FormatNames(targets[n]) for n in selected}

def _FlattenEmptyLibs(deps, empty_lib, empty_lib_deps):
    for x in deps:
        if x == empty_lib:
            yield from empty_lib_deps
        else:
            yield x

def FlattenEmptyLibs(deps, empty_lib, empty_lib_deps):
    return list(_FlattenEmptyLibs(deps, empty_lib, empty_lib_deps))

def NonNoneFrom(l):
    for a in l:
        if a is not None:
            return a
    return None

def MergeListField(target, f, target_by_arch):
    set_by_arch = {}
    for a, t in target_by_arch.items():
        if len(t) == 0:
            # We only care about enabled archs
            continue
        set_by_arch[a] = set(t.get(f, []))

    union = set()
    for _, s in set_by_arch.items():
        union |= s

    common = union
    for a, s in set_by_arch.items():
        common &= s

    not_common = {a: s - common for a,s in set_by_arch.items()}

    if len(common) > 0:
        target[f] = list(common)
    for a, s in not_common.items():
        if len(s) > 0:
            target['arch'][a][f] = sorted(list(s))

def Merge(target_by_arch):
    # The new target shouldn't have the transitive dependencies memoization fields
    # or have the union of those fields from all 4 input targets.
    target = {}
    for f in ['original_name', 'name', 'type']:
        target[f] = NonNoneFrom([t.get(f) for _,t in target_by_arch.items()])

    target['arch'] = {}
    for a, t in target_by_arch.items():
        target['arch'][a] = {}
        if len(t) == 0:
            target['arch'][a]['enabled'] = 'false'

    list_fields = ['sources',
                   'deps',
                   'cflags',
                   'cflags_c',
                   'cflags_cc',
                   'asmflags']
    for lf in list_fields:
        MergeListField(target, lf, target_by_arch)

    # Static libraries should be depended on at the root level and disabled for
    # the corresponding architectures.
    for arch in target['arch'].values():
        if 'deps' not in arch:
            continue
        deps = arch['deps']
        if 'deps' not in target:
            target['deps'] = []
        target['deps'] += deps
        arch.pop('deps')
    if 'deps' in target:
        target['deps'] = sorted(target['deps'])

    # Remove empty sets
    for a in ARCHS:
        if len(target['arch'][a]) == 0:
            target['arch'].pop(a)
    if len(target['arch']) == 0:
        target.pop('arch')

    return target

def DisabledArchs4Target(target):
    ret = set()
    for a in ARCHS:
        if a not in target.get('arch', {}):
            continue
        if target['arch'][a].get('enabled', 'true') == 'false':
            ret.add(a)
    return ret


def HandleDisabledArchs(targets):
    for n, t in targets.items():
        if 'arch' not in t:
            continue
        disabledArchs = DisabledArchs4Target(t)
        if len(disabledArchs) == 0:
            continue
        # Fix targets that depend on this one
        for t in targets.values():
            if DisabledArchs4Target(t) == disabledArchs:
                # With the same disabled archs there is no need to move dependencies
                continue
            if 'deps' in t and n in t['deps']:
                # Remove the dependency from the high level list
                t['deps'] = sorted(set(t['deps']) - {n})
                if 'arch' not in t:
                    t['arch'] = {}
                for a in ARCHS:
                    if a in disabledArchs:
                        continue
                    if a not in t['arch']:
                        t['arch'][a] = {}
                    if 'deps' not in t['arch'][a]:
                        t['arch'][a]['deps'] = []
                    t['arch'][a]['deps'] += [n]

def MergeAll(targets_by_arch):
    names = set()
    for t in targets_by_arch.values():
        names |= t.keys()
    targets = {}
    for name in names:
        targets[name] = Merge({a: t.get(name, {}) for a,t in targets_by_arch.items()})

    HandleDisabledArchs(targets)

    return targets

def GatherAllFlags(obj):
    if type(obj) != type({}):
        # not a dictionary
        return set()
    ret = set()
    for f in FLAGS:
        ret |= set(obj.get(f, []))
    for v in obj.values():
        ret |= GatherAllFlags(v)
    return ret

def FilterFlagsInUse(flags, directory):
    unused = []
    for f in flags:
        nf = f
        if nf.startswith("-D"):
            nf = nf[2:]
            i = nf.find('=')
            if i > 0:
                nf = nf[:i]
        c = os.system(f"find {directory} -name '*.gn*' | xargs grep -q -s -e '{nf}'")
        if c != 0:
            # couldn't find the flag in *.gn or *.gni
            unused.append(f)
    return unused

if len(sys.argv) != 2:
    print('wrong number of arguments', file = sys.stderr)
    exit(1)

dir = sys.argv[1]

targets_by_arch = {}
flags = set()
for arch in ARCHS:
    path = "{0}/project_{1}.json".format(dir, arch)
    json_file = open(path, 'r')
    targets_by_arch[arch] = Preprocess(json.load(json_file))
    flags |= GatherAllFlags(targets_by_arch[arch])

unusedFlags = FilterFlagsInUse(flags, f"{dir}/..")
IGNORED_FLAGS = sorted(set(IGNORED_FLAGS) | set(unusedFlags))

PrintHeader()

GenerateDefault(targets_by_arch)

targets = MergeAll(targets_by_arch)

print('\n\n')

for name, target in sorted(targets.items()):
    typ = target['type']
    if typ == 'static_library':
        GenerateStaticLib(target, targets)
    elif typ == 'group':
        GenerateGroup(target)
    else:
        print('Unknown type: {0} ({1})'.format(typ, target['name']), file = sys.stderr)
        exit(1)
    print('\n\n')

webrtc_libs = TransitiveDependencies(FormatName('//:webrtc'), 'static_library', targets)
print('cc_library_static {')
print('    name: "libwebrtc",')
print('    defaults: ["webrtc_defaults"],')
print('    export_include_dirs: ["."],')
print('    whole_static_libs: {0},'.format(FormatList(sorted(webrtc_libs['global']) + ['libpffft', 'rnnoise_rnn_vad'])))
print('    arch: {')
for a in ARCHS:
    if len(webrtc_libs[a]) > 0:
        print('        {0}: {{'.format(ARCH_NAME_MAP[a]))
        print('            whole_static_libs: {0},'.format(FormatList(sorted(webrtc_libs[a]))))
        print('        },')
print('    },')
print('}')

print('\n\n')

audio_proc_libs = TransitiveDependencies(FormatName('//modules/audio_processing:audio_processing'), 'static_library', targets)
print('cc_library_static {')
print('    name: "webrtc_audio_processing",')
print('    defaults: ["webrtc_defaults"],')
print('    export_include_dirs: [')
print('        ".",')
print('        "modules/include",')
print('        "modules/audio_processing/include",')
print('    ],')
print('    whole_static_libs: {0},'.format(FormatList(sorted(audio_proc_libs['global']) + ['libpffft', 'rnnoise_rnn_vad'])))
print('    arch: {')
for a in ARCHS:
    if len(audio_proc_libs[a]) > 0:
        print('        {0}: {{'.format(ARCH_NAME_MAP[a]))
        print('            whole_static_libs: {0},'.format(FormatList(sorted(audio_proc_libs[a]))))
        print('        },')
print('    },')
print('}')
