# Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Testing that flags validators framework does work.

This file tests that each flag validator called when it should be, and that
failed validator will throw an exception, etc.
"""

import warnings

from absl.flags import _defines
from absl.flags import _exceptions
from absl.flags import _flagvalues
from absl.flags import _validators
from absl.testing import absltest


class SingleFlagValidatorTest(absltest.TestCase):
  """Testing _validators.register_validator() method."""

  def setUp(self):
    super(SingleFlagValidatorTest, self).setUp()
    self.flag_values = _flagvalues.FlagValues()
    self.call_args = []

  def test_success(self):
    def checker(x):
      self.call_args.append(x)
      return True
    _defines.DEFINE_integer(
        'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
    _validators.register_validator(
        'test_flag',
        checker,
        message='Errors happen',
        flag_values=self.flag_values)

    argv = ('./program',)
    self.flag_values(argv)
    self.assertIsNone(self.flag_values.test_flag)
    self.flag_values.test_flag = 2
    self.assertEqual(2, self.flag_values.test_flag)
    self.assertEqual([None, 2], self.call_args)

  def test_success_holder(self):
    def checker(x):
      self.call_args.append(x)
      return True

    flag_holder = _defines.DEFINE_integer(
        'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
    _validators.register_validator(
        flag_holder,
        checker,
        message='Errors happen',
        flag_values=self.flag_values)

    argv = ('./program',)
    self.flag_values(argv)
    self.assertIsNone(self.flag_values.test_flag)
    self.flag_values.test_flag = 2
    self.assertEqual(2, self.flag_values.test_flag)
    self.assertEqual([None, 2], self.call_args)

  def test_success_holder_infer_flagvalues(self):
    def checker(x):
      self.call_args.append(x)
      return True

    flag_holder = _defines.DEFINE_integer(
        'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
    _validators.register_validator(
        flag_holder,
        checker,
        message='Errors happen')

    argv = ('./program',)
    self.flag_values(argv)
    self.assertIsNone(self.flag_values.test_flag)
    self.flag_values.test_flag = 2
    self.assertEqual(2, self.flag_values.test_flag)
    self.assertEqual([None, 2], self.call_args)

  def test_default_value_not_used_success(self):
    def checker(x):
      self.call_args.append(x)
      return True
    _defines.DEFINE_integer(
        'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
    _validators.register_validator(
        'test_flag',
        checker,
        message='Errors happen',
        flag_values=self.flag_values)

    argv = ('./program', '--test_flag=1')
    self.flag_values(argv)
    self.assertEqual(1, self.flag_values.test_flag)
    self.assertEqual([1], self.call_args)

  def test_validator_not_called_when_other_flag_is_changed(self):
    def checker(x):
      self.call_args.append(x)
      return True
    _defines.DEFINE_integer(
        'test_flag', 1, 'Usual integer flag', flag_values=self.flag_values)
    _defines.DEFINE_integer(
        'other_flag', 2, 'Other integer flag', flag_values=self.flag_values)
    _validators.register_validator(
        'test_flag',
        checker,
        message='Errors happen',
        flag_values=self.flag_values)

    argv = ('./program',)
    self.flag_values(argv)
    self.assertEqual(1, self.flag_values.test_flag)
    self.flag_values.other_flag = 3
    self.assertEqual([1], self.call_args)

  def test_exception_raised_if_checker_fails(self):
    def checker(x):
      self.call_args.append(x)
      return x == 1
    _defines.DEFINE_integer(
        'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
    _validators.register_validator(
        'test_flag',
        checker,
        message='Errors happen',
        flag_values=self.flag_values)

    argv = ('./program', '--test_flag=1')
    self.flag_values(argv)
    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
      self.flag_values.test_flag = 2
    self.assertEqual('flag --test_flag=2: Errors happen', str(cm.exception))
    self.assertEqual([1, 2], self.call_args)

  def test_exception_raised_if_checker_raises_exception(self):
    def checker(x):
      self.call_args.append(x)
      if x == 1:
        return True
      raise _exceptions.ValidationError('Specific message')

    _defines.DEFINE_integer(
        'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
    _validators.register_validator(
        'test_flag',
        checker,
        message='Errors happen',
        flag_values=self.flag_values)

    argv = ('./program', '--test_flag=1')
    self.flag_values(argv)
    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
      self.flag_values.test_flag = 2
    self.assertEqual('flag --test_flag=2: Specific message', str(cm.exception))
    self.assertEqual([1, 2], self.call_args)

  def test_error_message_when_checker_returns_false_on_start(self):
    def checker(x):
      self.call_args.append(x)
      return False
    _defines.DEFINE_integer(
        'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
    _validators.register_validator(
        'test_flag',
        checker,
        message='Errors happen',
        flag_values=self.flag_values)

    argv = ('./program', '--test_flag=1')
    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
      self.flag_values(argv)
    self.assertEqual('flag --test_flag=1: Errors happen', str(cm.exception))
    self.assertEqual([1], self.call_args)

  def test_error_message_when_checker_raises_exception_on_start(self):
    def checker(x):
      self.call_args.append(x)
      raise _exceptions.ValidationError('Specific message')

    _defines.DEFINE_integer(
        'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
    _validators.register_validator(
        'test_flag',
        checker,
        message='Errors happen',
        flag_values=self.flag_values)

    argv = ('./program', '--test_flag=1')
    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
      self.flag_values(argv)
    self.assertEqual('flag --test_flag=1: Specific message', str(cm.exception))
    self.assertEqual([1], self.call_args)

  def test_validators_checked_in_order(self):

    def required(x):
      self.calls.append('required')
      return x is not None

    def even(x):
      self.calls.append('even')
      return x % 2 == 0

    self.calls = []
    self._define_flag_and_validators(required, even)
    self.assertEqual(['required', 'even'], self.calls)

    self.calls = []
    self._define_flag_and_validators(even, required)
    self.assertEqual(['even', 'required'], self.calls)

  def _define_flag_and_validators(self, first_validator, second_validator):
    local_flags = _flagvalues.FlagValues()
    _defines.DEFINE_integer(
        'test_flag', 2, 'test flag', flag_values=local_flags)
    _validators.register_validator(
        'test_flag', first_validator, message='', flag_values=local_flags)
    _validators.register_validator(
        'test_flag', second_validator, message='', flag_values=local_flags)
    argv = ('./program',)
    local_flags(argv)

  def test_validator_as_decorator(self):
    _defines.DEFINE_integer(
        'test_flag', None, 'Simple integer flag', flag_values=self.flag_values)

    @_validators.validator('test_flag', flag_values=self.flag_values)
    def checker(x):
      self.call_args.append(x)
      return True

    argv = ('./program',)
    self.flag_values(argv)
    self.assertIsNone(self.flag_values.test_flag)
    self.flag_values.test_flag = 2
    self.assertEqual(2, self.flag_values.test_flag)
    self.assertEqual([None, 2], self.call_args)
    # Check that 'Checker' is still a function and has not been replaced.
    self.assertTrue(checker(3))
    self.assertEqual([None, 2, 3], self.call_args)

  def test_mismatching_flagvalues(self):

    def checker(x):
      self.call_args.append(x)
      return True

    flag_holder = _defines.DEFINE_integer(
        'test_flag',
        None,
        'Usual integer flag',
        flag_values=_flagvalues.FlagValues())
    expected = (
        'flag_values must not be customized when operating on a FlagHolder')
    with self.assertRaisesWithLiteralMatch(ValueError, expected):
      _validators.register_validator(
          flag_holder,
          checker,
          message='Errors happen',
          flag_values=self.flag_values)


class MultiFlagsValidatorTest(absltest.TestCase):
  """Test flags multi-flag validators."""

  def setUp(self):
    super(MultiFlagsValidatorTest, self).setUp()
    self.flag_values = _flagvalues.FlagValues()
    self.call_args = []
    self.foo_holder = _defines.DEFINE_integer(
        'foo', 1, 'Usual integer flag', flag_values=self.flag_values)
    self.bar_holder = _defines.DEFINE_integer(
        'bar', 2, 'Usual integer flag', flag_values=self.flag_values)

  def test_success(self):
    def checker(flags_dict):
      self.call_args.append(flags_dict)
      return True
    _validators.register_multi_flags_validator(
        ['foo', 'bar'], checker, flag_values=self.flag_values)

    argv = ('./program', '--bar=2')
    self.flag_values(argv)
    self.assertEqual(1, self.flag_values.foo)
    self.assertEqual(2, self.flag_values.bar)
    self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args)
    self.flag_values.foo = 3
    self.assertEqual(3, self.flag_values.foo)
    self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 3, 'bar': 2}],
                     self.call_args)

  def test_success_holder(self):

    def checker(flags_dict):
      self.call_args.append(flags_dict)
      return True

    _validators.register_multi_flags_validator(
        [self.foo_holder, self.bar_holder],
        checker,
        flag_values=self.flag_values)

    argv = ('./program', '--bar=2')
    self.flag_values(argv)
    self.assertEqual(1, self.flag_values.foo)
    self.assertEqual(2, self.flag_values.bar)
    self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args)
    self.flag_values.foo = 3
    self.assertEqual(3, self.flag_values.foo)
    self.assertEqual([{
        'foo': 1,
        'bar': 2
    }, {
        'foo': 3,
        'bar': 2
    }], self.call_args)

  def test_success_holder_infer_flagvalues(self):
    def checker(flags_dict):
      self.call_args.append(flags_dict)
      return True

    _validators.register_multi_flags_validator(
        [self.foo_holder, self.bar_holder], checker)

    argv = ('./program', '--bar=2')
    self.flag_values(argv)
    self.assertEqual(1, self.flag_values.foo)
    self.assertEqual(2, self.flag_values.bar)
    self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args)
    self.flag_values.foo = 3
    self.assertEqual(3, self.flag_values.foo)
    self.assertEqual([{
        'foo': 1,
        'bar': 2
    }, {
        'foo': 3,
        'bar': 2
    }], self.call_args)

  def test_validator_not_called_when_other_flag_is_changed(self):
    def checker(flags_dict):
      self.call_args.append(flags_dict)
      return True
    _defines.DEFINE_integer(
        'other_flag', 3, 'Other integer flag', flag_values=self.flag_values)
    _validators.register_multi_flags_validator(
        ['foo', 'bar'], checker, flag_values=self.flag_values)

    argv = ('./program',)
    self.flag_values(argv)
    self.flag_values.other_flag = 3
    self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args)

  def test_exception_raised_if_checker_fails(self):
    def checker(flags_dict):
      self.call_args.append(flags_dict)
      values = flags_dict.values()
      # Make sure all the flags have different values.
      return len(set(values)) == len(values)
    _validators.register_multi_flags_validator(
        ['foo', 'bar'],
        checker,
        message='Errors happen',
        flag_values=self.flag_values)

    argv = ('./program',)
    self.flag_values(argv)
    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
      self.flag_values.bar = 1
    self.assertEqual('flags foo=1, bar=1: Errors happen', str(cm.exception))
    self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}],
                     self.call_args)

  def test_exception_raised_if_checker_raises_exception(self):
    def checker(flags_dict):
      self.call_args.append(flags_dict)
      values = flags_dict.values()
      # Make sure all the flags have different values.
      if len(set(values)) != len(values):
        raise _exceptions.ValidationError('Specific message')
      return True

    _validators.register_multi_flags_validator(
        ['foo', 'bar'],
        checker,
        message='Errors happen',
        flag_values=self.flag_values)

    argv = ('./program',)
    self.flag_values(argv)
    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
      self.flag_values.bar = 1
    self.assertEqual('flags foo=1, bar=1: Specific message', str(cm.exception))
    self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}],
                     self.call_args)

  def test_decorator(self):
    @_validators.multi_flags_validator(
        ['foo', 'bar'], message='Errors happen', flag_values=self.flag_values)
    def checker(flags_dict):  # pylint: disable=unused-variable
      self.call_args.append(flags_dict)
      values = flags_dict.values()
      # Make sure all the flags have different values.
      return len(set(values)) == len(values)

    argv = ('./program',)
    self.flag_values(argv)
    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
      self.flag_values.bar = 1
    self.assertEqual('flags foo=1, bar=1: Errors happen', str(cm.exception))
    self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}],
                     self.call_args)

  def test_mismatching_flagvalues(self):

    def checker(flags_dict):
      self.call_args.append(flags_dict)
      values = flags_dict.values()
      # Make sure all the flags have different values.
      return len(set(values)) == len(values)

    other_holder = _defines.DEFINE_integer(
        'other_flag',
        3,
        'Other integer flag',
        flag_values=_flagvalues.FlagValues())
    expected = (
        'multiple FlagValues instances used in invocation. '
        'FlagHolders must be registered to the same FlagValues instance as '
        'do flag names, if provided.')
    with self.assertRaisesWithLiteralMatch(ValueError, expected):
      _validators.register_multi_flags_validator(
          [self.foo_holder, self.bar_holder, other_holder],
          checker,
          message='Errors happen',
          flag_values=self.flag_values)


class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):

  def setUp(self):
    super(MarkFlagsAsMutualExclusiveTest, self).setUp()
    self.flag_values = _flagvalues.FlagValues()

    self.flag_one_holder = _defines.DEFINE_string(
        'flag_one', None, 'flag one', flag_values=self.flag_values)
    self.flag_two_holder = _defines.DEFINE_string(
        'flag_two', None, 'flag two', flag_values=self.flag_values)
    _defines.DEFINE_string(
        'flag_three', None, 'flag three', flag_values=self.flag_values)
    _defines.DEFINE_integer(
        'int_flag_one', None, 'int flag one', flag_values=self.flag_values)
    _defines.DEFINE_integer(
        'int_flag_two', None, 'int flag two', flag_values=self.flag_values)
    _defines.DEFINE_multi_string(
        'multi_flag_one', None, 'multi flag one', flag_values=self.flag_values)
    _defines.DEFINE_multi_string(
        'multi_flag_two', None, 'multi flag two', flag_values=self.flag_values)
    _defines.DEFINE_boolean(
        'flag_not_none', False, 'false default', flag_values=self.flag_values)

  def _mark_flags_as_mutually_exclusive(self, flag_names, required):
    _validators.mark_flags_as_mutual_exclusive(
        flag_names, required=required, flag_values=self.flag_values)

  def test_no_flags_present(self):
    self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], False)
    argv = ('./program',)

    self.flag_values(argv)
    self.assertIsNone(self.flag_values.flag_one)
    self.assertIsNone(self.flag_values.flag_two)

  def test_no_flags_present_holder(self):
    self._mark_flags_as_mutually_exclusive(
        [self.flag_one_holder, self.flag_two_holder], False)
    argv = ('./program',)

    self.flag_values(argv)
    self.assertIsNone(self.flag_values.flag_one)
    self.assertIsNone(self.flag_values.flag_two)

  def test_no_flags_present_mixed(self):
    self._mark_flags_as_mutually_exclusive([self.flag_one_holder, 'flag_two'],
                                           False)
    argv = ('./program',)

    self.flag_values(argv)
    self.assertIsNone(self.flag_values.flag_one)
    self.assertIsNone(self.flag_values.flag_two)

  def test_no_flags_present_required(self):
    self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True)
    argv = ('./program',)
    expected = (
        'flags flag_one=None, flag_two=None: '
        'Exactly one of (flag_one, flag_two) must have a value other than '
        'None.')

    self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
                                      expected, self.flag_values, argv)

  def test_one_flag_present(self):
    self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], False)
    self.flag_values(('./program', '--flag_one=1'))
    self.assertEqual('1', self.flag_values.flag_one)

  def test_one_flag_present_required(self):
    self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True)
    self.flag_values(('./program', '--flag_two=2'))
    self.assertEqual('2', self.flag_values.flag_two)

  def test_one_flag_zero_required(self):
    self._mark_flags_as_mutually_exclusive(
        ['int_flag_one', 'int_flag_two'], True)
    self.flag_values(('./program', '--int_flag_one=0'))
    self.assertEqual(0, self.flag_values.int_flag_one)

  def test_mutual_exclusion_with_extra_flags(self):
    self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True)
    argv = ('./program', '--flag_two=2', '--flag_three=3')

    self.flag_values(argv)
    self.assertEqual('2', self.flag_values.flag_two)
    self.assertEqual('3', self.flag_values.flag_three)

  def test_mutual_exclusion_with_zero(self):
    self._mark_flags_as_mutually_exclusive(
        ['int_flag_one', 'int_flag_two'], False)
    argv = ('./program', '--int_flag_one=0', '--int_flag_two=0')
    expected = (
        'flags int_flag_one=0, int_flag_two=0: '
        'At most one of (int_flag_one, int_flag_two) must have a value other '
        'than None.')

    self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
                                      expected, self.flag_values, argv)

  def test_multiple_flags_present(self):
    self._mark_flags_as_mutually_exclusive(
        ['flag_one', 'flag_two', 'flag_three'], False)
    argv = ('./program', '--flag_one=1', '--flag_two=2', '--flag_three=3')
    expected = (
        'flags flag_one=1, flag_two=2, flag_three=3: '
        'At most one of (flag_one, flag_two, flag_three) must have a value '
        'other than None.')

    self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
                                      expected, self.flag_values, argv)

  def test_multiple_flags_present_required(self):
    self._mark_flags_as_mutually_exclusive(
        ['flag_one', 'flag_two', 'flag_three'], True)
    argv = ('./program', '--flag_one=1', '--flag_two=2', '--flag_three=3')
    expected = (
        'flags flag_one=1, flag_two=2, flag_three=3: '
        'Exactly one of (flag_one, flag_two, flag_three) must have a value '
        'other than None.')

    self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
                                      expected, self.flag_values, argv)

  def test_no_multiflags_present(self):
    self._mark_flags_as_mutually_exclusive(
        ['multi_flag_one', 'multi_flag_two'], False)
    argv = ('./program',)
    self.flag_values(argv)
    self.assertIsNone(self.flag_values.multi_flag_one)
    self.assertIsNone(self.flag_values.multi_flag_two)

  def test_no_multistring_flags_present_required(self):
    self._mark_flags_as_mutually_exclusive(
        ['multi_flag_one', 'multi_flag_two'], True)
    argv = ('./program',)
    expected = (
        'flags multi_flag_one=None, multi_flag_two=None: '
        'Exactly one of (multi_flag_one, multi_flag_two) must have a value '
        'other than None.')

    self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
                                      expected, self.flag_values, argv)

  def test_one_multiflag_present(self):
    self._mark_flags_as_mutually_exclusive(
        ['multi_flag_one', 'multi_flag_two'], True)
    self.flag_values(('./program', '--multi_flag_one=1'))
    self.assertEqual(['1'], self.flag_values.multi_flag_one)

  def test_one_multiflag_present_repeated(self):
    self._mark_flags_as_mutually_exclusive(
        ['multi_flag_one', 'multi_flag_two'], True)
    self.flag_values(('./program', '--multi_flag_one=1', '--multi_flag_one=1b'))
    self.assertEqual(['1', '1b'], self.flag_values.multi_flag_one)

  def test_multiple_multiflags_present(self):
    self._mark_flags_as_mutually_exclusive(
        ['multi_flag_one', 'multi_flag_two'], False)
    argv = ('./program', '--multi_flag_one=1', '--multi_flag_two=2')
    expected = (
        "flags multi_flag_one=['1'], multi_flag_two=['2']: "
        'At most one of (multi_flag_one, multi_flag_two) must have a value '
        'other than None.')

    self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
                                      expected, self.flag_values, argv)

  def test_multiple_multiflags_present_required(self):
    self._mark_flags_as_mutually_exclusive(
        ['multi_flag_one', 'multi_flag_two'], True)
    argv = ('./program', '--multi_flag_one=1', '--multi_flag_two=2')
    expected = (
        "flags multi_flag_one=['1'], multi_flag_two=['2']: "
        'Exactly one of (multi_flag_one, multi_flag_two) must have a value '
        'other than None.')

    self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
                                      expected, self.flag_values, argv)

  def test_flag_default_not_none_warning(self):
    with warnings.catch_warnings(record=True) as caught_warnings:
      warnings.simplefilter('always')
      self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_not_none'],
                                             False)
    self.assertLen(caught_warnings, 1)
    self.assertIn('--flag_not_none has a non-None default value',
                  str(caught_warnings[0].message))

  def test_multiple_flagvalues(self):
    other_holder = _defines.DEFINE_boolean(
        'other_flagvalues',
        False,
        'other ',
        flag_values=_flagvalues.FlagValues())
    expected = (
        'multiple FlagValues instances used in invocation. '
        'FlagHolders must be registered to the same FlagValues instance as '
        'do flag names, if provided.')
    with self.assertRaisesWithLiteralMatch(ValueError, expected):
      self._mark_flags_as_mutually_exclusive(
          [self.flag_one_holder, other_holder], False)


class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase):

  def setUp(self):
    super(MarkBoolFlagsAsMutualExclusiveTest, self).setUp()
    self.flag_values = _flagvalues.FlagValues()

    self.false_1_holder = _defines.DEFINE_boolean(
        'false_1', False, 'default false 1', flag_values=self.flag_values)
    self.false_2_holder = _defines.DEFINE_boolean(
        'false_2', False, 'default false 2', flag_values=self.flag_values)
    self.true_1_holder = _defines.DEFINE_boolean(
        'true_1', True, 'default true 1', flag_values=self.flag_values)
    self.non_bool_holder = _defines.DEFINE_integer(
        'non_bool', None, 'non bool', flag_values=self.flag_values)

  def _mark_bool_flags_as_mutually_exclusive(self, flag_names, required):
    _validators.mark_bool_flags_as_mutual_exclusive(
        flag_names, required=required, flag_values=self.flag_values)

  def test_no_flags_present(self):
    self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], False)
    self.flag_values(('./program',))
    self.assertEqual(False, self.flag_values.false_1)
    self.assertEqual(False, self.flag_values.false_2)

  def test_no_flags_present_holder(self):
    self._mark_bool_flags_as_mutually_exclusive(
        [self.false_1_holder, self.false_2_holder], False)
    self.flag_values(('./program',))
    self.assertEqual(False, self.flag_values.false_1)
    self.assertEqual(False, self.flag_values.false_2)

  def test_no_flags_present_mixed(self):
    self._mark_bool_flags_as_mutually_exclusive(
        [self.false_1_holder, 'false_2'], False)
    self.flag_values(('./program',))
    self.assertEqual(False, self.flag_values.false_1)
    self.assertEqual(False, self.flag_values.false_2)

  def test_no_flags_present_required(self):
    self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], True)
    argv = ('./program',)
    expected = (
        'flags false_1=False, false_2=False: '
        'Exactly one of (false_1, false_2) must be True.')

    self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
                                      expected, self.flag_values, argv)

  def test_no_flags_present_with_default_true_required(self):
    self._mark_bool_flags_as_mutually_exclusive(['false_1', 'true_1'], True)
    self.flag_values(('./program',))
    self.assertEqual(False, self.flag_values.false_1)
    self.assertEqual(True, self.flag_values.true_1)

  def test_two_flags_true(self):
    self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], False)
    argv = ('./program', '--false_1', '--false_2')
    expected = (
        'flags false_1=True, false_2=True: At most one of (false_1, '
        'false_2) must be True.')

    self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
                                      expected, self.flag_values, argv)

  def test_non_bool_flag(self):
    expected = ('Flag --non_bool is not Boolean, which is required for flags '
                'used in mark_bool_flags_as_mutual_exclusive.')
    with self.assertRaisesWithLiteralMatch(_exceptions.ValidationError,
                                           expected):
      self._mark_bool_flags_as_mutually_exclusive(['false_1', 'non_bool'],
                                                  False)

  def test_multiple_flagvalues(self):
    other_bool_holder = _defines.DEFINE_boolean(
        'other_bool', False, 'other bool', flag_values=_flagvalues.FlagValues())
    expected = (
        'multiple FlagValues instances used in invocation. '
        'FlagHolders must be registered to the same FlagValues instance as '
        'do flag names, if provided.')
    with self.assertRaisesWithLiteralMatch(ValueError, expected):
      self._mark_bool_flags_as_mutually_exclusive(
          [self.false_1_holder, other_bool_holder], False)


class MarkFlagAsRequiredTest(absltest.TestCase):

  def setUp(self):
    super(MarkFlagAsRequiredTest, self).setUp()
    self.flag_values = _flagvalues.FlagValues()

  def test_success(self):
    _defines.DEFINE_string(
        'string_flag', None, 'string flag', flag_values=self.flag_values)
    _validators.mark_flag_as_required(
        'string_flag', flag_values=self.flag_values)
    argv = ('./program', '--string_flag=value')
    self.flag_values(argv)
    self.assertEqual('value', self.flag_values.string_flag)

  def test_success_holder(self):
    holder = _defines.DEFINE_string(
        'string_flag', None, 'string flag', flag_values=self.flag_values)
    _validators.mark_flag_as_required(holder, flag_values=self.flag_values)
    argv = ('./program', '--string_flag=value')
    self.flag_values(argv)
    self.assertEqual('value', self.flag_values.string_flag)

  def test_success_holder_infer_flagvalues(self):
    holder = _defines.DEFINE_string(
        'string_flag', None, 'string flag', flag_values=self.flag_values)
    _validators.mark_flag_as_required(holder)
    argv = ('./program', '--string_flag=value')
    self.flag_values(argv)
    self.assertEqual('value', self.flag_values.string_flag)

  def test_catch_none_as_default(self):
    _defines.DEFINE_string(
        'string_flag', None, 'string flag', flag_values=self.flag_values)
    _validators.mark_flag_as_required(
        'string_flag', flag_values=self.flag_values)
    argv = ('./program',)
    expected = (
        r'flag --string_flag=None: Flag --string_flag must have a value other '
        r'than None\.')
    with self.assertRaisesRegex(_exceptions.IllegalFlagValueError, expected):
      self.flag_values(argv)

  def test_catch_setting_none_after_program_start(self):
    _defines.DEFINE_string(
        'string_flag', 'value', 'string flag', flag_values=self.flag_values)
    _validators.mark_flag_as_required(
        'string_flag', flag_values=self.flag_values)
    argv = ('./program',)
    self.flag_values(argv)
    self.assertEqual('value', self.flag_values.string_flag)
    expected = ('flag --string_flag=None: Flag --string_flag must have a value '
                'other than None.')
    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
      self.flag_values.string_flag = None
    self.assertEqual(expected, str(cm.exception))

  def test_flag_default_not_none_warning(self):
    _defines.DEFINE_string(
        'flag_not_none', '', 'empty default', flag_values=self.flag_values)
    with warnings.catch_warnings(record=True) as caught_warnings:
      warnings.simplefilter('always')
      _validators.mark_flag_as_required(
          'flag_not_none', flag_values=self.flag_values)

    self.assertLen(caught_warnings, 1)
    self.assertIn('--flag_not_none has a non-None default value',
                  str(caught_warnings[0].message))

  def test_mismatching_flagvalues(self):
    flag_holder = _defines.DEFINE_string(
        'string_flag',
        'value',
        'string flag',
        flag_values=_flagvalues.FlagValues())
    expected = (
        'flag_values must not be customized when operating on a FlagHolder')
    with self.assertRaisesWithLiteralMatch(ValueError, expected):
      _validators.mark_flag_as_required(
          flag_holder, flag_values=self.flag_values)


class MarkFlagsAsRequiredTest(absltest.TestCase):

  def setUp(self):
    super(MarkFlagsAsRequiredTest, self).setUp()
    self.flag_values = _flagvalues.FlagValues()

  def test_success(self):
    _defines.DEFINE_string(
        'string_flag_1', None, 'string flag 1', flag_values=self.flag_values)
    _defines.DEFINE_string(
        'string_flag_2', None, 'string flag 2', flag_values=self.flag_values)
    flag_names = ['string_flag_1', 'string_flag_2']
    _validators.mark_flags_as_required(flag_names, flag_values=self.flag_values)
    argv = ('./program', '--string_flag_1=value_1', '--string_flag_2=value_2')
    self.flag_values(argv)
    self.assertEqual('value_1', self.flag_values.string_flag_1)
    self.assertEqual('value_2', self.flag_values.string_flag_2)

  def test_success_holders(self):
    flag_1_holder = _defines.DEFINE_string(
        'string_flag_1', None, 'string flag 1', flag_values=self.flag_values)
    flag_2_holder = _defines.DEFINE_string(
        'string_flag_2', None, 'string flag 2', flag_values=self.flag_values)
    _validators.mark_flags_as_required([flag_1_holder, flag_2_holder],
                                       flag_values=self.flag_values)
    argv = ('./program', '--string_flag_1=value_1', '--string_flag_2=value_2')
    self.flag_values(argv)
    self.assertEqual('value_1', self.flag_values.string_flag_1)
    self.assertEqual('value_2', self.flag_values.string_flag_2)

  def test_catch_none_as_default(self):
    _defines.DEFINE_string(
        'string_flag_1', None, 'string flag 1', flag_values=self.flag_values)
    _defines.DEFINE_string(
        'string_flag_2', None, 'string flag 2', flag_values=self.flag_values)
    _validators.mark_flags_as_required(
        ['string_flag_1', 'string_flag_2'], flag_values=self.flag_values)
    argv = ('./program', '--string_flag_1=value_1')
    expected = (
        r'flag --string_flag_2=None: Flag --string_flag_2 must have a value '
        r'other than None\.')
    with self.assertRaisesRegex(_exceptions.IllegalFlagValueError, expected):
      self.flag_values(argv)

  def test_catch_setting_none_after_program_start(self):
    _defines.DEFINE_string(
        'string_flag_1',
        'value_1',
        'string flag 1',
        flag_values=self.flag_values)
    _defines.DEFINE_string(
        'string_flag_2',
        'value_2',
        'string flag 2',
        flag_values=self.flag_values)
    _validators.mark_flags_as_required(
        ['string_flag_1', 'string_flag_2'], flag_values=self.flag_values)
    argv = ('./program', '--string_flag_1=value_1')
    self.flag_values(argv)
    self.assertEqual('value_1', self.flag_values.string_flag_1)
    expected = (
        'flag --string_flag_1=None: Flag --string_flag_1 must have a value '
        'other than None.')
    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
      self.flag_values.string_flag_1 = None
    self.assertEqual(expected, str(cm.exception))

  def test_catch_multiple_flags_as_none_at_program_start(self):
    _defines.DEFINE_float(
        'float_flag_1',
        None,
        'string flag 1',
        flag_values=self.flag_values)
    _defines.DEFINE_float(
        'float_flag_2',
        None,
        'string flag 2',
        flag_values=self.flag_values)
    _validators.mark_flags_as_required(
        ['float_flag_1', 'float_flag_2'], flag_values=self.flag_values)
    argv = ('./program', '')
    expected = (
        'flag --float_flag_1=None: Flag --float_flag_1 must have a value '
        'other than None.\n'
        'flag --float_flag_2=None: Flag --float_flag_2 must have a value '
        'other than None.')
    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
      self.flag_values(argv)
    self.assertEqual(expected, str(cm.exception))

  def test_fail_fast_single_flag_and_skip_remaining_validators(self):
    def raise_unexpected_error(x):
      del x
      raise _exceptions.ValidationError('Should not be raised.')
    _defines.DEFINE_float(
        'flag_1', None, 'flag 1', flag_values=self.flag_values)
    _defines.DEFINE_float(
        'flag_2', 4.2, 'flag 2', flag_values=self.flag_values)
    _validators.mark_flag_as_required('flag_1', flag_values=self.flag_values)
    _validators.register_validator(
        'flag_1', raise_unexpected_error, flag_values=self.flag_values)
    _validators.register_multi_flags_validator(['flag_2', 'flag_1'],
                                               raise_unexpected_error,
                                               flag_values=self.flag_values)
    argv = ('./program', '')
    expected = (
        'flag --flag_1=None: Flag --flag_1 must have a value other than None.')
    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
      self.flag_values(argv)
    self.assertEqual(expected, str(cm.exception))

  def test_fail_fast_multi_flag_and_skip_remaining_validators(self):
    def raise_expected_error(x):
      del x
      raise _exceptions.ValidationError('Expected error.')
    def raise_unexpected_error(x):
      del x
      raise _exceptions.ValidationError('Got unexpected error.')
    _defines.DEFINE_float(
        'flag_1', 5.1, 'flag 1', flag_values=self.flag_values)
    _defines.DEFINE_float(
        'flag_2', 10.0, 'flag 2', flag_values=self.flag_values)
    _validators.register_multi_flags_validator(['flag_1', 'flag_2'],
                                               raise_expected_error,
                                               flag_values=self.flag_values)
    _validators.register_multi_flags_validator(['flag_2', 'flag_1'],
                                               raise_unexpected_error,
                                               flag_values=self.flag_values)
    _validators.register_validator(
        'flag_1', raise_unexpected_error, flag_values=self.flag_values)
    _validators.register_validator(
        'flag_2', raise_unexpected_error, flag_values=self.flag_values)
    argv = ('./program', '')
    expected = ('flags flag_1=5.1, flag_2=10.0: Expected error.')
    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
      self.flag_values(argv)
    self.assertEqual(expected, str(cm.exception))


if __name__ == '__main__':
  absltest.main()
