# Copyright (C) 2022 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from perfetto.trace_uri_resolver.util import parse_trace_uri
from perfetto.trace_uri_resolver.util import to_list
from perfetto.trace_uri_resolver.util import _cs_list
from perfetto.trace_uri_resolver.util import and_list
from perfetto.trace_uri_resolver.util import or_list
from perfetto.trace_uri_resolver.resolver import _args_dict_from_uri
from perfetto.trace_uri_resolver.resolver import Constraint
from perfetto.trace_uri_resolver.resolver import ConstraintClass
from perfetto.trace_uri_resolver.resolver import TraceUriResolver
from perfetto.trace_uri_resolver.registry import ResolverRegistry


class SimpleResolver(TraceUriResolver):
  PREFIX = 'simple'

  def __init__(self, foo=None, bar=None):
    self.foo = foo
    self.bar = bar

  def foo_gen(self):
    yield self.foo.encode() if self.foo else b''

  def bar_gen(self):
    yield self.bar.encode() if self.bar else b''

  def resolve(self):
    return [
        TraceUriResolver.Result(self.foo_gen()),
        TraceUriResolver.Result(
            self.bar_gen(), metadata={
                'foo': self.foo,
                'bar': self.bar
            })
    ]


class RecursiveResolver(SimpleResolver):
  PREFIX = 'recursive'

  def __init__(self, foo=None, bar=None):
    super().__init__(foo=foo, bar=bar)

  def resolve(self):
    return [
        TraceUriResolver.Result(self.foo_gen()),
        TraceUriResolver.Result(
            self.bar_gen(), metadata={
                'foo': 'foo',
                'bar': 'bar'
            }),
        TraceUriResolver.Result(f'simple:foo={self.foo};bar={self.bar}'),
        TraceUriResolver.Result(SimpleResolver(foo=self.foo, bar=self.bar)),
    ]


