163 lines
5.1 KiB
Python
163 lines
5.1 KiB
Python
# Copyright (c) 2022 Tulir Asokan
|
|
#
|
|
# This Source Code Form is subject to the terms of the Mozilla Public
|
|
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Generic, Iterable, Sequence, Type, TypeVar
|
|
from abc import ABC, abstractmethod
|
|
from itertools import chain
|
|
|
|
from attr import dataclass
|
|
import attr
|
|
|
|
from .formatted_string import EntityType, FormattedString
|
|
|
|
|
|
class AbstractEntity(ABC):
|
|
def __init__(
|
|
self, type: EntityType, offset: int, length: int, extra_info: dict[str, Any]
|
|
) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def copy(self) -> AbstractEntity:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def adjust_offset(self, offset: int, max_length: int = -1) -> AbstractEntity | None:
|
|
pass
|
|
|
|
|
|
class SemiAbstractEntity(AbstractEntity, ABC):
|
|
offset: int
|
|
length: int
|
|
|
|
def adjust_offset(self, offset: int, max_length: int = -1) -> SemiAbstractEntity | None:
|
|
entity = self.copy()
|
|
entity.offset += offset
|
|
if entity.offset < 0:
|
|
entity.length += entity.offset
|
|
if entity.length < 0:
|
|
return None
|
|
entity.offset = 0
|
|
elif entity.offset > max_length > -1:
|
|
return None
|
|
elif entity.offset + entity.length > max_length > -1:
|
|
entity.length = max_length - entity.offset
|
|
return entity
|
|
|
|
|
|
@dataclass
|
|
class SimpleEntity(SemiAbstractEntity):
|
|
type: EntityType
|
|
offset: int
|
|
length: int
|
|
extra_info: dict[str, Any] = attr.ib(factory=dict)
|
|
|
|
def copy(self) -> SimpleEntity:
|
|
return attr.evolve(self)
|
|
|
|
|
|
TEntity = TypeVar("TEntity", bound=AbstractEntity)
|
|
TEntityType = TypeVar("TEntityType")
|
|
|
|
|
|
class EntityString(Generic[TEntity, TEntityType], FormattedString):
|
|
text: str
|
|
_entities: list[TEntity]
|
|
entity_class: Type[AbstractEntity] = SimpleEntity
|
|
|
|
def __init__(self, text: str = "", entities: list[TEntity] = None) -> None:
|
|
self.text = text
|
|
self._entities = entities or []
|
|
|
|
def __repr__(self) -> str:
|
|
return f"{self.__class__.__name__}(text='{self.text}', entities={self.entities})"
|
|
|
|
def __str__(self) -> str:
|
|
return self.text
|
|
|
|
@property
|
|
def entities(self) -> list[TEntity]:
|
|
return self._entities
|
|
|
|
@entities.setter
|
|
def entities(self, val: Iterable[TEntity]) -> None:
|
|
self._entities = [entity for entity in val if entity is not None]
|
|
|
|
def _offset_entities(self, offset: int) -> EntityString:
|
|
self.entities = (entity.adjust_offset(offset, len(self.text)) for entity in self.entities)
|
|
return self
|
|
|
|
def append(self, *args: str | FormattedString) -> EntityString:
|
|
for msg in args:
|
|
if isinstance(msg, EntityString):
|
|
self.entities += (entity.adjust_offset(len(self.text)) for entity in msg.entities)
|
|
self.text += msg.text
|
|
else:
|
|
self.text += str(msg)
|
|
return self
|
|
|
|
def prepend(self, *args: str | FormattedString) -> EntityString:
|
|
for msg in args:
|
|
if isinstance(msg, EntityString):
|
|
self.text = msg.text + self.text
|
|
self.entities = chain(
|
|
msg.entities, (entity.adjust_offset(len(msg.text)) for entity in self.entities)
|
|
)
|
|
else:
|
|
text = str(msg)
|
|
self.text = text + self.text
|
|
self.entities = (entity.adjust_offset(len(text)) for entity in self.entities)
|
|
return self
|
|
|
|
def format(
|
|
self, entity_type: TEntityType, offset: int = None, length: int = None, **kwargs
|
|
) -> EntityString:
|
|
self.entities.append(
|
|
self.entity_class(
|
|
type=entity_type,
|
|
offset=offset or 0,
|
|
length=length or len(self.text),
|
|
extra_info=kwargs,
|
|
)
|
|
)
|
|
return self
|
|
|
|
def trim(self) -> EntityString:
|
|
orig_len = len(self.text)
|
|
self.text = self.text.lstrip()
|
|
diff = orig_len - len(self.text)
|
|
self.text = self.text.rstrip()
|
|
self._offset_entities(-diff)
|
|
return self
|
|
|
|
def split(self, separator, max_items: int = -1) -> list[EntityString]:
|
|
text_parts = self.text.split(separator, max_items - 1)
|
|
output: list[EntityString] = []
|
|
|
|
offset = 0
|
|
for part in text_parts:
|
|
msg = type(self)(part)
|
|
msg.entities = (entity.adjust_offset(-offset, len(part)) for entity in self.entities)
|
|
output.append(msg)
|
|
|
|
offset += len(part)
|
|
offset += len(separator)
|
|
|
|
return output
|
|
|
|
@classmethod
|
|
def join(cls, items: Sequence[str | EntityString], separator: str = " ") -> EntityString:
|
|
main = cls()
|
|
for msg in items:
|
|
if not isinstance(msg, EntityString):
|
|
msg = cls(text=str(msg))
|
|
main.entities += [entity.adjust_offset(len(main.text)) for entity in msg.entities]
|
|
main.text += msg.text + separator
|
|
if len(separator) > 0:
|
|
main.text = main.text[: -len(separator)]
|
|
return main
|