Source code for dsgrid.registry.dimension_registry_manager

"""Manages the registry for dimensions"""

import logging
from collections import defaultdict
from pathlib import Path
from typing import Optional, Union
from uuid import uuid4

from prettytable import PrettyTable
from sqlalchemy import Connection

from dsgrid.config.dimension_config_factory import get_dimension_config, load_dimension_config
from dsgrid.config.dimension_config import DimensionConfig
from dsgrid.config.dimensions_config import DimensionsConfig
from dsgrid.config.dimensions import (
    TimeDimensionBaseModel,
    DimensionReferenceModel,
)
from dsgrid.registry.common import ConfigKey, RegistryType, VersionUpdateType
from dsgrid.registry.registry_interface import DimensionRegistryInterface
from dsgrid.utils.filters import transform_and_validate_filters, matches_filters
from dsgrid.utils.timing import timer_stats_collector, track_timing
from dsgrid.utils.utilities import display_table
from .registration_context import RegistrationContext
from .dimension_update_checker import DimensionUpdateChecker
from .registry_manager_base import RegistryManagerBase


logger = logging.getLogger(__name__)


[docs] class DimensionRegistryManager(RegistryManagerBase): """Manages registered dimensions.""" def __init__(self, path, params): super().__init__(path, params) self._dimensions = {} # key = ConfigKey, value = DimensionConfig @staticmethod def config_class(): return DimensionConfig @property def db(self) -> DimensionRegistryInterface: return self._db @db.setter def db(self, db: DimensionRegistryInterface): self._db = db @staticmethod def name(): return "Dimensions" def _replace_duplicates(self, config: DimensionsConfig, context: RegistrationContext): hashes = defaultdict(list) time_dims = {} for dimension in self._db.iter_models(context.connection, all_versions=True): if isinstance(dimension, TimeDimensionBaseModel): time_dims[dimension.id] = dimension else: hashes[dimension.file_hash].append(dimension) existing_ids = set() for i, dim in enumerate(config.model.dimensions): replace_dim = False existing = None if isinstance(dim, TimeDimensionBaseModel): existing = self._get_matching_time_dimension(time_dims.values(), dim) if existing is not None: replace_dim = True elif dim.file_hash in hashes: for existing in hashes[dim.file_hash]: if ( dim.dimension_type == existing.dimension_type and dim.name == existing.name and dim.display_name == existing.display_name ): replace_dim = True break if not replace_dim: logger.info( "Register new dimension even though records are duplicate with " "one or more existing dimensions. New name/display_name=%s/%s", dim.name, dim.display_name, ) if replace_dim: assert existing is not None logger.info( "Replace %s with existing dimension ID %s", dim.name, existing.dimension_id ) config.model.dimensions[i] = existing existing_ids.add(existing.dimension_id) return existing_ids @staticmethod def _get_matching_time_dimension(existing_dims, new_dim): for time_dim in existing_dims: if type(time_dim) is not type(new_dim): continue match = True exclude = set(("dimension_id", "version", "id")) for field in type(new_dim).model_fields: if field not in exclude and getattr(new_dim, field) != getattr(time_dim, field): match = False break if match: return time_dim return None
[docs] def get_by_id( self, config_id: str, version: Optional[str] = None, conn: Optional[Connection] = None ): if version is None: version = self._db.get_latest_version(conn, config_id) key = ConfigKey(config_id, version) dimension = self._dimensions.get(key) if dimension is not None: return dimension if version is None: model = self.db.get_latest(conn, config_id) else: model = self.db.get_by_version(conn, config_id, version) config = get_dimension_config(model) self._dimensions[key] = config return config
def list_ids(self, dimension_type=None, conn: Optional[Connection] = None): """Return the dimension ids for the given type. Parameters ---------- dimension_type : DimensionType Returns ------- list """ if dimension_type is None: ids = super().list_ids(conn) else: ids = [ x.dimension_id for x in self.db.iter_models( conn, filter_config={"dimension_type": dimension_type} ) ] ids.sort() return ids def load_dimensions(self, dimension_references, conn: Optional[Connection] = None): """Load dimensions from the database. Parameters ---------- dimension_references : list iterable of DimensionReferenceModel instances Returns ------- dict ConfigKey to DimensionConfig """ dimensions = {} for dim in dimension_references: key = ConfigKey(dim.dimension_id, dim.version) dimensions[key] = self.get_by_id(dim.dimension_id, version=dim.version, conn=conn) return dimensions @track_timing(timer_stats_collector) def register_from_config( self, config: DimensionsConfig, context: RegistrationContext, ) -> list[str]: return self._register(config, context) @track_timing(timer_stats_collector) def register(self, config_file, submitter: str, log_message: str) -> list[str]: with RegistrationContext( self.db, log_message, VersionUpdateType.MAJOR, submitter ) as context: config = DimensionsConfig.load(config_file) return self.register_from_config(config, context=context) def _register(self, config, context: RegistrationContext) -> list[str]: existing_ids = self._replace_duplicates(config, context) registered_dimension_ids = [] # This function will either register the dimension specified by each model or re-use an # existing ID. The returned list must be in the same order as the list of models. final_dimension_ids = [] for dim in config.model.dimensions: if dim.id is None: assert dim.dimension_id is None dim.dimension_id = str(uuid4()) dim.version = "1.0.0" dim = self.db.insert(context.connection, dim, context.registration) final_dimension_ids.append(dim.dimension_id) registered_dimension_ids.append(dim.dimension_id) logger.info( "%s Registered dimension id=%s type=%s version=%s name=%s", self._log_offline_mode_prefix(), dim.id, dim.dimension_type.value, dim.version, dim.name, ) else: if dim.dimension_id not in existing_ids: msg = f"Bug: {dim.dimension_id=} should have been in existing_ids" raise Exception(msg) final_dimension_ids.append(dim.dimension_id) logger.info("Registered %s dimensions", len(config.model.dimensions)) context.add_ids(RegistryType.DIMENSION, registered_dimension_ids, self) return final_dimension_ids def make_dimension_references(self, conn: Connection, dimension_ids: list[str]): """Return a list of dimension references from a list of registered dimension IDs. This assumes that the latest version of the dimensions will be used because they were just created. Parameters ---------- dimension_ids : list[str] """ refs = [] for dim_id in dimension_ids: dim = self.db.get_latest(conn, dim_id) refs.append( DimensionReferenceModel( dimension_id=dim_id, dimension_type=dim.dimension_type, version=dim.version, ) ) return refs
[docs] def show( self, conn: Optional[Connection] = None, filters: Optional[list[str]] = None, max_width: Optional[Union[int, dict]] = None, drop_fields: Optional[list[str]] = None, dimension_ids: Optional[set[str]] = None, return_table: bool = False, **kwargs, ): """Show registry in PrettyTable Parameters ---------- filters : list or tuple List of filter expressions for reigstry content (e.g., filters=["Submitter==USER", "Description contains comstock"]) max_width Max column width in PrettyTable, specify as a single value or as a dict of values by field name drop_fields List of field names not to show """ if filters: logger.info("List registered dimensions for: %s", filters) table = PrettyTable(title="Dimensions") all_field_names = ( "Type", "Query Name", "ID", "Version", "Date", "Submitter", "Description", ) if drop_fields is None: table.field_names = all_field_names else: table.field_names = tuple(x for x in all_field_names if x not in drop_fields) if max_width is None: table._max_width = { "ID": 40, "Date": 10, "Description": 40, } if isinstance(max_width, int): table.max_width = max_width elif isinstance(max_width, dict): table._max_width = max_width if filters: transformed_filters = transform_and_validate_filters(filters) field_to_index = {x: i for i, x in enumerate(table.field_names)} rows = [] for model in self.db.iter_models(conn): registration = self.db.get_registration(conn, model) if dimension_ids and model.dimension_id not in dimension_ids: continue all_fields = ( model.dimension_type.value, model.dimension_query_name, model.dimension_id, model.version, registration.timestamp.strftime("%Y-%m-%d %H:%M:%S"), registration.submitter, registration.log_message, ) if drop_fields is None: row = all_fields else: row = tuple( y for (x, y) in zip(all_field_names, all_fields) if x not in drop_fields ) if not filters or matches_filters(row, field_to_index, transformed_filters): rows.append(row) rows.sort(key=lambda x: x[0]) table.add_rows(rows) table.align = "l" if return_table: return table display_table(table)
def update_from_file( self, config_file: Path, dimension_id: str, submitter: str, update_type: VersionUpdateType, log_message: str, version: str, ): with RegistrationContext(self.db, log_message, update_type, submitter) as context: config = load_dimension_config(config_file) self._check_update(context.connection, config, dimension_id, version) self.update_with_context(config, context) @track_timing(timer_stats_collector) def update( self, config, update_type: VersionUpdateType, log_message: str, submitter: Optional[str] = None, ) -> DimensionConfig: with RegistrationContext(self.db, log_message, update_type, submitter) as context: return self.update_with_context(config, context) def update_with_context(self, config, context: RegistrationContext) -> DimensionConfig: old_config = self.get_by_id(config.model.dimension_id, conn=context.connection) checker = DimensionUpdateChecker(old_config.model, config.model) checker.run() cur_version = old_config.model.version old_key = ConfigKey(config.model.dimension_id, cur_version) model = self._update_config(config, context) new_key = ConfigKey(model.dimension_id, model.version) self._dimensions.pop(old_key, None) self._dimensions[new_key] = get_dimension_config(model) return self._dimensions[new_key] def finalize_registration(self, conn: Connection, config_ids: set[str], error_occurred: bool): if error_occurred: for key in [x for x in self._dimensions if x.id in config_ids]: self._dimensions.pop(key) def remove(self, dimension_id, conn: Optional[Connection] = None): self.db.delete_all(conn, dimension_id) for key in [x for x in self._dimensions if x.id == dimension_id]: self._dimensions.pop(key) logger.info("Removed %s from the registry.", dimension_id)