#!/usr/bin/env python3
import argparse
import fnmatch
import html
import io
import json
import os
import pathlib
import subprocess
import types
import sys
class Graph:
    def __init__(self, modules):
        def get_or_make_node(dictionary, id, module):
            node = dictionary.get(id)
            if node:
                if module and not node.module:
                    node.module = module
                return node
            node = Node(id, module)
            dictionary[id] = node
            return node
        self.nodes = dict()
        for module in modules.values():
            node = get_or_make_node(self.nodes, module.id, module)
            for d in module.deps:
                dep = get_or_make_node(self.nodes, d.id, None)
                node.deps.add(dep)
                dep.rdeps.add(node)
                node.dep_tags.setdefault(dep, list()).append(d)
    def find_paths(self, id1, id2, tag_filter):
        # Throws KeyError if one of the names isn't found
        def recurse(node1, node2, visited):
            result = set()
            for dep in node1.rdeps:
                if not matches_tag(dep, node1, tag_filter):
                    continue
                if dep == node2:
                    result.add(node2)
                if dep not in visited:
                    visited.add(dep)
                    found = recurse(dep, node2, visited)
                    if found:
                        result |= found
                        result.add(dep)
            return result
        node1 = self.nodes[id1]
        node2 = self.nodes[id2]
        # Take either direction
        p = recurse(node1, node2, set())
        if p:
            p.add(node1)
            return p
        p = recurse(node2, node1, set())
        p.add(node2)
        return p
class Node:
    def __init__(self, id, module):
        self.id = id
        self.module = module
        self.deps = set()
        self.rdeps = set()
        self.dep_tags = {}
PROVIDERS = [
    "android/soong/java.JarJarProviderData",
    "android/soong/java.BaseJarJarProviderData",
]
def format_dep_label(node, dep):
    tags = node.dep_tags.get(dep)
    labels = []
    if tags:
        labels = [tag.tag_type.split("/")[-1] for tag in tags]
        labels = sorted(set(labels))
    if labels:
        result = "<
"
        for label in labels:
            result += f"| {label} | 
"
        result += "
>"
        return result
def format_node_label(node, module_formatter):
    result = "<"
    # node name
    result += f"| {node.module.name if node.module else node.id} | 
"
    if node.module:
        # node_type
        result += f"| {node.module.type} | 
"
        # module_formatter will return a list of rows
        for row in module_formatter(node.module):
            row = html.escape(row)
            result += f"| {row} | 
"
    result += "
>"
    return result
def format_source_pos(file, lineno):
    result = file
    if lineno:
        result += f":{lineno}"
    return result
STRIP_TYPE_PREFIXES = [
    "android/soong/",
    "github.com/google/",
]
def format_provider(provider):
    result = ""
    for prefix in STRIP_TYPE_PREFIXES:
        if provider.type.startswith(prefix):
            result = provider.type[len(prefix):]
            break
    if not result:
        result = provider.type
    if True and provider.debug:
        result += " (" + provider.debug + ")"
    return result
def load_soong_debug():
    # Read the json
    try:
        with open(SOONG_DEBUG_DATA_FILENAME) as f:
            info = json.load(f, object_hook=lambda d: types.SimpleNamespace(**d))
    except IOError:
        sys.stderr.write(f"error: Unable to open {SOONG_DEBUG_DATA_FILENAME}. Make sure you have"
                         + " built with GENERATE_SOONG_DEBUG.\n")
        sys.exit(1)
    # Construct IDs, which are name + variant if the
    name_counts = dict()
    for m in info.modules:
        name_counts[m.name] = name_counts.get(m.name, 0) + 1
    def get_id(m):
        result = m.name
        if name_counts[m.name] > 1 and m.variant:
            result += "@@" + m.variant
        return result
    for m in info.modules:
        m.id = get_id(m)
        for dep in m.deps:
            dep.id = get_id(dep)
    return info
def load_modules():
    info = load_soong_debug()
    # Filter out unnamed modules
    modules = dict()
    for m in info.modules:
        if not m.name:
            continue
        modules[m.id] = m
    return modules
def load_graph():
    modules=load_modules()
    return Graph(modules)
def module_selection_args(parser):
    parser.add_argument("modules", nargs="*",
                        help="Modules to match. Can be glob-style wildcards.")
    parser.add_argument("--provider", nargs="+",
                        help="Match the given providers.")
    parser.add_argument("--dep", nargs="+",
                        help="Match the given providers.")
