# Copyright 2019 The RE2 Authors.  All Rights Reserved.
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.
r"""A drop-in replacement for the re module.

It uses RE2 under the hood, of course, so various PCRE features
(e.g. backreferences, look-around assertions) are not supported.
See https://github.com/google/re2/wiki/Syntax for the canonical
reference, but known syntactic "gotchas" relative to Python are:

  * PCRE supports \Z and \z; RE2 supports \z; Python supports \z,
    but calls it \Z. You must rewrite \Z to \z in pattern strings.

Known differences between this module's API and the re module's API:

  * The error class does not provide any error information as attributes.
  * The Options class replaces the re module's flags with RE2's options as
    gettable/settable properties. Please see re2.h for their documentation.
  * The pattern string and the input string do not have to be the same type.
    Any str will be encoded to UTF-8.
  * The pattern string cannot be str if the options specify Latin-1 encoding.

This module's LRU cache contains a maximum of 128 regular expression objects.
Each regular expression object's underlying RE2 object uses a maximum of 8MiB
of memory (by default). Hence, this module's LRU cache uses a maximum of 1GiB
of memory (by default), but in most cases, it should use much less than that.
"""

import codecs
import functools
import itertools

import _re2


# pybind11 translates C++ exceptions to Python exceptions.
# We use that same Python exception class for consistency.
error = _re2.Error


class Options(_re2.RE2.Options):

  __slots__ = ()

  NAMES = (
      'max_mem',
      'encoding',
      'posix_syntax',
      'longest_match',
      'log_errors',
      'literal',
      'never_nl',
      'dot_nl',
      'never_capture',
      'case_sensitive',
      'perl_classes',
      'word_boundary',
      'one_line',
  )


def compile(pattern, options=None):
  if isinstance(pattern, _Regexp):
    if options:
      raise error('pattern is already compiled, so '
                  'options may not be specified')
    pattern = pattern._pattern
  options = options or Options()
  values = tuple(getattr(options, name) for name in Options.NAMES)
  return _Regexp._make(pattern, values)


def search(pattern, text, options=None):
  return compile(pattern, options=options).search(text)


def match(pattern, text, options=None):
  return compile(pattern, options=options).match(text)


def fullmatch(pattern, text, options=None):
  return compile(pattern, options=options).fullmatch(text)


def finditer(pattern, text, options=None):
  return compile(pattern, options=options).finditer(text)


def findall(pattern, text, options=None):
  return compile(pattern, options=options).findall(text)


def split(pattern, text, maxsplit=0, options=None):
  return compile(pattern, options=options).split(text, maxsplit)


def subn(pattern, repl, text, count=0, options=None):
  return compile(pattern, options=options).subn(repl, text, count)


def sub(pattern, repl, text, count=0, options=None):
  return compile(pattern, options=options).sub(repl, text, count)


def _encode(t):
  return t.encode(encoding='utf-8')


def _decode(b):
  return b.decode(encoding='utf-8')


def escape(pattern):
  if isinstance(pattern, str):
    encoded_pattern = _encode(pattern)
    escaped = _re2.RE2.QuoteMeta(encoded_pattern)
    decoded_escaped = _decode(escaped)
    return decoded_escaped
  else:
    escaped = _re2.RE2.QuoteMeta(pattern)
    return escaped


def purge():
  return _Regexp._make.cache_clear()


_Anchor = _re2.RE2.Anchor
_NULL_SPAN = (-1, -1)


