Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lightly refactor messages #5

Merged
merged 1 commit into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion snews/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# -*- coding: utf-8 -*-
from .data import detectors
from .models import messages, timing
from .schema import SNEWSJsonSchema

__all__ = ["data", "models", "schemas"]
__all__ = ["detectors", "messages", "timing", "SNEWSJsonSchema"]
3 changes: 3 additions & 0 deletions snews/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-

# Standard modules
import inspect
import json
import logging
from pathlib import Path
Expand Down Expand Up @@ -32,6 +33,8 @@ def generate_model_schemas(outdir: str = None, models_module: list = models, dry
model_class = getattr(models, model_class_name)
for model_name in model_class.__all__:
model = getattr(model_class, model_name)
if not inspect.isclass(model):
continue

if not issubclass(model, BaseModel):
continue
Expand Down
141 changes: 83 additions & 58 deletions snews/examples/tutorial.ipynb

Large diffs are not rendered by default.

116 changes: 89 additions & 27 deletions snews/models/messages.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
# -*- coding: utf-8 -*-
__all__ = [
"HeartBeat",
"Retraction",
"CoincidenceTierMessage",
"SignificanceTierMessage",
"TimingTierMessage"
]

# Standard library modules
from datetime import UTC, datetime, timedelta
Expand All @@ -15,15 +8,24 @@

# Third-party modules
import numpy as np
from pydantic import (UUID4, BaseModel, Field, NonNegativeFloat,
field_validator, model_validator, root_validator,
validator)
from pydantic import (BaseModel, Field, NonNegativeFloat, ValidationError,
field_validator, model_validator)

# Local modules
from ..__version__ import schema_version
from ..data import detectors
from ..models.timing import PrecisionTimestamp

__all__ = [
"HeartbeatMessage",
"RetractionMessage",
"CoincidenceTierMessage",
"SignificanceTierMessage",
"TimingTierMessage",
"compatible_message_types",
"create_messages",
]


# .................................................................................................
def convert_timestamp_to_ns_precision(timestamp: Union[str, datetime, np.datetime64]) -> str:
Expand Down Expand Up @@ -68,10 +70,11 @@ class Config:
description="Textual identifier for the message"
)

