from __future__ import annotations from copy import copy from typing import Any from tomlkit.exceptions import ParseError from tomlkit.exceptions import UnexpectedCharError from tomlkit.toml_char import TOMLChar class _State: def __init__( self, source: Source, save_marker: bool | None = False, restore: bool | None = False, ) -> None: self._source = source self._save_marker = save_marker self.restore = restore def __enter__(self) -> _State: # Entering this context manager - save the state self._chars = copy(self._source._chars) self._idx = self._source._idx self._current = self._source._current self._marker = self._source._marker return self def __exit__(self, exception_type, exception_val, trace): # Exiting this context manager - restore the prior state if self.restore or exception_type: self._source._chars = self._chars self._source._idx = self._idx self._source._current = self._current if self._save_marker: self._source._marker = self._marker class _StateHandler: """ State preserver for the Parser. """ def __init__(self, source: Source) -> None: self._source = source self._states = [] def __call__(self, *args, **kwargs): return _State(self._source, *args, **kwargs) def __enter__(self) -> _State: state = self() self._states.append(state) return state.__enter__() def __exit__(self, exception_type, exception_val, trace): state = self._states.pop() return state.__exit__(exception_type, exception_val, trace) class Source(str): EOF = TOMLChar("\0") def __init__(self, _: str) -> None: super().__init__() # Collection of TOMLChars self._chars = iter([(i, TOMLChar(c)) for i, c in enumerate(self)]) self._idx = 0 self._marker = 0 self._current = TOMLChar("") self._state = _StateHandler(self) self.inc() def reset(self): # initialize both idx and current self.inc() # reset marker self.mark() @property def state(self) -> _StateHandler: return self._state @property def idx(self) -> int: return self._idx @property def current(self) -> TOMLChar: return self._current @property def marker(self) -> int: return self._marker def extract(self) -> str: """ Extracts the value between marker and index """ return self[self._marker : self._idx] def inc(self, exception: type[ParseError] | None = None) -> bool: """ Increments the parser if the end of the input has not been reached. Returns whether or not it was able to advance. """ try: self._idx, self._current = next(self._chars) return True except StopIteration: self._idx = len(self) self._current = self.EOF if exception: raise self.parse_error(exception) return False def inc_n(self, n: int, exception: type[ParseError] | None = None) -> bool: """ Increments the parser by n characters if the end of the input has not been reached. """ return all(self.inc(exception=exception) for _ in range(n)) def consume(self, chars, min=0, max=-1): """ Consume chars until min/max is satisfied is valid. """ while self.current in chars and max != 0: min -= 1 max -= 1 if not self.inc(): break # failed to consume minimum number of characters if min > 0: raise self.parse_error(UnexpectedCharError, self.current) def end(self) -> bool: """ Returns True if the parser has reached the end of the input. """ return self._current is self.EOF def mark(self) -> None: """ Sets the marker to the index's current position """ self._marker = self._idx def parse_error( self, exception: type[ParseError] = ParseError, *args: Any, **kwargs: Any, ) -> ParseError: """ Creates a generic "parse error" at the current position. """ line, col = self._to_linecol() return exception(line, col, *args, **kwargs) def _to_linecol(self) -> tuple[int, int]: cur = 0 for i, line in enumerate(self.splitlines()): if cur + len(line) + 1 > self.idx: return (i + 1, self.idx - cur) cur += len(line) + 1 return len(self.splitlines()), 0