mautrix-python/mautrix/util/formatter/entity_string.py

163 lines
5.1 KiB
Python

# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any, Generic, Iterable, Sequence, Type, TypeVar
from abc import ABC, abstractmethod
from itertools import chain
from attr import dataclass
import attr
from .formatted_string import EntityType, FormattedString
class AbstractEntity(ABC):
def __init__(
self, type: EntityType, offset: int, length: int, extra_info: dict[str, Any]
) -> None:
pass
@abstractmethod
def copy(self) -> AbstractEntity:
pass
@abstractmethod
def adjust_offset(self, offset: int, max_length: int = -1) -> AbstractEntity | None:
pass
class SemiAbstractEntity(AbstractEntity, ABC):
offset: int
length: int
def adjust_offset(self, offset: int, max_length: int = -1) -> SemiAbstractEntity | None:
entity = self.copy()
entity.offset += offset
if entity.offset < 0:
entity.length += entity.offset
if entity.length < 0:
return None
entity.offset = 0
elif entity.offset > max_length > -1:
return None
elif entity.offset + entity.length > max_length > -1:
entity.length = max_length - entity.offset
return entity
@dataclass
class SimpleEntity(SemiAbstractEntity):
type: EntityType
offset: int
length: int
extra_info: dict[str, Any] = attr.ib(factory=dict)
def copy(self) -> SimpleEntity:
return attr.evolve(self)
TEntity = TypeVar("TEntity", bound=AbstractEntity)
TEntityType = TypeVar("TEntityType")
class EntityString(Generic[TEntity, TEntityType], FormattedString):
text: str
_entities: list[TEntity]
entity_class: Type[AbstractEntity] = SimpleEntity
def __init__(self, text: str = "", entities: list[TEntity] = None) -> None:
self.text = text
self._entities = entities or []
def __repr__(self) -> str:
return f"{self.__class__.__name__}(text='{self.text}', entities={self.entities})"
def __str__(self) -> str:
return self.text
@property
def entities(self) -> list[TEntity]:
return self._entities
@entities.setter
def entities(self, val: Iterable[TEntity]) -> None:
self._entities = [entity for entity in val if entity is not None]
def _offset_entities(self, offset: int) -> EntityString:
self.entities = (entity.adjust_offset(offset, len(self.text)) for entity in self.entities)
return self
def append(self, *args: str | FormattedString) -> EntityString:
for msg in args:
if isinstance(msg, EntityString):
self.entities += (entity.adjust_offset(len(self.text)) for entity in msg.entities)
self.text += msg.text
else:
self.text += str(msg)
return self
def prepend(self, *args: str | FormattedString) -> EntityString:
for msg in args:
if isinstance(msg, EntityString):
self.text = msg.text + self.text
self.entities = chain(
msg.entities, (entity.adjust_offset(len(msg.text)) for entity in self.entities)
)
else:
text = str(msg)
self.text = text + self.text
self.entities = (entity.adjust_offset(len(text)) for entity in self.entities)
return self
def format(
self, entity_type: TEntityType, offset: int = None, length: int = None, **kwargs
) -> EntityString:
self.entities.append(
self.entity_class(
type=entity_type,
offset=offset or 0,
length=length or len(self.text),
extra_info=kwargs,
)
)
return self
def trim(self) -> EntityString:
orig_len = len(self.text)
self.text = self.text.lstrip()
diff = orig_len - len(self.text)
self.text = self.text.rstrip()
self._offset_entities(-diff)
return self
def split(self, separator, max_items: int = -1) -> list[EntityString]:
text_parts = self.text.split(separator, max_items - 1)
output: list[EntityString] = []
offset = 0
for part in text_parts:
msg = type(self)(part)
msg.entities = (entity.adjust_offset(-offset, len(part)) for entity in self.entities)
output.append(msg)
offset += len(part)
offset += len(separator)
return output
@classmethod
def join(cls, items: Sequence[str | EntityString], separator: str = " ") -> EntityString:
main = cls()
for msg in items:
if not isinstance(msg, EntityString):
msg = cls(text=str(msg))
main.entities += [entity.adjust_offset(len(main.text)) for entity in msg.entities]
main.text += msg.text + separator
if len(separator) > 0:
main.text = main.text[: -len(separator)]
return main