class _Regexp(object):

  __slots__ = ('_pattern', '_regexp')

  @classmethod
  @functools.lru_cache(typed=True)
  def _make(cls, pattern, values):
    options = Options()
    for name, value in zip(Options.NAMES, values):
      setattr(options, name, value)
    return cls(pattern, options)

  def __init__(self, pattern, options):
    self._pattern = pattern
    if isinstance(self._pattern, str):
      if options.encoding == Options.Encoding.LATIN1:
        raise error('string type of pattern is str, but '
                    'encoding specified in options is LATIN1')
      encoded_pattern = _encode(self._pattern)
      self._regexp = _re2.RE2(encoded_pattern, options)
    else:
      self._regexp = _re2.RE2(self._pattern, options)
    if not self._regexp.ok():
      raise error(self._regexp.error())

  def __getstate__(self):
    options = {name: getattr(self.options, name) for name in Options.NAMES}
    return self._pattern, options

  def __setstate__(self, state):
    pattern, options = state
    values = tuple(options[name] for name in Options.NAMES)
    other = _Regexp._make(pattern, values)
    self._pattern = other._pattern
    self._regexp = other._regexp

  def _match(self, anchor, text, pos=None, endpos=None):
    pos = 0 if pos is None else max(0, min(pos, len(text)))
    endpos = len(text) if endpos is None else max(0, min(endpos, len(text)))
    if pos > endpos:
      return
    if isinstance(text, str):
      encoded_text = _encode(text)
      encoded_pos = _re2.CharLenToBytes(encoded_text, 0, pos)
      if endpos == len(text):
        # This is the common case.
        encoded_endpos = len(encoded_text)
      else:
        encoded_endpos = encoded_pos + _re2.CharLenToBytes(
            encoded_text, encoded_pos, endpos - pos)
      decoded_offsets = {0: 0}
      last_offset = 0
      while True:
        spans = self._regexp.Match(anchor, encoded_text, encoded_pos,
                                   encoded_endpos)
        if spans[0] == _NULL_SPAN:
          break

        # This algorithm is linear in the length of encoded_text. Specifically,
        # no matter how many groups there are for a given regular expression or
        # how many iterations through the loop there are for a given generator,
        # this algorithm uses a single, straightforward pass over encoded_text.
        offsets = sorted(set(itertools.chain(*spans)))
        if offsets[0] == -1:
          offsets = offsets[1:]
        # Discard the rest of the items because they are useless now - and we
        # could accumulate one item per str offset in the pathological case!
        decoded_offsets = {last_offset: decoded_offsets[last_offset]}
        for offset in offsets:
          decoded_offsets[offset] = (
              decoded_offsets[last_offset] +
              _re2.BytesToCharLen(encoded_text, last_offset, offset))
          last_offset = offset

        def decode(span):
          if span == _NULL_SPAN:
            return span
          return decoded_offsets[span[0]], decoded_offsets[span[1]]

        decoded_spans = [decode(span) for span in spans]
        yield _Match(self, text, pos, endpos, decoded_spans)
        if encoded_pos == encoded_endpos:
          break
        elif encoded_pos == spans[0][1]:
          # We matched the empty string at encoded_pos and would be stuck, so
          # in order to make forward progress, increment the str offset.
          encoded_pos += _re2.CharLenToBytes(encoded_text, encoded_pos, 1)
        else:
          encoded_pos = spans[0][1]
    else:
      while True:
        spans = self._regexp.Match(anchor, text, pos, endpos)
        if spans[0] == _NULL_SPAN:
          break
        yield _Match(self, text, pos, endpos, spans)
        if pos == endpos:
          break
        elif pos == spans[0][1]:
          # We matched the empty string at pos and would be stuck, so in order
          # to make forward progress, increment the bytes offset.
          pos += 1
        else:
          pos = spans[0][1]

  def search(self, text, pos=None, endpos=None):
    return next(self._match(_Anchor.UNANCHORED, text, pos, endpos), None)

  def match(self, text, pos=None, endpos=None):
    return next(self._match(_Anchor.ANCHOR_START, text, pos, endpos), None)

  def fullmatch(self, text, pos=None, endpos=None):
    return next(self._match(_Anchor.ANCHOR_BOTH, text, pos, endpos), None)

  def finditer(self, text, pos=None, endpos=None):
    return self._match(_Anchor.UNANCHORED, text, pos, endpos)

  def findall(self, text, pos=None, endpos=None):
    empty = type(text)()
    items = []
    for match in self.finditer(text, pos, endpos):
      if not self.groups:
        item = match.group()
      elif self.groups == 1:
        item = match.groups(default=empty)[0]
      else:
        item = match.groups(default=empty)
      items.append(item)
    return items

  def _split(self, cb, text, maxsplit=0):
    if maxsplit < 0:
      return [text], 0
    elif maxsplit > 0:
      matchiter = itertools.islice(self.finditer(text), maxsplit)
    else:
      matchiter = self.finditer(text)
    pieces = []
    end = 0
    numsplit = 0
    for match in matchiter:
      pieces.append(text[end:match.start()])
      pieces.extend(cb(match))
      end = match.end()
      numsplit += 1
    pieces.append(text[end:])
    return pieces, numsplit

  def split(self, text, maxsplit=0):
    cb = lambda match: [match[group] for group in range(1, self.groups + 1)]
    pieces, _ = self._split(cb, text, maxsplit)
    return pieces

  def subn(self, repl, text, count=0):
    cb = lambda match: [repl(match) if callable(repl) else match.expand(repl)]
    empty = type(text)()
    pieces, numsplit = self._split(cb, text, count)
    joined_pieces = empty.join(pieces)
    return joined_pieces, numsplit

  def sub(self, repl, text, count=0):
    joined_pieces, _ = self.subn(repl, text, count)
    return joined_pieces

  @property
  def pattern(self):
    return self._pattern

  @property
  def options(self):
    return self._regexp.options()

  @property
  def groups(self):
    return self._regexp.NumberOfCapturingGroups()

  @property
  def groupindex(self):
    groups = self._regexp.NamedCapturingGroups()
    if isinstance(self._pattern, str):
      decoded_groups = [(_decode(group), index) for group, index in groups]
      return dict(decoded_groups)
    else:
      return dict(groups)

  @property
  def programsize(self):
    return self._regexp.ProgramSize()

  @property
  def reverseprogramsize(self):
    return self._regexp.ReverseProgramSize()

  @property
  def programfanout(self):
    return self._regexp.ProgramFanout()

  @property
  def reverseprogramfanout(self):
    return self._regexp.ReverseProgramFanout()

  def possiblematchrange(self, maxlen):
    ok, min, max = self._regexp.PossibleMatchRange(maxlen)
    if not ok:
      raise error('failed to compute match range')
    return min, max