uid: UUID4 = Field(
uuid: str = Field(
title="Unique message ID",
default_factory=uuid4,
description="Unique identifier for the message"
description="Unique identifier for the message",
validate_default=True
)

tier: Tier = Field(
Expand All @@ -83,13 +86,15 @@ class Config:
sent_time_utc: Optional[str] = Field(
default=None,
title="Sent time (UTC)",
description="Time the message was sent in ISO 8601-1:2019 format"
description="Time the message was sent in ISO 8601-1:2019 format",
validate_default=True
)

machine_time_utc: Optional[str] = Field(
default=None,
title="Machine time (UTC)",
description="Time of the event at the detector in ISO 8601-1:2019 format"
description="Time of the event at the detector in ISO 8601-1:2019 format",
validate_default=True
)

is_pre_sn: Optional[bool] = Field(
Expand Down Expand Up @@ -123,14 +128,21 @@ class Config:
frozen=True,
)

@validator("sent_time_utc", "machine_time_utc", pre=True, always=True)
@field_validator("sent_time_utc", "machine_time_utc", mode="before")
def _convert_timestamp_to_ns_precision(cls, v):
"""
Convert to nanosecond precision (before running Pydantic validators).
"""
if v is not None:
return convert_timestamp_to_ns_precision(timestamp=v)

@field_validator("uuid", mode="before")
def _cast_uuid_to_string(cls, v):
"""
Cast UUID to string (before running Pydantic validators).
"""
return str(v)

@model_validator(mode="after")
def _format_id(self):
"""
Expand All @@ -143,6 +155,18 @@ def _format_id(self):

return self

def fields(self):
"""
Return a list of fields for the message.
"""
return list(self.model_fields.keys())

def required_fields(self):
"""
Return a list of required fields for the message.
"""
return [k for k, v in self.model_fields.items() if v.is_required()]


# .................................................................................................
class DetectorMessageBase(MessageBase):
Expand Down Expand Up @@ -171,7 +195,7 @@ def _validate_detector_name(self) -> str:


# .................................................................................................
class HeartBeat(DetectorMessageBase):
class HeartbeatMessage(DetectorMessageBase):
"""
Heartbeat detector message.
"""
Expand All @@ -186,7 +210,7 @@ class Config:
examples=["ON", "OFF"]
)

@root_validator(pre=True)
@model_validator(mode="before")
def _set_tier(cls, values):
values['tier'] = Tier.HEART_BEAT
return values
Expand All @@ -204,15 +228,15 @@ def _validate_model(self):


# .................................................................................................
class Retraction(DetectorMessageBase):
class RetractionMessage(DetectorMessageBase):
"""
Retraction detector message.
"""

class Config:
validate_assignment = True

retract_message_uid: Optional[UUID4] = Field(
retract_message_uid: Optional[str] = Field(
default=None,
title="Unique message ID",
description="Unique identifier for the message to retract"
Expand All @@ -230,25 +254,25 @@ class Config:
description="Reason for retraction",
)

@root_validator(pre=True)
@model_validator(mode="before")
def _set_tier(cls, values):
values['tier'] = Tier.RETRACTION
return values

@model_validator(mode="after")
def _validate_model(self):
if self.retract_latest and self.retract_message_uid is not None:
raise ValueError("retract_message_uid cannot be specified when retract_latest=True")
raise ValueError("retract_message_uuid cannot be specified when retract_latest=True")

if not self.retract_latest and self.retract_message_uid is None:
raise ValueError("Must specify either retract_message_uid or retract_latest=True")
raise ValueError("Must specify either retract_message_uuid or retract_latest=True")
return self


# .................................................................................................
class TierMessageBase(DetectorMessageBase):
"""
Tier detector base message
Tier base message
"""

class Config:
Expand Down Expand Up @@ -276,13 +300,13 @@ class TimingTierMessage(TierMessageBase):
class Config:
validate_assignment = True

timing_series: List[str] = Field(
timing_series: List[Union[str, int]] = Field(
...,
title="Timing Series",
description="Timing series of the event",
)

@root_validator(pre=True)
@model_validator(mode="before")
def _set_tier(cls, values):
values['tier'] = Tier.TIMING_TIER
return values
Expand Down Expand Up @@ -322,7 +346,7 @@ class Config:
description="Time bin width of the event",
)

@root_validator(pre=True)
@model_validator(mode="before")
def _set_tier(cls, values):
values['tier'] = Tier.SIGNIFICANCE_TIER
return values
Expand Down Expand Up @@ -358,7 +382,7 @@ class Config:
description="Time of the first neutrino in the event in ISO 8601-1:2019 format"
)

@root_validator(pre=True)
@model_validator(mode="before")
def _set_tier(cls, values):
values['tier'] = Tier.COINCIDENCE_TIER
return values
Expand All @@ -384,3 +408,41 @@ def _validate_neutrino_time(self):
raise ValueError("neutrino_time_utc must be in the past")

return self


# .................................................................................................
def compatible_message_types(**kwargs) -> list:
"""
Return a list of message types that are compatible with the given keyword arguments.
"""

message_types = [
HeartbeatMessage,
RetractionMessage,
CoincidenceTierMessage,
SignificanceTierMessage,
TimingTierMessage,
]

compatible_message_types = []
for message_type in message_types:
try:
message_type(**kwargs)
compatible_message_types.append(message_type)
except ValidationError:
pass

return compatible_message_types


# .................................................................................................
def create_messages(**kwargs) -> list:
"""
Return a list of messages initialized with the given keyword arguments.
"""

messages = []
for message_type in compatible_message_types(**kwargs):
messages.append(message_type(**kwargs))

return messages
7 changes: 3 additions & 4 deletions snews/schema/CoincidenceTierMessage.schema.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"schema_author": "Supernova Early Warning System (SNEWS)",
"schema_version": "1a",
"schema_version": "0.1",
"$defs": {
"Tier": {
"enum": [
Expand Down Expand Up @@ -30,9 +30,8 @@
"description": "Textual identifier for the message",
"title": "Human-readable message ID"
},
"uid": {
"uuid": {
"description": "Unique identifier for the message",
"format": "uuid4",
"title": "Unique message ID",
"type": "string"
},
Expand Down Expand Up @@ -132,7 +131,7 @@
"type": "null"
}
],
"default": "1a",
"default": "0.1",
"description": "Schema version of the message",
"title": "Schema Version"
},
Expand Down
2 changes: 1 addition & 1 deletion snews/schema/Detector.schema.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"schema_author": "Supernova Early Warning System (SNEWS)",
"schema_version": "1a",
"schema_version": "0.1",
"$defs": {
"DetectorType": {
"enum": [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"schema_author": "Supernova Early Warning System (SNEWS)",
"schema_version": "1a",
"schema_version": "0.1",
"$defs": {
"Tier": {
"enum": [
Expand Down Expand Up @@ -30,9 +30,8 @@
"description": "Textual identifier for the message",
"title": "Human-readable message ID"
},
"uid": {
"uuid": {
"description": "Unique identifier for the message",
"format": "uuid4",
"title": "Unique message ID",
"type": "string"
},
Expand Down Expand Up @@ -132,7 +131,7 @@
"type": "null"
}
],
"default": "1a",
"default": "0.1",
"description": "Schema version of the message",
"title": "Schema Version"
},
Expand All @@ -156,6 +155,6 @@
"detector_name",
"detector_status"
],
"title": "HeartBeat",
"title": "HeartbeatMessage",
"type": "object"
}
Loading
Loading