class TestResolver(unittest.TestCase):

  def test_simple_resolve(self):
    registry = ResolverRegistry([SimpleResolver])

    res = registry.resolve('simple:foo=x;bar=y')
    self.assertEqual(len(res), 2)

    (foo_res, bar_res) = res
    self._check_resolver_result(foo_res, bar_res)

    (foo_res, bar_res) = registry.resolve(['simple:foo=x;bar=y'])
    self._check_resolver_result(foo_res, bar_res)

    resolver = SimpleResolver(foo='x', bar='y')

    (foo_res, bar_res) = registry.resolve(resolver)
    self._check_resolver_result(foo_res, bar_res)

    (foo_res, bar_res) = registry.resolve([resolver])
    self._check_resolver_result(foo_res, bar_res)

    (foo_a, bar_b, foo_x,
     bar_y) = registry.resolve(['simple:foo=a;bar=b', resolver])
    self._check_resolver_result(foo_a, bar_b, foo='a', bar='b')
    self._check_resolver_result(foo_x, bar_y)

  def test_simple_resolve_missing_arg(self):
    registry = ResolverRegistry([SimpleResolver])

    (foo_res, bar_res) = registry.resolve('simple:foo=x')
    self._check_resolver_result(foo_res, bar_res, bar=None)

    (foo_res, bar_res) = registry.resolve('simple:bar=y')
    self._check_resolver_result(foo_res, bar_res, foo=None)

    (foo_res, bar_res) = registry.resolve('simple:')
    self._check_resolver_result(foo_res, bar_res, foo=None, bar=None)

  def test_recursive_resolve(self):
    registry = ResolverRegistry([SimpleResolver])
    registry.register(RecursiveResolver)

    res = registry.resolve('recursive:foo=x;bar=y')
    self.assertEqual(len(res), 6)

    (non_rec_foo, non_rec_bar, rec_foo_str, rec_bar_str, rec_foo_obj,
     rec_bar_obj) = res

    self._check_resolver_result(
        non_rec_foo, non_rec_bar, foo_metadata='foo', bar_metadata='bar')
    self._check_resolver_result(rec_foo_str, rec_bar_str)
    self._check_resolver_result(rec_foo_obj, rec_bar_obj)

  def test_parse_trace_uri(self):
    self.assertEqual(parse_trace_uri('/foo/bar'), (None, '/foo/bar'))
    self.assertEqual(parse_trace_uri('foo/bar'), (None, 'foo/bar'))
    self.assertEqual(parse_trace_uri('/foo/b:ar'), (None, '/foo/b:ar'))
    self.assertEqual(parse_trace_uri('./foo/b:ar'), (None, './foo/b:ar'))
    self.assertEqual(parse_trace_uri('foo/b:ar'), ('foo/b', 'ar'))

  def test_to_list(self):
    self.assertEqual(to_list(None), None)
    self.assertEqual(to_list(1), [1])
    self.assertEqual(to_list('1'), ['1'])
    self.assertEqual(to_list([]), [])
    self.assertEqual(to_list([1]), [1])

  def test_cs_list(self):
    fn = 'col = {}'.format
    sep = ' || '
    self.assertEqual(_cs_list(None, fn, 'FALSE', sep), 'TRUE')
    self.assertEqual(_cs_list(None, fn, 'TRUE', sep), 'TRUE')
    self.assertEqual(_cs_list([], fn, 'FALSE', sep), 'FALSE')
    self.assertEqual(_cs_list([], fn, 'TRUE', sep), 'TRUE')
    self.assertEqual(_cs_list([1], fn, 'FALSE', sep), '(col = 1)')
    self.assertEqual(_cs_list([1, 2], fn, 'FALSE', sep), '(col = 1 || col = 2)')

  def test_and_list(self):
    fn = 'col != {}'.format
    self.assertEqual(and_list([1, 2], fn, 'FALSE'), '(col != 1 AND col != 2)')

  def test_or_list(self):
    fn = 'col = {}'.format
    self.assertEqual(or_list([1, 2], fn, 'FALSE'), '(col = 1 OR col = 2)')

  def test_args_dict_from_uri(self):
    self.assertEqual(_args_dict_from_uri('foo:', {}), {})
    self.assertEqual(_args_dict_from_uri('foo:bar=baz', {}), {
        'bar': 'baz',
    })
    self.assertEqual(
        _args_dict_from_uri('foo:key=v1,v2', {}), {'key': ['v1', 'v2']})
    self.assertEqual(
        _args_dict_from_uri('foo:bar=baz;key=v1,v2', {}), {
            'bar': 'baz',
            'key': ['v1', 'v2']
        })
    with self.assertRaises(ValueError):
      _args_dict_from_uri('foo:=v1', {})
    with self.assertRaises(ValueError):
      _args_dict_from_uri('foo:key', {})
    with self.assertRaises(ValueError):
      _args_dict_from_uri('foo:key<', {})
    with self.assertRaises(ValueError):
      _args_dict_from_uri('foo:key<v1', {})
    with self.assertRaises(ValueError):
      _args_dict_from_uri('foo:key<v1', {'key': str})

    type_hints = {'key': Constraint[str]}
    self.assertEqual(
        _args_dict_from_uri('foo:key=v1', type_hints),
        {'key': ConstraintClass('v1', ConstraintClass.Op.EQ)})
    self.assertEqual(
        _args_dict_from_uri('foo:key!=v1', type_hints),
        {'key': ConstraintClass('v1', ConstraintClass.Op.NE)})
    self.assertEqual(
        _args_dict_from_uri('foo:key<=v1', type_hints),
        {'key': ConstraintClass('v1', ConstraintClass.Op.LE)})
    self.assertEqual(
        _args_dict_from_uri('foo:key>=v1', type_hints),
        {'key': ConstraintClass('v1', ConstraintClass.Op.GE)})
    self.assertEqual(
        _args_dict_from_uri('foo:key>v1', type_hints),
        {'key': ConstraintClass('v1', ConstraintClass.Op.GT)})
    self.assertEqual(
        _args_dict_from_uri('foo:key<v1', type_hints),
        {'key': ConstraintClass('v1', ConstraintClass.Op.LT)})
    self.assertEqual(
        _args_dict_from_uri('foo:key>v1;key<=v2', type_hints), {
            'key': [
                ConstraintClass('v1', ConstraintClass.Op.GT),
                ConstraintClass('v2', ConstraintClass.Op.LE)
            ]
        })
    self.assertEqual(
        _args_dict_from_uri('foo:key>=v1;key<v4;key!=v2;key!=v3', type_hints), {
            'key': [
                ConstraintClass('v1', ConstraintClass.Op.GE),
                ConstraintClass('v4', ConstraintClass.Op.LT),
                ConstraintClass('v2', ConstraintClass.Op.NE),
                ConstraintClass('v3', ConstraintClass.Op.NE),
            ]
        })

  def _check_resolver_result(self,
                             foo_res,
                             bar_res,
                             foo='x',
                             bar='y',
                             foo_metadata=None,
                             bar_metadata=None):
    self.assertEqual(
        tuple(foo_res.generator), (foo.encode() if foo else ''.encode(),))
    self.assertEqual(
        tuple(bar_res.generator), (bar.encode() if bar else ''.encode(),))
    self.assertEqual(
        bar_res.metadata, {
            'foo': foo_metadata if foo_metadata else foo,
            'bar': bar_metadata if bar_metadata else bar
        })
