Source code for dsgrid.registry.dimension_registry_manager

"""Manages the registry for dimensions"""

import getpass
import logging
from typing import Optional, Union

from prettytable import PrettyTable

from dsgrid.config.dimension_config_factory import get_dimension_config, load_dimension_config
from dsgrid.config.dimension_config import DimensionConfig
from dsgrid.config.dimensions_config import DimensionsConfig
from dsgrid.config.dimensions import (
    TimeDimensionBaseModel,
    DimensionReferenceModel,
)
from dsgrid.registry.common import ConfigKey, make_initial_config_registration
from dsgrid.utils.filters import transform_and_validate_filters, matches_filters
from dsgrid.utils.timing import timer_stats_collector, track_timing
from dsgrid.utils.utilities import display_table
from .common import RegistryType
from .registration_context import RegistrationContext
from .dimension_update_checker import DimensionUpdateChecker
from .registry_interface import DimensionRegistryInterface
from .registry_manager_base import RegistryManagerBase


logger = logging.getLogger(__name__)


[docs] class DimensionRegistryManager(RegistryManagerBase): """Manages registered dimensions.""" def __init__(self, path, params): super().__init__(path, params) self._dimensions = {} # key = ConfigKey, value = DimensionConfig @staticmethod def config_class(): return DimensionConfig @property def db(self) -> DimensionRegistryInterface: return self._db @db.setter def db(self, db: DimensionRegistryInterface): self._db = db @staticmethod def name(): return "Dimensions" def _replace_duplicates(self, config: DimensionsConfig): hashes = {} time_dims = {} for dimension in self._db.iter_models(all_versions=True): if isinstance(dimension, TimeDimensionBaseModel): time_dims[dimension.id] = dimension else: hashes[dimension.file_hash] = dimension # TODO: This only works if the matching dimension is the latest. existing_ids = set() for i, dim in enumerate(config.model.dimensions): replace_dim = False existing = None if isinstance(dim, TimeDimensionBaseModel): existing = self._get_matching_time_dimension(time_dims.values(), dim) if existing is not None: replace_dim = True elif dim.file_hash in hashes: existing = hashes[dim.file_hash] if dim.dimension_type == existing.dimension_type: if dim.name == existing.name and dim.display_name == existing.display_name: replace_dim = True else: logger.info( "Register new dimension even though records are duplicate with " "existing dimension. Existing name/display_name=%s/%s. " "New name/display_name=%s/%s", existing.name, existing.display_name, dim.name, dim.display_name, ) if replace_dim: logger.info( "Replace %s with existing dimension ID %s", dim.name, existing.dimension_id ) config.model.dimensions[i] = existing existing_ids.add(existing.dimension_id) return existing_ids @staticmethod def _get_matching_time_dimension(existing_dims, new_dim): for time_dim in existing_dims: if type(time_dim) is not type(new_dim): continue match = True exclude = set(("description", "dimension_id", "key", "id", "rev", "version")) for field in type(new_dim).model_fields: if field not in exclude and getattr(new_dim, field) != getattr(time_dim, field): match = False break if match: return time_dim return None def finalize_registration(self, config_ids: list[str], error_occurred: bool): if error_occurred: logger.info("Remove all intermediate dimensions after error") for dimension_id in config_ids: self.remove(dimension_id)
[docs] def get_by_id(self, config_id, version=None): if version is None: version = self._db.get_latest_version(config_id) key = ConfigKey(config_id, version) dimension = self._dimensions.get(key) if dimension is not None: return dimension if version is None: model = self.db.get_latest(config_id) else: model = self.db.get_by_version(config_id, version) config = get_dimension_config(model) self._dimensions[key] = config return config
def list_ids(self, dimension_type=None): """Return the dimension ids for the given type. Parameters ---------- dimension_type : DimensionType Returns ------- list """ if dimension_type is None: ids = list(self.iter_ids()) else: ids = [ x.dimension_id for x in self.db.iter_models(filter_config={"dimension_type": dimension_type}) ] ids.sort() return ids def load_dimensions(self, dimension_references): """Load dimensions from the database. Parameters ---------- dimension_references : list iterable of DimensionReferenceModel instances Returns ------- dict ConfigKey to DimensionConfig """ dimensions = {} for dim in dimension_references: key = ConfigKey(dim.dimension_id, dim.version) dimensions[key] = self.get_by_id(dim.dimension_id, version=dim.version) return dimensions @track_timing(timer_stats_collector) def register_from_config(self, config: DimensionsConfig, submitter, log_message, context=None): error_occurred = False need_to_finalize = context is None if context is None: context = RegistrationContext() try: return self._register(config, submitter, log_message, context) except Exception: error_occurred = True raise finally: if need_to_finalize: context.finalize(error_occurred) @track_timing(timer_stats_collector) def register(self, config_file, submitter, log_message): context = RegistrationContext() error_occurred = False try: config = DimensionsConfig.load(config_file) return self.register_from_config(config, submitter, log_message, context=context) except Exception: error_occurred = True raise finally: context.finalize(error_occurred) def _register(self, config, submitter, log_message, context): existing_ids = self._replace_duplicates(config) registration = make_initial_config_registration(submitter, log_message) # This function will either register the dimension specified by each model or re-use an # existing ID. The returned list must be in the same order as the list of models. final_dimension_ids = [] registered_dimension_ids = [] try: # Guarantee that registration of dimensions is all or none. for dim in config.model.dimensions: if dim.id is None: dim = self.db.insert(dim, registration) final_dimension_ids.append(dim.dimension_id) registered_dimension_ids.append(dim.dimension_id) logger.info( "%s Registered dimension id=%s type=%s version=%s name=%s", self._log_offline_mode_prefix(), dim.id, dim.dimension_type.value, registration.version, dim.name, ) else: if dim.dimension_id not in existing_ids: msg = f"Bug: {dim.dimension_id=} should have been in existing_ids" raise Exception(msg) final_dimension_ids.append(dim.dimension_id) except Exception: if registered_dimension_ids: logger.warning( "Exception occured after partial completion of dimension registration." ) for dimension_id in registered_dimension_ids: self.remove(dimension_id) raise logger.info( "Registered %s dimensions with version=%s", len(config.model.dimensions), registration.version, ) context.add_ids(RegistryType.DIMENSION, registered_dimension_ids, self) return final_dimension_ids def make_dimension_references(self, dimension_ids: list[str]): """Return a list of dimension references from a list of registered dimension IDs. This assumes that the latest version of the dimensions will be used because they were just created. Parameters ---------- dimension_ids : list[str] """ refs = [] for dim_id in dimension_ids: dim = self.db.get_latest(dim_id) refs.append( DimensionReferenceModel( dimension_id=dim_id, dimension_type=dim.dimension_type, version=dim.version, ) ) return refs
[docs] def show( self, filters: Optional[list[str]] = None, max_width: Optional[Union[int, dict]] = None, drop_fields: Optional[list[str]] = None, dimension_ids: Optional[set[str]] = None, return_table: bool = False, **kwargs, ): """Show registry in PrettyTable Parameters ---------- filters : list or tuple List of filter expressions for reigstry content (e.g., filters=["Submitter==USER", "Description contains comstock"]) max_width Max column width in PrettyTable, specify as a single value or as a dict of values by field name drop_fields List of field names not to show """ if filters: logger.info("List registered dimensions for: %s", filters) table = PrettyTable(title="Dimensions") all_field_names = ( "Type", "Query Name", "ID", "Version", "Date", "Submitter", "Description", ) if drop_fields is None: table.field_names = all_field_names else: table.field_names = tuple(x for x in all_field_names if x not in drop_fields) if max_width is None: table._max_width = { "ID": 40, "Date": 10, "Description": 40, } if isinstance(max_width, int): table.max_width = max_width elif isinstance(max_width, dict): table._max_width = max_width if filters: transformed_filters = transform_and_validate_filters(filters) field_to_index = {x: i for i, x in enumerate(table.field_names)} rows = [] for model in self.db.iter_models(): registration = self.db.get_registration(model) if dimension_ids and model.dimension_id not in dimension_ids: continue all_fields = ( model.dimension_type.value, model.dimension_query_name, model.dimension_id, registration.version, registration.date.strftime("%Y-%m-%d %H:%M:%S"), registration.submitter, registration.log_message, ) if drop_fields is None: row = all_fields else: row = tuple( y for (x, y) in zip(all_field_names, all_fields) if x not in drop_fields ) if not filters or matches_filters(row, field_to_index, transformed_filters): rows.append(row) rows.sort(key=lambda x: x[0]) table.add_rows(rows) table.align = "l" if return_table: return table display_table(table)
def update_from_file( self, config_file, dimension_id, submitter, update_type, log_message, version ): config = load_dimension_config(config_file) self._check_update(config, dimension_id, version) self.update(config, update_type, log_message, submitter=submitter) @track_timing(timer_stats_collector) def update(self, config, update_type, log_message, submitter=None): if submitter is None: submitter = getpass.getuser() return self._update(config, submitter, update_type, log_message) def _update(self, config, submitter, update_type, log_message): old_config = self.get_by_id(config.model.dimension_id) checker = DimensionUpdateChecker(old_config.model, config.model) checker.run() cur_version = old_config.model.version old_key = ConfigKey(config.model.dimension_id, cur_version) model = self._update_config(config, submitter, update_type, log_message) new_key = ConfigKey(config.model.dimension_id, model.version) self._dimensions.pop(old_key, None) self._dimensions[new_key] = get_dimension_config(model) return model def remove(self, dimension_id): self.db.delete_all(dimension_id) for key in [x for x in self._dimensions if x.id == dimension_id]: self._dimensions.pop(key) logger.info("Removed %s from the registry.", dimension_id)