Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions service_capacity_modeling/capacity_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def register_group(self, group: Callable[[], Dict[str, CapacityModel]]):
self.register_model(name, model)

def register_model(self, name: str, capacity_model: CapacityModel):
capacity_model.validate_implementation()
self._models[name] = capacity_model

@property
Expand Down
58 changes: 58 additions & 0 deletions service_capacity_modeling/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from functools import lru_cache
from typing import Any
from typing import Dict
from typing import get_args
from typing import get_origin
from typing import List
from typing import Optional
from typing import Sequence
Expand Down Expand Up @@ -874,6 +876,62 @@ class CapacityDesires(ExcludeUnsetModel):
# by setting the buffers you want to preserve in buffers.derived
buffers: Buffers = Buffers()

def validate_required_fields_set(self):
"""Validate that all required fields have been explicitly set

This ensures that default_desires implementations in models explicitly
set all required fields rather than relying on defaults.
"""
# Fields to skip validation for (deprecated or allowed to use defaults)
SKIP_FIELDS = {"service_tier"}

# Model types that are leaf value objects and shouldn't be recursed into
# These are primitive types that if set at all is considered valid
SKIP_TYPES = {"Interval", "FixedInterval", "Buffers", "Consistency"}

def validate_base_model_dfs(model: BaseModel, path: List[str]):
"""Recursively validate that all required fields are explicitly set"""
fields_set = model.model_fields_set
validation_suffix = (
f" is required and must be explicitly set for "
f"{model.__class__.__name__}"
)

for field_name, field_info in model.__class__.model_fields.items():
# Skip fields in deny list
if field_name in SKIP_FIELDS:
continue

# Skip optional/nullable fields (fields with Optional[...] type)
# Use typing.get_origin and get_args to check for Union with None
annotation = field_info.annotation
origin = get_origin(annotation)
if origin is Union:
args = get_args(annotation)
if type(None) in args:
# This is an Optional field, skip it
continue

# Build full path for error messages
full_path = ".".join(path + [field_name])

# Check if field was explicitly set
if field_name not in fields_set:
raise ValueError(f"{full_path}{validation_suffix}")

# Recursively validate nested models (but skip leaf value objects)
field_value = getattr(model, field_name)
if isinstance(field_value, BaseModel):
if (
field_value.__class__ is Buffers
and "default" not in field_value.model_fields_set
):
raise ValueError(f"{full_path}.default{validation_suffix}")
if field_value.__class__.__name__ not in SKIP_TYPES:
validate_base_model_dfs(field_value, path + [field_name])

validate_base_model_dfs(self, [])

@property
def reference_shape(self) -> Instance:
if not self.current_clusters:
Expand Down
11 changes: 11 additions & 0 deletions service_capacity_modeling/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,17 @@ def capacity_plan(
(_, _, _, _, _) = (instance, drive, context, desires, extra_model_arguments)
return None

def validate_implementation(self):
"""Validate the implementation of the model"""
desires = self.default_desires(CapacityDesires(), {})
model_name = self.__class__.__name__
try:
desires.validate_required_fields_set()
except ValueError as ve:
raise ValueError(
f"Model {model_name} has invalid default_desires: {ve}"
) from ve

@staticmethod
def regret(
regret_params: CapacityRegretParameters,
Expand Down
21 changes: 21 additions & 0 deletions service_capacity_modeling/models/org/netflix/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from service_capacity_modeling.interface import CurrentClusterCapacity
from service_capacity_modeling.interface import DataShape
from service_capacity_modeling.interface import Drive
from service_capacity_modeling.interface import fixed_float
from service_capacity_modeling.interface import FixedInterval
from service_capacity_modeling.interface import GlobalConsistency
from service_capacity_modeling.interface import Instance
Expand Down Expand Up @@ -837,6 +838,14 @@ def default_desires(user_desires, extra_model_arguments: Dict[str, Any]):
target_consistency=AccessConsistency.eventual,
),
),
estimated_read_per_second=Interval(
low=128, mid=1024, high=2048, confidence=0.95
),
estimated_write_per_second=Interval(
low=64, mid=256, high=1024, confidence=0.95
),
estimated_read_parallelism=certain_int(1),
estimated_write_parallelism=certain_int(1),
estimated_mean_read_size_bytes=Interval(
low=128, mid=1024, high=65536, confidence=0.95
),
Expand Down Expand Up @@ -885,6 +894,8 @@ def default_desires(user_desires, extra_model_arguments: Dict[str, Any]):
# We dynamically allocate the C* JVM memory in the plan
# but account for the Priam sidecar here
reserved_instance_app_mem_gib=4,
reserved_instance_system_mem_gib=12,
durability_slo_order=fixed_float(10000),
),
buffers=buffers,
)
Expand Down Expand Up @@ -915,6 +926,14 @@ def default_desires(user_desires, extra_model_arguments: Dict[str, Any]):
estimated_mean_write_latency_ms=Interval(
low=0.2, mid=0.6, high=2, confidence=0.98
),
estimated_read_per_second=Interval(
low=128, mid=1024, high=2048, confidence=0.95
),
estimated_write_per_second=Interval(
low=64, mid=256, high=1024, confidence=0.95
),
estimated_read_parallelism=certain_int(1),
estimated_write_parallelism=certain_int(1),
# Assume they're scanning -> slow reads
read_latency_slo_ms=FixedInterval(
minimum_value=1,
Expand Down Expand Up @@ -945,6 +964,8 @@ def default_desires(user_desires, extra_model_arguments: Dict[str, Any]):
# We dynamically allocate the C* JVM memory in the plan
# but account for the Priam sidecar here
reserved_instance_app_mem_gib=4,
reserved_instance_system_mem_gib=12,
durability_slo_order=fixed_float(10000),
),
buffers=buffers,
)
Expand Down
230 changes: 230 additions & 0 deletions tests/test_model_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
from typing import Any
from typing import Dict