class _Match(object):

  __slots__ = ('_regexp', '_text', '_pos', '_endpos', '_spans')

  def __init__(self, regexp, text, pos, endpos, spans):
    self._regexp = regexp
    self._text = text
    self._pos = pos
    self._endpos = endpos
    self._spans = spans

  # Python prioritises three-digit octal numbers over group escapes.
  # For example, \100 should not be handled the same way as \g<10>0.
  _OCTAL_RE = compile('\\\\[0-7][0-7][0-7]')

  # Python supports \1 through \99 (inclusive) and \g<...> syntax.
  _GROUP_RE = compile('\\\\[1-9][0-9]?|\\\\g<\\w+>')

  @classmethod
  @functools.lru_cache(typed=True)
  def _split(cls, template):
    if isinstance(template, str):
      backslash = '\\'
    else:
      backslash = b'\\'
    empty = type(template)()
    pieces = [empty]
    index = template.find(backslash)
    while index != -1:
      piece, template = template[:index], template[index:]
      pieces[-1] += piece
      octal_match = cls._OCTAL_RE.match(template)
      group_match = cls._GROUP_RE.match(template)
      if (not octal_match) and group_match:
        index = group_match.end()
        piece, template = template[:index], template[index:]
        pieces.extend((piece, empty))
      else:
        # 2 isn't enough for \o, \x, \N, \u and \U escapes, but none of those
        # should contain backslashes, so break them here and then fix them at
        # the beginning of the next loop iteration or right before returning.
        index = 2
        piece, template = template[:index], template[index:]
        pieces[-1] += piece
      index = template.find(backslash)
    pieces[-1] += template
    return pieces

  def expand(self, template):
    if isinstance(template, str):
      unescape = codecs.unicode_escape_decode
    else:
      unescape = codecs.escape_decode
    empty = type(template)()
    # Make a copy so that we don't clobber the cached pieces!
    pieces = list(self._split(template))
    for index, piece in enumerate(pieces):
      if not index % 2:
        pieces[index], _ = unescape(piece)
      else:
        if len(piece) <= 3:  # \1 through \99 (inclusive)
          group = int(piece[1:])
        else:  # \g<...>
          group = piece[3:-1]
          try:
            group = int(group)
          except ValueError:
            pass
        pieces[index] = self.__getitem__(group) or empty
    joined_pieces = empty.join(pieces)
    return joined_pieces

  def __getitem__(self, group):
    if not isinstance(group, int):
      try:
        group = self._regexp.groupindex[group]
      except KeyError:
        raise IndexError('bad group name')
    if not 0 <= group <= self._regexp.groups:
      raise IndexError('bad group index')
    span = self._spans[group]
    if span == _NULL_SPAN:
      return None
    return self._text[span[0]:span[1]]

  def group(self, *groups):
    if not groups:
      groups = (0,)
    items = (self.__getitem__(group) for group in groups)
    return next(items) if len(groups) == 1 else tuple(items)

  def groups(self, default=None):
    items = []
    for group in range(1, self._regexp.groups + 1):
      item = self.__getitem__(group)
      items.append(default if item is None else item)
    return tuple(items)

  def groupdict(self, default=None):
    items = []
    for group, index in self._regexp.groupindex.items():
      item = self.__getitem__(index)
      items.append((group, default) if item is None else (group, item))
    return dict(items)

  def start(self, group=0):
    if not 0 <= group <= self._regexp.groups:
      raise IndexError('bad group index')
    return self._spans[group][0]

  def end(self, group=0):
    if not 0 <= group <= self._regexp.groups:
      raise IndexError('bad group index')
    return self._spans[group][1]

  def span(self, group=0):
    if not 0 <= group <= self._regexp.groups:
      raise IndexError('bad group index')
    return self._spans[group]

  @property
  def re(self):
    return self._regexp

  @property
  def string(self):
    return self._text

  @property
  def pos(self):
    return self._pos

  @property
  def endpos(self):
    return self._endpos

  @property
  def lastindex(self):
    max_end = -1
    max_group = None
    # We look for the rightmost right parenthesis by keeping the first group
    # that ends at max_end because that is the leftmost/outermost group when
    # there are nested groups!
    for group in range(1, self._regexp.groups + 1):
      end = self._spans[group][1]
      if max_end < end:
        max_end = end
        max_group = group
    return max_group

  @property
  def lastgroup(self):
    max_group = self.lastindex
    if not max_group:
      return None
    for group, index in self._regexp.groupindex.items():
      if max_group == index:
        return group
    return None


