import abc
import csv
import importlib
import logging
import os
from datetime import datetime, timedelta
from typing import Any, Optional, Union, Literal
import copy
from pydantic import field_serializer, field_validator, model_validator, Field, ValidationInfo
from pydantic.functional_validators import BeforeValidator
from typing_extensions import Annotated
from dsgrid.data_models import DSGBaseDatabaseModel, DSGBaseModel
from dsgrid.dimension.base_models import DimensionType, DimensionCategory
from dsgrid.dimension.time import (
TimeIntervalType,
MeasurementType,
TimeZone,
TimeDimensionType,
RepresentativePeriodFormat,
DatetimeFormat,
)
from dsgrid.registry.common import REGEX_VALID_REGISTRY_NAME
from dsgrid.utils.files import compute_file_hash
from dsgrid.utils.utilities import convert_record_dicts_to_classes
logger = logging.getLogger(__name__)
class DimensionBaseModel(DSGBaseDatabaseModel):
"""Common attributes for all dimensions"""
name: str = Field(
title="name",
description="Dimension name",
json_schema_extra={
"note": "Dimension names should be descriptive, memorable, identifiable, and reusable for "
"other datasets and projects",
"notes": (
"Only alphanumeric characters and dashes are supported (no underscores or spaces).",
"The :meth:`~dsgrid.config.dimensions.check_name` validator is used to enforce valid"
" dimension names.",
),
"updateable": False,
},
)
dimension_type: DimensionType = Field(
title="dimension_type",
alias="type",
description="Type of the dimension",
json_schema_extra={
"options": DimensionType.format_for_docs(),
},
)
dimension_id: Optional[str] = Field(
default=None,
title="dimension_id",
description="Unique identifier, generated by dsgrid",
json_schema_extra={
"dsg_internal": True,
"updateable": False,
},
)
module: str = Field(
title="module",
description="Python module with the dimension class",
default="dsgrid.dimension.standard",
)
class_name: str = Field(
title="class_name",
description="Dimension record model class name",
alias="class",
json_schema_extra={
"notes": (
"The dimension class defines the expected and allowable fields (and their data types)"
" for the dimension records file.",
"All dimension records must have a 'id' and 'name' field."
"Some dimension classes support additional fields that can be used for mapping,"
" querying, display, etc.",
"dsgrid in online-mode only supports dimension classes defined in the"
" :mod:`dsgrid.dimension.standard` module. If dsgrid does not currently support a"
" dimension class that you require, please contact the dsgrid-coordination team to"
" request a new class feature",
),
},
)
cls: Any = Field(
default=None,
title="cls",
description="Dimension record model class",
alias="dimension_class",
json_schema_extra={
"dsgrid_internal": True,
},
)
description: str = Field(
title="description",
description="A description of the dimension records that is helpful, memorable, and "
"identifiable",
json_schema_extra={
"notes": (
"The description will get stored in the dimension record registry and may be used"
" when searching the registry.",
),
},
)
id: Optional[int] = Field(
default=None,
description="Registry database ID",
json_schema_extra={
"dsgrid_internal": True,
},
)
@field_validator("name")
@classmethod
def check_name(cls, name: str) -> str:
if REGEX_VALID_REGISTRY_NAME.search(name) is None:
raise ValueError(f"dimension name={name} does not meet the requirements")
return name
@field_validator("description")
@classmethod
def check_description(cls, description):
if description == "":
raise ValueError(f'Empty description field for dimension: "{cls}"')
# TODO: improve validation for allowable dimension record names.
prohibited_names = [x.value for x in DimensionType] + [
"county",
"counties",
"year",
"hourly",
"comstock",
"resstock",
"tempo",
"model",
"source",
"data-source",
"dimension",
]
prohibited_names = prohibited_names + [x + "s" for x in prohibited_names]
if description.lower() in prohibited_names:
raise ValueError(
f"""
Dimension description '{description}' is insufficient. Please be more descriptive.
Hint: try adding a vintage, or other distinguishable text that will be this dimension memorable,
identifiable, and reusable for other datasets and projects.
e.g., 'Time dimension, 2012 hourly EST, period-beginning, no DST, no Leap Day Adjustment, total value'
is a good description.
"""
)
return description
@field_validator("module")
@classmethod
def check_module(cls, module) -> "DimensionBaseModel":
if not module.startswith("dsgrid"):
raise ValueError("Only dsgrid modules are supported as a dimension module.")
return module
@field_validator("class_name")
@classmethod
def get_dimension_class_name(cls, class_name, info: ValidationInfo):
"""Set class_name based on inputs."""
if "module" not in info.data:
return class_name
mod = importlib.import_module(info.data["module"])
if not hasattr(mod, class_name):
if class_name is None:
msg = (
f'There is no class "{class_name}" in module: {mod}.'
"\nIf you are using a unique dimension name, you must "
"specify the dimension class."
)
else:
msg = f"dimension class {class_name} not in {mod}"
raise ValueError(msg)
return class_name
@field_validator("cls")
@classmethod
def get_dimension_class(cls, dim_class, info: ValidationInfo):
if "module" not in info.data or "class_name" not in info.data:
return dim_class
if dim_class is not None:
raise ValueError(f"cls={dim_class} should not be set")
return getattr(
importlib.import_module(info.data["module"]),
info.data["class_name"],
)
@property
def label(self) -> str:
"""Return a label for the dimension to be used in user messages."""
return f"{self.dimension_type} {self.name}"
[docs]
class DimensionModel(DimensionBaseModel):
"""Defines a non-time dimension"""
filename: Optional[str] = Field(
title="filename",
alias="file",
default=None,
description="Filename containing dimension records. Only assigned for user input and "
"output purposes. The registry database stores records in the dimension JSON document.",
)
file_hash: Optional[str] = Field(
title="file_hash",
description="Hash of the contents of the file",
json_schema_extra={
"dsgrid_internal": True,
},
default=None,
)
records: list = Field(
title="records",
description="Dimension records in filename that get loaded at runtime",
json_schema_extra={
"dsgrid_internal": True,
},
default=[],
)
[docs]
@field_validator("filename")
@classmethod
def check_file(cls, filename):
"""Validate that dimension file exists and has no errors"""
if filename is not None:
if not os.path.isfile(filename):
raise ValueError(f"file {filename} does not exist")
if filename.startswith("s3://"):
raise ValueError("records must exist in the local filesystem, not on S3")
if not filename.endswith(".csv"):
raise ValueError(f"only CSV is supported: {filename}")
return filename
[docs]
@field_validator("file_hash")
@classmethod
def compute_file_hash(cls, file_hash, info: ValidationInfo):
if "filename" not in info.data:
return file_hash
if file_hash is None:
file_hash = compute_file_hash(info.data["filename"])
return file_hash
[docs]
@field_validator("records")
@classmethod
def add_records(cls, records, info: ValidationInfo):
"""Add records from the file."""
dim_class = info.data.get("cls")
if "filename" not in info.data or dim_class is None:
return records
if records:
if isinstance(records[0], dict):
records = convert_record_dicts_to_classes(
records, dim_class, check_duplicates=["id"]
)
return records
with open(info.data["filename"], encoding="utf-8-sig") as f_in:
records = convert_record_dicts_to_classes(
csv.DictReader(f_in), dim_class, check_duplicates=["id"]
)
return records
[docs]
@field_serializer("cls", "filename")
def serialize_cls(self, val, _):
return None
[docs]
class TimeRangeModel(DSGBaseModel):
"""Defines a continuous range of time."""
# This uses str instead of datetime because this object doesn't have the ability
# to serialize/deserialize by itself (no str-format).
# We use the DatetimeRange object during processing.
start: str = Field(
title="start",
description="First timestamp in the data",
)
end: str = Field(
title="end",
description="Last timestamp in the data (inclusive)",
)
[docs]
class MonthRangeModel(DSGBaseModel):
"""Defines a continuous range of time."""
# This uses str instead of datetime because this object doesn't have the ability
# to serialize/deserialize by itself (no str-format).
# We use the DatetimeRange object during processing.
start: int = Field(
title="start",
description="First month in the data (January is 1, December is 12)",
)
end: int = Field(
title="end",
description="Last month in the data (inclusive)",
)
class IndexRangeModel(DSGBaseModel):
"""Defines a continuous range of indices."""
start: int = Field(
title="start",
description="First of indices",
)
end: int = Field(
title="end",
description="Last of indices (inclusive)",
)
class TimeDimensionBaseModel(DimensionBaseModel, abc.ABC):
"""Defines a base model common to all time dimensions."""
time_type: TimeDimensionType = Field(
title="time_type",
default=TimeDimensionType.DATETIME,
description="Type of time dimension",
json_schema_extra={
"options": TimeDimensionType.format_for_docs(),
},
)
@field_serializer("cls")
def serialize_cls(self, val, _):
return None
@abc.abstractmethod
def is_time_zone_required_in_geography(self):
"""Returns True if the geography dimension records must contain a time_zone column."""
class AlignedTime(DSGBaseModel):
"""Data has absolute timestamps that are aligned with the same start and end
for each geography."""
format_type: Literal[DatetimeFormat.ALIGNED] = DatetimeFormat.ALIGNED
timezone: TimeZone = Field(
title="timezone",
description="Time zone of data",
json_schema_extra={
"options": TimeZone.format_descriptions_for_docs(),
},
)
class LocalTimeAsStrings(DSGBaseModel):
"""Data has absolute timestamps formatted as strings with offsets from UTC.
They are aligned for each geography when adjusted for time zone but staggered
in an absolute time scale."""
format_type: Literal[DatetimeFormat.LOCAL_AS_STRINGS] = DatetimeFormat.LOCAL_AS_STRINGS
data_str_format: str = Field(
title="data_str_format",
default="yyyy-MM-dd HH:mm:ssZZZZZ",
description="Timestamp string format (for parsing the time column of the dataframe)",
json_schema_extra={
"notes": (
"The string format is used to parse the timestamps in the dataframe while in Spark, "
"(e.g., yyyy-MM-dd HH:mm:ssZZZZZ). "
"Cheatsheet reference: `<https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html>`_.",
),
},
)
@field_validator("data_str_format")
@classmethod
def check_data_str_format(cls, data_str_format):
raise NotImplementedError("DatetimeFormat.LOCAL_AS_STRINGS is not fully implemented.")
dsf = data_str_format
if (
"x" not in dsf
and "X" not in dsf
and "Z" not in dsf
and "z" not in dsf
and "V" not in dsf
and "O" not in dsf
):
raise ValueError("data_str_format must provide time zone or zone offset.")
return data_str_format
[docs]
class DateTimeDimensionModel(TimeDimensionBaseModel):
"""Defines a time dimension where timestamps translate to datetime objects."""
datetime_format: Union[AlignedTime, LocalTimeAsStrings] = Field(
title="datetime_format",
discriminator="format_type",
description="""
Format of the datetime used to define the data format, alignment between geography,
and time zone information.
""",
)
measurement_type: MeasurementType = Field(
title="measurement_type",
default=MeasurementType.TOTAL,
description="""
The type of measurement represented by a value associated with a timestamp:
mean, min, max, measured, total
""",
json_schema_extra={
"options": MeasurementType.format_for_docs(),
},
)
str_format: str = Field(
title="str_format",
default="%Y-%m-%d %H:%M:%s",
description="Timestamp string format (for parsing the time ranges)",
json_schema_extra={
"notes": (
"The string format is used to parse the timestamps provided in the time ranges."
"Cheatsheet reference: `<https://strftime.org/>`_.",
),
},
)
frequency: timedelta = Field(
title="frequency",
description="Resolution of the timestamps",
json_schema_extra={
"notes": (
"Reference: `Datetime timedelta objects"
" <https://docs.python.org/3/library/datetime.html#timedelta-objects>`_",
),
},
)
ranges: list[TimeRangeModel] = Field(
title="time_ranges",
description="Defines the continuous ranges of time in the data, inclusive of start and end time.",
)
time_interval_type: TimeIntervalType = Field(
title="time_interval",
description="The range of time that the value associated with a timestamp represents, e.g., period-beginning",
json_schema_extra={
"options": TimeIntervalType.format_descriptions_for_docs(),
},
)
[docs]
@model_validator(mode="before")
@classmethod
def handle_legacy_fields(cls, values):
if "leap_day_adjustment" in values:
if values["leap_day_adjustment"] != "none":
msg = f"Unknown data_schema format: {values=}"
raise ValueError(msg)
values.pop("leap_day_adjustment")
if "timezone" in values:
values["datetime_format"] = {
"format_type": DatetimeFormat.ALIGNED.value,
"timezone": values["timezone"],
}
values.pop("timezone")
return values
# @model_validator(mode="after")
# def check_frequency(self) -> "DateTimeDimensionModel":
# if self.frequency in [timedelta(days=365), timedelta(days=366)]:
# raise ValueError(
# f"frequency={self.frequency}, datetime config does not allow 365 or 366 days frequency, "
# "use class=AnnualTime, time_type=annual to specify a year series."
# )
# return self
[docs]
@field_validator("frequency")
@classmethod
def check_frequency(cls, frequency: timedelta) -> timedelta:
if frequency in [timedelta(days=365), timedelta(days=366)]:
msg = (
f"{frequency=}, datetime config does not allow 365 or 366 days frequency, "
"use class=AnnualTime, time_type=annual to specify a year series."
)
raise ValueError(msg)
return frequency
[docs]
@field_validator("ranges")
@classmethod
def check_times(
cls, ranges: list[TimeRangeModel], info: ValidationInfo
) -> list[TimeRangeModel]:
if "str_format" not in info.data or "frequency" not in info.data:
return ranges
return _check_time_ranges(ranges, info.data["str_format"], info.data["frequency"])
[docs]
def is_time_zone_required_in_geography(self):
return False
[docs]
class AnnualTimeDimensionModel(TimeDimensionBaseModel):
"""Defines an annual time dimension where timestamps are years."""
time_type: TimeDimensionType = Field(default=TimeDimensionType.ANNUAL)
measurement_type: MeasurementType = Field(
title="measurement_type",
default=MeasurementType.TOTAL,
description="""
The type of measurement represented by a value associated with a timestamp:
e.g., mean, total
""",
json_schema_extra={
"options": MeasurementType.format_for_docs(),
},
)
str_format: str = Field(
title="str_format",
default="%Y",
description="Timestamp string format",
json_schema_extra={
"notes": (
"The string format is used to parse the timestamps provided in the time ranges."
"Cheatsheet reference: `<https://strftime.org/>`_.",
),
},
)
ranges: list[TimeRangeModel] = Field(
default=None,
title="time_ranges",
description="Defines the contiguous ranges of time in the data, inclusive of start and end time.",
)
include_leap_day: bool = Field(
title="include_leap_day",
default=False,
description="Whether annual time includes leap day.",
)
[docs]
@field_validator("ranges")
@classmethod
def check_times(
cls, ranges: list[TimeRangeModel], info: ValidationInfo
) -> list[TimeRangeModel]:
return _check_annual_ranges(ranges, info.data["str_format"])
[docs]
@field_validator("measurement_type")
@classmethod
def check_measurement_type(cls, measurement_type: MeasurementType) -> MeasurementType:
# This restriction exists because any other measurement type would require a frequency,
# and that isn't part of the model definition.
if measurement_type != MeasurementType.TOTAL:
msg = f"Annual time currently only supports MeasurementType total: {measurement_type}"
raise ValueError(msg)
return measurement_type
[docs]
def is_time_zone_required_in_geography(self):
return False
[docs]
class RepresentativePeriodTimeDimensionModel(TimeDimensionBaseModel):
"""Defines a representative time dimension."""
time_type: TimeDimensionType = Field(default=TimeDimensionType.REPRESENTATIVE_PERIOD)
measurement_type: MeasurementType = Field(
title="measurement_type",
default=MeasurementType.TOTAL,
description="""
The type of measurement represented by a value associated with a timestamp:
e.g., mean, total
""",
json_schema_extra={
"options": MeasurementType.format_for_docs(),
},
)
format: RepresentativePeriodFormat = Field(
title="format",
description="Format of the timestamps in the load data",
)
ranges: list[MonthRangeModel] = Field(
title="ranges",
description="Defines the continuous ranges of time in the data, inclusive of start and end time.",
)
time_interval_type: TimeIntervalType = Field(
title="time_interval",
description="The range of time that the value associated with a timestamp represents",
json_schema_extra={
"options": TimeIntervalType.format_descriptions_for_docs(),
},
)
[docs]
def is_time_zone_required_in_geography(self):
return True
class IndexTimeDimensionModel(TimeDimensionBaseModel):
"""Defines a time dimension where timestamps are indices."""
time_type: TimeDimensionType = Field(default=TimeDimensionType.INDEX)
measurement_type: MeasurementType = Field(
title="measurement_type",
default=MeasurementType.TOTAL,
description="""
The type of measurement represented by a value associated with a timestamp:
e.g., mean, total
""",
json_schema_extra={
"options": MeasurementType.format_for_docs(),
},
)
ranges: list[IndexRangeModel] = Field(
title="ranges",
description="Defines the continuous ranges of indices of the data, inclusive of start and end index.",
)
frequency: timedelta = Field(
title="frequency",
description="Resolution of the timestamps for which the ranges represent.",
json_schema_extra={
"notes": (
"Reference: `Datetime timedelta objects"
" <https://docs.python.org/3/library/datetime.html#timedelta-objects>`_",
),
},
)
starting_timestamps: list[str] = Field(
title="starting timestamps",
description="Starting timestamp for for each of the ranges.",
)
str_format: str = Field(
title="str_format",
default="%Y-%m-%d %H:%M:%s",
description="Timestamp string format",
json_schema_extra={
"notes": (
"The string format is used to parse the starting timestamp provided."
"Cheatsheet reference: `<https://strftime.org/>`_.",
),
},
)
time_interval_type: TimeIntervalType = Field(
title="time_interval",
description="The range of time that the value associated with a timestamp represents, e.g., period-beginning",
json_schema_extra={
"options": TimeIntervalType.format_descriptions_for_docs(),
},
)
@field_validator("starting_timestamps")
@classmethod
def check_timestamps(cls, starting_timestamps, info: ValidationInfo) -> list[str]:
if len(starting_timestamps) != len(info.data["ranges"]):
msg = f"{starting_timestamps=} must match the number of ranges."
raise ValueError(msg)
return starting_timestamps
@field_validator("ranges")
@classmethod
def check_indices(cls, ranges: list[IndexRangeModel]) -> list[IndexRangeModel]:
return _check_index_ranges(ranges)
def is_time_zone_required_in_geography(self) -> bool:
return True
[docs]
class NoOpTimeDimensionModel(TimeDimensionBaseModel):
"""Defines a NoOp time dimension."""
time_type: TimeDimensionType = TimeDimensionType.NOOP
[docs]
def is_time_zone_required_in_geography(self) -> bool:
return False
[docs]
class DimensionReferenceModel(DSGBaseModel):
"""Reference to a dimension stored in the registry"""
dimension_type: DimensionType = Field(
title="dimension_type",
alias="type",
description="Type of the dimension",
json_schema_extra={
"options": DimensionType.format_for_docs(),
},
)
dimension_id: str = Field(
title="dimension_id",
description="Unique ID of the dimension in the registry",
json_schema_extra={
"notes": (
"The dimension ID is generated by dsgrid when a dimension is registered.",
"Only alphanumerics and dashes are supported.",
),
},
)
version: str = Field(
title="version",
description="Version of the dimension",
json_schema_extra={
"requirements": (
"The version string must be in semver format (e.g., '1.0.0') and it must be "
" a valid/existing version in the registry.",
),
# TODO: add notes about warnings for outdated versions DSGRID-189 & DSGRID-148
},
)
def handle_dimension_union(values):
values = copy.deepcopy(values)
for i, value in enumerate(values):
if isinstance(value, DimensionBaseModel):
continue
dim_type = value.get("type")
if dim_type is None:
dim_type = value["dimension_type"]
# NOTE: Errors inside DimensionModel or DateTimeDimensionModel will be duplicated by Pydantic
if dim_type == DimensionType.TIME.value:
if value["time_type"] == TimeDimensionType.DATETIME.value:
values[i] = DateTimeDimensionModel(**value)
elif value["time_type"] == TimeDimensionType.ANNUAL.value:
values[i] = AnnualTimeDimensionModel(**value)
elif value["time_type"] == TimeDimensionType.REPRESENTATIVE_PERIOD.value:
values[i] = RepresentativePeriodTimeDimensionModel(**value)
elif value["time_type"] == TimeDimensionType.INDEX.value:
values[i] = IndexTimeDimensionModel(**value)
elif value["time_type"] == TimeDimensionType.NOOP.value:
values[i] = NoOpTimeDimensionModel(**value)
else:
options = [x.value for x in TimeDimensionType]
raise ValueError(f"{value['time_type']} not supported, valid options: {options}")
else:
values[i] = DimensionModel(**value)
return values
DimensionsListModel = Annotated[
list[
Union[
DimensionModel,
DateTimeDimensionModel,
AnnualTimeDimensionModel,
RepresentativePeriodTimeDimensionModel,
IndexTimeDimensionModel,
NoOpTimeDimensionModel,
]
],
BeforeValidator(handle_dimension_union),
]
def _check_time_ranges(ranges: list[TimeRangeModel], str_format: str, frequency: timedelta):
assert isinstance(frequency, timedelta)
for trange in ranges:
# Make sure start and end time parse.
start = datetime.strptime(trange.start, str_format)
end = datetime.strptime(trange.end, str_format)
# Make sure start and end is tz-naive.
if start.tzinfo is not None or end.tzinfo is not None:
msg = f"datetime range {trange} start and end need to be tz-naive. Pass in the time zone info via datetime_format"
raise ValueError(msg)
if end < start:
msg = f"datetime range {trange} end must not be less than start."
raise ValueError(msg)
if (end - start) % frequency != timedelta(0):
msg = f"datetime range {trange} is inconsistent with {frequency}"
raise ValueError(msg)
return ranges
def _check_annual_ranges(ranges: list[TimeRangeModel], str_format: str):
for trange in ranges:
# Make sure start and end time parse.
start = datetime.strptime(trange.start, str_format)
end = datetime.strptime(trange.end, str_format)
if end < start:
msg = f"annual time range {trange} end must not be less than start."
raise ValueError(msg)
return ranges
def _check_index_ranges(ranges: list[IndexRangeModel]):
for trange in ranges:
if trange.end < trange.start:
msg = f"index range {trange} end must not be less than start."
raise ValueError(msg)
return ranges
class DimensionCommonModel(DSGBaseModel):
"""Common attributes for all dimensions"""
name: str
dimension_type: DimensionType
dimension_id: str
class_name: str
description: str
class ProjectDimensionModel(DimensionCommonModel):
"""Common attributes for all dimensions that are assigned to a project"""
category: DimensionCategory
def create_dimension_common_model(model) -> DimensionCommonModel:
"""Constructs an instance of DimensionBaseModel from subclasses in order to give the API
one common model for all dimensions. Avoids the complexity of dealing with
DimensionBaseModel validators.
"""
fields = set(DimensionCommonModel.model_fields)
data = {x: getattr(model, x) for x in type(model).model_fields if x in fields}
return DimensionCommonModel(**data)
def create_project_dimension_model(model, category: DimensionCategory) -> ProjectDimensionModel:
data = create_dimension_common_model(model).model_dump()
data["category"] = category.value
return ProjectDimensionModel(**data)