import pytest

from service_capacity_modeling.interface import AccessConsistency
from service_capacity_modeling.interface import AccessPattern
from service_capacity_modeling.interface import Buffer
from service_capacity_modeling.interface import BufferComponent
from service_capacity_modeling.interface import Buffers
from service_capacity_modeling.interface import CapacityDesires
from service_capacity_modeling.interface import certain_float
from service_capacity_modeling.interface import Consistency
from service_capacity_modeling.interface import DataShape
from service_capacity_modeling.interface import FixedInterval
from service_capacity_modeling.interface import GlobalConsistency
from service_capacity_modeling.interface import QueryPattern
from service_capacity_modeling.models import CapacityModel

valid_query_pattern = QueryPattern(
access_pattern=AccessPattern.latency,
access_consistency=GlobalConsistency(
same_region=Consistency(
target_consistency=AccessConsistency.read_your_writes,
staleness_slo_sec=FixedInterval(low=0, mid=0.1, high=1),
),
cross_region=Consistency(
target_consistency=AccessConsistency.best_effort,
staleness_slo_sec=FixedInterval(low=10, mid=60, high=600),
),
),
estimated_read_per_second=certain_float(1000),
estimated_write_per_second=certain_float(100),
estimated_mean_read_latency_ms=certain_float(1),
estimated_mean_write_latency_ms=certain_float(1),
estimated_mean_read_size_bytes=certain_float(1024),
estimated_mean_write_size_bytes=certain_float(512),
estimated_read_parallelism=certain_float(1),
estimated_write_parallelism=certain_float(1),
read_latency_slo_ms=FixedInterval(low=0.4, mid=4, high=10, confidence=0.98),
write_latency_slo_ms=FixedInterval(low=0.4, mid=4, high=10, confidence=0.98),
)

valid_data_shape = DataShape(
estimated_state_size_gib=certain_float(100),
estimated_working_set_percent=certain_float(0.8),
estimated_compression_ratio=certain_float(1),
reserved_instance_app_mem_gib=2,
reserved_instance_system_mem_gib=1,
durability_slo_order=FixedInterval(
low=1000, mid=10000, high=100000, confidence=0.98
),
)


class ValidModel(CapacityModel):
"""A model that properly sets all required fields"""

@staticmethod
def default_desires(
user_desires: CapacityDesires, extra_model_arguments: Dict[str, Any]
):
return CapacityDesires(
query_pattern=valid_query_pattern,
data_shape=valid_data_shape,
buffers=Buffers(default=Buffer(ratio=1.5)),
)


class InvalidModelMissingBuffer(CapacityModel):
"""A model that doesn't set buffers"""