def load_and_filter_modules(args):
    # Which modules are printed
    matchers = []
    if args.modules:
        matchers.append(lambda m: [True for pattern in args.modules
                                   if fnmatch.fnmatchcase(m.name, pattern)])
    if args.provider:
        matchers.append(lambda m: [True for pattern in args.provider
                                   if [True for p in m.providers if p.type.endswith(pattern)]])
    if args.dep:
        matchers.append(lambda m: [True for pattern in args.dep
                                   if [True for d in m.deps if d.id == pattern]])
    if not matchers:
        sys.stderr.write("error: At least one module matcher must be supplied\n")
        sys.exit(1)
    info = load_soong_debug()
    for m in sorted(info.modules, key=lambda m: (m.name, m.variant)):
        if len([matcher for matcher in matchers if matcher(m)]) == len(matchers):
            yield m
def print_args(parser):
    parser.add_argument("--label", action="append", metavar="JQ_FILTER",
                        help="jq query for each module metadata")
    parser.add_argument("--deptags", action="store_true",
                        help="show dependency tags (makes the graph much more complex)")
    parser.add_argument("--tag", action="append", default=[],
                        help="Limit output to these dependency tags.")
    group = parser.add_argument_group("output formats",
                                      "If no format is provided, a dot file will be written to"
                                      + " stdout.")
    output = group.add_mutually_exclusive_group()
    output.add_argument("--dot", type=str, metavar="FILENAME",
                        help="Write the graph to this file as dot (graphviz format)")
    output.add_argument("--svg", type=str, metavar="FILENAME",
                        help="Write the graph to this file as svg")
def print_nodes(args, nodes, module_formatter):
    # Generate the graphviz
    dep_tag_id = 0
    dot = io.StringIO()
    dot.write("digraph {\n")
    dot.write("node [shape=box];")
    for node in nodes:
        dot.write(f"\"{node.id}\" [label={format_node_label(node, module_formatter)}];\n")
        for dep in node.deps:
            if dep in nodes:
                if args.deptags:
                    dot.write(f"\"{node.id}\" -> \"__dep_tag_{dep_tag_id}\" [ arrowhead=none ];\n")
                    dot.write(f"\"__dep_tag_{dep_tag_id}\" -> \"{dep.id}\";\n")
                    dot.write(f"\"__dep_tag_{dep_tag_id}\""
                                  + f"[label={format_dep_label(node, dep)} shape=ellipse"
                                  + " color=\"#666666\" fontcolor=\"#666666\"];\n")
                else:
                    dot.write(f"\"{node.id}\" -> \"{dep.id}\";\n")
                dep_tag_id += 1
    dot.write("}\n")
    text = dot.getvalue()
    # Write it somewhere
    if args.dot:
        with open(args.dot, "w") as f:
            f.write(text)
    elif args.svg:
        subprocess.run(["dot", "-Tsvg", "-o", args.svg],
                              input=text, text=True, check=True)
    else:
        sys.stdout.write(text)
def matches_tag(node, dep, tag_filter):
    if not tag_filter:
        return True
    return not tag_filter.isdisjoint([t.tag_type for t in node.dep_tags[dep]])
def get_deps(nodes, root, maxdepth, reverse, tag_filter):
    if root in nodes:
        return
    nodes.add(root)
    if maxdepth != 0:
        for dep in (root.rdeps if reverse else root.deps):
            if not matches_tag(root, dep, tag_filter):
                continue
            get_deps(nodes, dep, maxdepth-1, reverse, tag_filter)
def new_module_formatter(args):
    def module_formatter(module):
        if not args.label:
            return []
        result = []
        text = json.dumps(module, default=lambda o: o.__dict__)
        for jq_filter in args.label:
            proc = subprocess.run(["jq", jq_filter],
                                  input=text, text=True, check=True, stdout=subprocess.PIPE)
            if proc.stdout:
                o = json.loads(proc.stdout)
                if type(o) == list:
                    for row in o:
                        if row:
                            result.append(row)
                elif type(o) == dict:
                    result.append(str(proc.stdout).strip())
                else:
                    if o:
                        result.append(str(o).strip())
        return result
    return module_formatter
class BetweenCommand:
    help = "Print the module graph between two nodes."
    def args(self, parser):
        parser.add_argument("module", nargs=2,
                            help="the two modules")
        print_args(parser)
    def run(self, args):
        graph = load_graph()
        print_nodes(args, graph.find_paths(args.module[0], args.module[1], set(args.tag)),
                    new_module_formatter(args))
