import abc
import enum
from typing import Tuple, Dict, Any, Union, List, AsyncIterator, Type

import rich.progress

from redbot.core.utils._internal_utils import RichIndefiniteBarColumn

__all__ = ["BaseDriver", "IdentifierData", "ConfigCategory"]

[docs]class ConfigCategory(str, enum.Enum): """Represents config category.""" #: Global category. GLOBAL = "GLOBAL" #: Guild category. GUILD = "GUILD" #: Channel category. CHANNEL = "TEXTCHANNEL" #: Role category. ROLE = "ROLE" #: User category. USER = "USER" #: Member category. MEMBER = "MEMBER"
[docs] @classmethod def get_pkey_info( cls, category: Union[str, "ConfigCategory"], custom_group_data: Dict[str, int] ) -> Tuple[int, bool]: """Get the full primary key length for the given category, and whether or not the category is a custom category. """ try: # noinspection PyArgumentList category_obj = cls(category) except ValueError: return custom_group_data[category], True else: return _CATEGORY_PKEY_COUNTS[category_obj], False
_CATEGORY_PKEY_COUNTS = { ConfigCategory.GLOBAL: 0, ConfigCategory.GUILD: 1, ConfigCategory.CHANNEL: 1, ConfigCategory.ROLE: 1, ConfigCategory.USER: 1, ConfigCategory.MEMBER: 2, } class IdentifierData: def __init__( self, cog_name: str, uuid: str, category: str, primary_key: Tuple[str, ...], identifiers: Tuple[str, ...], primary_key_len: int, is_custom: bool = False, ): self._cog_name = cog_name self._uuid = uuid self._category = category self._primary_key = primary_key self._identifiers = identifiers self.primary_key_len = primary_key_len self._is_custom = is_custom @property def cog_name(self) -> str: return self._cog_name @property def uuid(self) -> str: return self._uuid @property def category(self) -> str: return self._category @property def primary_key(self) -> Tuple[str, ...]: return self._primary_key @property def identifiers(self) -> Tuple[str, ...]: return self._identifiers @property def is_custom(self) -> bool: return self._is_custom def __repr__(self) -> str: return ( f"<IdentifierData cog_name={self.cog_name} uuid={self.uuid} category={self.category} " f"primary_key={self.primary_key} identifiers={self.identifiers}>" ) def __eq__(self, other) -> bool: if not isinstance(other, IdentifierData): return False return ( self.uuid == other.uuid and self.category == other.category and self.primary_key == other.primary_key and self.identifiers == other.identifiers ) def __hash__(self) -> int: return hash((self.uuid, self.category, self.primary_key, self.identifiers)) def get_child(self, *keys: str) -> "IdentifierData": if not all(isinstance(i, str) for i in keys): raise ValueError("Identifiers must be strings.") primary_keys = self.primary_key identifiers = self.identifiers num_missing_pkeys = self.primary_key_len - len(self.primary_key) if num_missing_pkeys > 0: primary_keys += keys[:num_missing_pkeys] if len(keys) > num_missing_pkeys: identifiers += keys[num_missing_pkeys:] return IdentifierData( self.cog_name, self.uuid, self.category, primary_keys, identifiers, self.primary_key_len, self.is_custom, ) def add_identifier(self, *identifier: str) -> "IdentifierData": if not all(isinstance(i, str) for i in identifier): raise ValueError("Identifiers must be strings.") return IdentifierData( self.cog_name, self.uuid, self.category, self.primary_key, self.identifiers + identifier, self.primary_key_len, is_custom=self.is_custom, ) def to_tuple(self) -> Tuple[str, ...]: return tuple( filter( None, (self.cog_name, self.uuid, self.category, *self.primary_key, *self.identifiers), ) )
[docs]class BaseDriver(abc.ABC): def __init__(self, cog_name: str, identifier: str, **kwargs): self.cog_name = cog_name self.unique_cog_identifier = identifier
[docs] @classmethod @abc.abstractmethod async def initialize(cls, **storage_details) -> None: """ Initialize this driver. Parameters ---------- **storage_details The storage details required to initialize this driver. Should be the same as :func:`data_manager.storage_details` Raises ------ MissingExtraRequirements If initializing the driver requires an extra which isn't installed. """ raise NotImplementedError
[docs] @classmethod @abc.abstractmethod async def teardown(cls) -> None: """ Tear down this driver. """ raise NotImplementedError
[docs] @staticmethod @abc.abstractmethod def get_config_details() -> Dict[str, Any]: """ Asks users for additional configuration information necessary to use this config driver. Returns ------- Dict[str, Any] Dictionary of configuration details. """ raise NotImplementedError
[docs] @abc.abstractmethod async def get(self, identifier_data: IdentifierData) -> Any: """ Finds the value indicate by the given identifiers. Parameters ---------- identifier_data Returns ------- Any Stored value. """ raise NotImplementedError
[docs] @abc.abstractmethod async def set(self, identifier_data: IdentifierData, value=None) -> None: """ Sets the value of the key indicated by the given identifiers. Parameters ---------- identifier_data value Any JSON serializable python object. """ raise NotImplementedError
[docs] @abc.abstractmethod async def clear(self, identifier_data: IdentifierData) -> None: """ Clears out the value specified by the given identifiers. Equivalent to using ``del`` on a dict. Parameters ---------- identifier_data """ raise NotImplementedError
[docs] @classmethod @abc.abstractmethod def aiter_cogs(cls) -> AsyncIterator[Tuple[str, str]]: """Get info for cogs which have data stored on this backend. Yields ------ Tuple[str, str] Asynchronously yields (cog_name, cog_identifier) tuples. """ raise NotImplementedError
[docs] @classmethod async def migrate_to( cls, new_driver_cls: Type["BaseDriver"], all_custom_group_data: Dict[str, Dict[str, Dict[str, int]]], ) -> None: """Migrate data from this backend to another. Both drivers must be initialized beforehand. This will only move the data - no instance metadata is modified as a result of this operation. Parameters ---------- new_driver_cls Subclass of `BaseDriver`. all_custom_group_data : Dict[str, Dict[str, Dict[str, int]]] Dict mapping cog names, to cog IDs, to custom groups, to primary key lengths. """ # Backend-agnostic method of migrating from one driver to another. with rich.progress.Progress( rich.progress.SpinnerColumn(), rich.progress.TextColumn("[progress.description]{task.description}"), RichIndefiniteBarColumn(), rich.progress.TextColumn("{task.completed} cogs processed"), rich.progress.TimeElapsedColumn(), ) as progress: cog_count = 0 tid = progress.add_task("[yellow]Migrating", completed=cog_count, total=cog_count + 1) async for cog_name, cog_id in cls.aiter_cogs(): progress.console.print(f"Working on {cog_name}...") this_driver = cls(cog_name, cog_id) other_driver = new_driver_cls(cog_name, cog_id) custom_group_data = all_custom_group_data.get(cog_name, {}).get(cog_id, {}) exported_data = await this_driver.export_data(custom_group_data) await other_driver.import_data(exported_data, custom_group_data) cog_count += 1 progress.update(tid, completed=cog_count, total=cog_count + 1) progress.update(tid, total=cog_count) print()
[docs] @classmethod async def delete_all_data(cls, **kwargs) -> None: """Delete all data being stored by this driver. The driver must be initialized before this operation. The BaseDriver provides a generic method which may be overridden by subclasses. Parameters ---------- **kwargs Driver-specific kwargs to change the way this method operates. """ async for cog_name, cog_id in cls.aiter_cogs(): driver = cls(cog_name, cog_id) await driver.clear(IdentifierData(cog_name, cog_id, "", (), (), 0))
@staticmethod def _split_primary_key( category: Union[ConfigCategory, str], custom_group_data: Dict[str, int], data: Dict[str, Any], ) -> List[Tuple[Tuple[str, ...], Dict[str, Any]]]: pkey_len = ConfigCategory.get_pkey_info(category, custom_group_data)[0] if pkey_len == 0: return [((), data)] def flatten(levels_remaining, currdata, parent_key=()): items = [] for _k, _v in currdata.items(): new_key = parent_key + (_k,) if levels_remaining > 1: items.extend(flatten(levels_remaining - 1, _v, new_key).items()) else: items.append((new_key, _v)) return dict(items) ret = [] for k, v in flatten(pkey_len, data).items(): ret.append((k, v)) return ret async def export_data( self, custom_group_data: Dict[str, int] ) -> List[Tuple[str, Dict[str, Any]]]: categories = [c.value for c in ConfigCategory] categories.extend(custom_group_data.keys()) ret = [] for c in categories: ident_data = IdentifierData( self.cog_name, self.unique_cog_identifier, c, (), (), *ConfigCategory.get_pkey_info(c, custom_group_data), ) try: data = await self.get(ident_data) except KeyError: continue ret.append((c, data)) return ret async def import_data( self, cog_data: List[Tuple[str, Dict[str, Any]]], custom_group_data: Dict[str, int] ) -> None: for category, all_data in cog_data: splitted_pkey = self._split_primary_key(category, custom_group_data, all_data) for pkey, data in splitted_pkey: ident_data = IdentifierData( self.cog_name, self.unique_cog_identifier, category, pkey, (), *ConfigCategory.get_pkey_info(category, custom_group_data), ) await self.set(ident_data, data)