# Copyright 2024 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Tests for pw_config_loader."""

from pathlib import Path
import tempfile
from typing import Any
import unittest

from pw_config_loader import yaml_config_loader_mixin
import yaml

# pylint: disable=no-member,no-self-use


class YamlConfigLoader(yaml_config_loader_mixin.YamlConfigLoaderMixin):
    @property
    def config(self) -> dict[str, Any]:
        return self._config


class TestOneFile(unittest.TestCase):
    """Tests for loading a config section from one file."""

    def setUp(self):
        self._title = 'title'

    def init(self, config: dict[str, Any]) -> dict[str, Any]:
        loader = YamlConfigLoader()
        with tempfile.TemporaryDirectory() as folder:
            path = Path(folder, 'foo.yaml')
            path.write_bytes(yaml.safe_dump(config).encode())
            loader.config_init(
                user_file=path,
                config_section_title=self._title,
            )
            return loader.config

    def test_normal(self):
        content = {'a': 1, 'b': 2}
        config = self.init({self._title: content})
        self.assertEqual(content['a'], config['a'])
        self.assertEqual(content['b'], config['b'])

    def test_config_title(self):
        content = {'a': 1, 'b': 2, 'config_title': self._title}
        config = self.init(content)
        self.assertEqual(content['a'], config['a'])
        self.assertEqual(content['b'], config['b'])


class TestMultipleFiles(unittest.TestCase):
    """Tests for loading config sections from multiple files."""

    def init(
        self,
        project_config: dict[str, Any],
        project_user_config: dict[str, Any],
        user_config: dict[str, Any],
    ) -> dict[str, Any]:
        """Write config files then read and parse them."""

        loader = YamlConfigLoader()
        title = 'title'

        with tempfile.TemporaryDirectory() as folder:
            path = Path(folder)

            user_path = path / 'user.yaml'
            user_path.write_text(yaml.safe_dump({title: user_config}))

            project_user_path = path / 'project_user.yaml'
            project_user_path.write_text(
                yaml.safe_dump({title: project_user_config})
            )

            project_path = path / 'project.yaml'
            project_path.write_text(yaml.safe_dump({title: project_config}))

            loader.config_init(
                user_file=user_path,
                project_user_file=project_user_path,
                project_file=project_path,
                config_section_title=title,
            )

        return loader.config

    def test_user_override(self):
        config = self.init(
            user_config={'a': 1},
            project_user_config={'a': 2},
            project_config={'a': 3},
        )
        self.assertEqual(config['a'], 1)

    def test_project_user_override(self):
        config = self.init(
            user_config={},
            project_user_config={'a': 2},
            project_config={'a': 3},
        )
        self.assertEqual(config['a'], 2)

    def test_not_overridden(self):
        config = self.init(
            user_config={},
            project_user_config={},
            project_config={'a': 3},
        )
        self.assertEqual(config['a'], 3)

    def test_different_keys(self):
        config = self.init(
            user_config={'a': 1},
            project_user_config={'b': 2},
            project_config={'c': 3},
        )
        self.assertEqual(config['a'], 1)
        self.assertEqual(config['b'], 2)
        self.assertEqual(config['c'], 3)


class TestNestedTitle(unittest.TestCase):
    """Tests for nested config section loading."""

    def setUp(self):
        self._title = ('title', 'subtitle', 'subsubtitle', 'subsubsubtitle')

    def init(self, config: dict[str, Any]) -> dict[str, Any]:
        loader = YamlConfigLoader()
        with tempfile.TemporaryDirectory() as folder:
            path = Path(folder, 'foo.yaml')
            path.write_bytes(yaml.safe_dump(config).encode())
            loader.config_init(
                user_file=path,
                config_section_title=self._title,
            )
            return loader.config

    def test_normal(self):
        content = {'a': 1, 'b': 2}
        for part in reversed(self._title):
            content = {part: content}
        config = self.init(content)
        self.assertEqual(config['a'], 1)
        self.assertEqual(config['b'], 2)

    def test_config_title(self):
        content = {'a': 1, 'b': 2, 'config_title': '.'.join(self._title)}
        config = self.init(content)
        self.assertEqual(config['a'], 1)
        self.assertEqual(config['b'], 2)


class CustomOverloadYamlConfigLoader(
    yaml_config_loader_mixin.YamlConfigLoaderMixin
):
    """Custom config loader that implements handle_overloaded_value()."""

    @property
    def config(self) -> dict[str, Any]:
        return self._config

    def handle_overloaded_value(  # pylint: disable=no-self-use
        self,
        key: str,
        stage: yaml_config_loader_mixin.Stage,
        original_value: Any,
        overriding_value: Any,
    ):
        if key == 'extend':
            if original_value:
                return original_value + overriding_value
            return overriding_value

        if key == 'extend_sort':
            if original_value:
                result = original_value + overriding_value
            else:
                result = overriding_value
            return sorted(result)

        if key == 'do_not_override':
            if original_value:
                return original_value

        if key == 'max':
            return max(original_value, overriding_value)

        return overriding_value


class TestOverloading(unittest.TestCase):
    """Tests for envparse.EnvironmentParser."""

    def init(
        self,
        project_config: dict[str, Any],
        project_user_config: dict[str, Any],
        user_config: dict[str, Any],
    ) -> dict[str, Any]:
        """Write config files then read and parse them."""

        loader = CustomOverloadYamlConfigLoader()
        title = 'title'

        with tempfile.TemporaryDirectory() as folder:
            path = Path(folder)

            user_path = path / 'user.yaml'
            user_path.write_text(yaml.safe_dump({title: user_config}))

            project_user_path = path / 'project_user.yaml'
            project_user_path.write_text(
                yaml.safe_dump({title: project_user_config})
            )

            project_path = path / 'project.yaml'
            project_path.write_text(yaml.safe_dump({title: project_config}))

            loader.config_init(
                user_file=user_path,
                project_user_file=project_user_path,
                project_file=project_path,
                config_section_title=title,
            )

        return loader.config

    def test_lists(self):
        config = self.init(
            project_config={
                'extend': list('abc'),
                'extend_sort': list('az'),
                'do_not_override': ['persists'],
                'override': ['hidden'],
            },
            project_user_config={
                'extend': list('def'),
                'extend_sort': list('by'),
                'do_not_override': ['ignored'],
                'override': ['ignored'],
            },
            user_config={
                'extend': list('ghi'),
                'extend_sort': list('cx'),
                'do_not_override': ['ignored_2'],
                'override': ['overrides'],
            },
        )
        self.assertEqual(config['extend'], list('abcdefghi'))
        self.assertEqual(config['extend_sort'], list('abcxyz'))
        self.assertEqual(config['do_not_override'], ['persists'])
        self.assertEqual(config['override'], ['overrides'])

    def test_scalars(self):
        config = self.init(
            project_config={'extend': 'abc', 'max': 1},
            project_user_config={'extend': 'def', 'max': 3},
            user_config={'extend': 'ghi', 'max': 2},
        )
        self.assertEqual(config['extend'], 'abcdefghi')
        self.assertEqual(config['max'], 3)


if __name__ == '__main__':
    unittest.main()