class DepsCommand:
    help = "Print the module graph of dependencies of one or more modules"
    def args(self, parser):
        parser.add_argument("module", nargs="+",
                            help="Module to print dependencies of")
        parser.add_argument("--reverse", action="store_true",
                            help="traverse reverse dependencies")
        parser.add_argument("--depth", type=int, default=-1,
                            help="max depth of dependencies (can keep the graph size reasonable)")
        print_args(parser)
    def run(self, args):
        graph = load_graph()
        nodes = set()
        err = False
        for id in args.module:
            root = graph.nodes.get(id)
            if not root:
                sys.stderr.write(f"error: Can't find root: {id}\n")
                err = True
                continue
            get_deps(nodes, root, args.depth, args.reverse, set(args.tag))
        if err:
            sys.exit(1)
        print_nodes(args, nodes, new_module_formatter(args))
class IdCommand:
    help = "Print the id (name + variant) of matching modules"
    def args(self, parser):
        module_selection_args(parser)
    def run(self, args):
        for m in load_and_filter_modules(args):
            print(m.id)
class JsonCommand:
    help = "Print metadata about modules in json format"
    def args(self, parser):
        module_selection_args(parser)
        parser.add_argument("--list", action="store_true",
                            help="Print the results in a json list. If not set and multiple"
                            + " modules are matched, the output won't be valid json.")
    def run(self, args):
        modules = load_and_filter_modules(args)
        if args.list:
            json.dump([m for m in modules], sys.stdout, indent=4, default=lambda o: o.__dict__)
        else:
            for m in modules:
                json.dump(m, sys.stdout, indent=4, default=lambda o: o.__dict__)
                print()
class QueryCommand:
    help = "Query details about modules"
    def args(self, parser):
        module_selection_args(parser)
    def run(self, args):
        for m in load_and_filter_modules(args):
            print(m.id)
            print(f"    type:     {m.type}")
            print(f"    location: {format_source_pos(m.source_file, m.source_line)}")
            for p in m.providers:
                print(f"    provider: {format_provider(p)}")
            for d in m.deps:
                print(f"    dep:      {d.id}")
class StarCommand:
    help = "Print the dependencies and reverse dependencies of a module"
    def args(self, parser):
        parser.add_argument("module", nargs="+",
                            help="Module to print dependencies of")
        parser.add_argument("--depth", type=int, required=True,
                            help="max depth of dependencies")
        print_args(parser)
    def run(self, args):
        graph = load_graph()
        nodes = set()
        err = False
        for id in args.module:
            root = graph.nodes.get(id)
            if not root:
                sys.stderr.write(f"error: Can't find root: {id}\n")
                err = True
                continue
            get_deps(nodes, root, args.depth, False, set(args.tag))
            nodes.remove(root) # Remove it so get_deps doesn't bail out
            get_deps(nodes, root, args.depth, True, set(args.tag))
        if err:
            sys.exit(1)
        print_nodes(args, nodes, new_module_formatter(args))
COMMANDS = {
    "between": BetweenCommand(),
    "deps": DepsCommand(),
    "id": IdCommand(),
    "json": JsonCommand(),
    "query": QueryCommand(),
    "star": StarCommand(),
}
def assert_env(name):
    val = os.getenv(name)
    if not val:
        sys.stderr.write(f"{name} not set. please make sure you've run lunch.")
    return val
ANDROID_BUILD_TOP = assert_env("ANDROID_BUILD_TOP")
TARGET_PRODUCT = assert_env("TARGET_PRODUCT")
OUT_DIR = os.getenv("OUT_DIR")
if not OUT_DIR:
    OUT_DIR = "out"
if OUT_DIR[0] != "/":
    OUT_DIR = pathlib.Path(ANDROID_BUILD_TOP).joinpath(OUT_DIR)
SOONG_DEBUG_DATA_FILENAME = pathlib.Path(OUT_DIR).joinpath("soong/soong-debug-info.json")
def main():
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers(required=True, dest="command")
    for name in sorted(COMMANDS.keys()):
        command = COMMANDS[name]
        subparser = subparsers.add_parser(name, help=command.help)
        command.args(subparser)
    args = parser.parse_args()
    COMMANDS[args.command].run(args)
    sys.exit(0)
if __name__ == "__main__":
    main()