@staticmethod
def default_desires(
user_desires: CapacityDesires, extra_model_arguments: Dict[str, Any]
):
return CapacityDesires(
query_pattern=valid_query_pattern,
data_shape=valid_data_shape,
# Missing buffers
)


class InvalidModelMissingBufferDefault(CapacityModel):
"""A model that doesn't set buffers"""

@staticmethod
def default_desires(
user_desires: CapacityDesires, extra_model_arguments: Dict[str, Any]
):
return CapacityDesires(
query_pattern=valid_query_pattern,
data_shape=valid_data_shape,
buffers=Buffers(
desired={
"compute": Buffer(ratio=1.5, components=[BufferComponent.compute])
}
), # Missing default buffer
)


class InvalidModelMissingTopLevel(CapacityModel):
"""A model that doesn't set query_pattern"""

@staticmethod
def default_desires(
user_desires: CapacityDesires, extra_model_arguments: Dict[str, Any]
):
# Only sets data_shape, missing query_pattern
return CapacityDesires(
data_shape=DataShape(
estimated_state_size_gib=certain_float(100),
estimated_working_set_percent=certain_float(0.8),
),
)


class InvalidModelMissingDataShape(CapacityModel):
"""A model that sets query_pattern but leaves nested fields as defaults"""

@staticmethod
def default_desires(
user_desires: CapacityDesires, extra_model_arguments: Dict[str, Any]
):
# Sets query_pattern but doesn't explicitly set data shape
return CapacityDesires(
query_pattern=valid_query_pattern,
)


class InvalidModelMissingQueryPattern(CapacityModel):
"""A model that sets query_pattern but leaves nested fields as defaults"""

@staticmethod
def default_desires(
user_desires: CapacityDesires, extra_model_arguments: Dict[str, Any]
):
# Sets query_pattern but doesn't explicitly set required query pattern
return CapacityDesires(
data_shape=valid_data_shape,
)


class InvalidModelPartiallySetDataModels(CapacityModel):
"""A model that partially sets nested fields but misses some"""

@staticmethod
def default_desires(
user_desires: CapacityDesires, extra_model_arguments: Dict[str, Any]
):
return CapacityDesires(
query_pattern=valid_query_pattern,
data_shape=DataShape(
estimated_state_size_gib=certain_float(100),
estimated_working_set_percent=certain_float(0.8),
),
)


class InvalidModelPartiallySetQueryPattern(CapacityModel):
"""A model that partially sets nested fields but misses some"""

@staticmethod
def default_desires(
user_desires: CapacityDesires, extra_model_arguments: Dict[str, Any]
):
return CapacityDesires(
query_pattern=QueryPattern(
access_pattern=AccessPattern.latency,
# Missing access_consistency and other required fields
estimated_mean_read_latency_ms=certain_float(1),
),
data_shape=valid_data_shape,
)


def test_valid_model():
"""Test that a properly implemented model passes validation"""
model = ValidModel()
# Should not raise
model.validate_implementation()


def test_invalid_model_missing_top_level():
"""Test that missing top-level required field is caught"""
model = InvalidModelMissingTopLevel()
with pytest.raises(ValueError, match="query_pattern is required"):
model.validate_implementation()


def test_invalid_model_missing_nested():
"""Test that missing nested required fields are caught"""
model = InvalidModelMissingQueryPattern()
with pytest.raises(
ValueError, match="query_pattern is required and must be explicitly set"
):
model.validate_implementation()

model = InvalidModelMissingDataShape()
with pytest.raises(
ValueError, match="data_shape is required and must be explicitly set"
):
model.validate_implementation()

model = InvalidModelPartiallySetQueryPattern()
with pytest.raises(
ValueError, match="query_pattern\\..*is required and must be explicitly set"
):
model.validate_implementation()

model = InvalidModelPartiallySetDataModels()
with pytest.raises(
ValueError, match="data_shape\\..*is required and must be explicitly set"
):
model.validate_implementation()


def test_invalid_model_buffers():
"""Test that partially set nested fields are caught"""
model = InvalidModelMissingBuffer()
with pytest.raises(
ValueError, match="buffers is required and must be explicitly set"
):
model.validate_implementation()

model = InvalidModelMissingBufferDefault()
with pytest.raises(
ValueError, match="buffers\\.default is required and must be explicitly set"
):
model.validate_implementation()
Loading