class Set(object):
  """A Pythonic wrapper around RE2::Set."""

  __slots__ = ('_set')

  def __init__(self, anchor, options=None):
    options = options or Options()
    self._set = _re2.Set(anchor, options)

  @classmethod
  def SearchSet(cls, options=None):
    return cls(_Anchor.UNANCHORED, options=options)

  @classmethod
  def MatchSet(cls, options=None):
    return cls(_Anchor.ANCHOR_START, options=options)

  @classmethod
  def FullMatchSet(cls, options=None):
    return cls(_Anchor.ANCHOR_BOTH, options=options)

  def Add(self, pattern):
    if isinstance(pattern, str):
      encoded_pattern = _encode(pattern)
      index = self._set.Add(encoded_pattern)
    else:
      index = self._set.Add(pattern)
    if index == -1:
      raise error('failed to add %r to Set' % pattern)
    return index

  def Compile(self):
    if not self._set.Compile():
      raise error('failed to compile Set')

  def Match(self, text):
    if isinstance(text, str):
      encoded_text = _encode(text)
      matches = self._set.Match(encoded_text)
    else:
      matches = self._set.Match(text)
    return matches or None


class Filter(object):
  """A Pythonic wrapper around FilteredRE2."""

  __slots__ = ('_filter', '_patterns')

  def __init__(self):
    self._filter = _re2.Filter()
    self._patterns = []

  def Add(self, pattern, options=None):
    options = options or Options()
    if isinstance(pattern, str):
      encoded_pattern = _encode(pattern)
      index = self._filter.Add(encoded_pattern, options)
    else:
      index = self._filter.Add(pattern, options)
    if index == -1:
      raise error('failed to add %r to Filter' % pattern)
    self._patterns.append(pattern)
    return index

  def Compile(self):
    if not self._filter.Compile():
      raise error('failed to compile Filter')

  def Match(self, text, potential=False):
    if isinstance(text, str):
      encoded_text = _encode(text)
      matches = self._filter.Match(encoded_text, potential)
    else:
      matches = self._filter.Match(text, potential)
    return matches or None

  def re(self, index):
    if not 0 <= index < len(self._patterns):
      raise IndexError('bad index')
    proxy = object.__new__(_Regexp)
    proxy._pattern = self._patterns[index]
    proxy._regexp = self._filter.GetRE2(index)
    return proxy
