Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions crates/diom-core/src/types/duration_ms.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{
borrow::Cow,
fmt,
num::NonZeroU64,
ops::{Add, AddAssign, Mul},
time::Duration,
};
Expand Down Expand Up @@ -199,3 +200,65 @@ impl ValidateRange<u64> for DurationMs {
Some(*self < Self::from(min))
}
}

/// Non-zero variation of [`DurationMs`].
#[derive(Clone, Copy, PartialEq, Eq)]
pub struct NonZeroDurationMs(NonZeroU64);

impl NonZeroDurationMs {
pub fn get(self) -> DurationMs {
DurationMs(self.0.get())
}
}

impl From<NonZeroU64> for NonZeroDurationMs {
/// Assume the given value represents an integer number of milliseconds
/// and treat it as a `NonZeroDurationMs`.
#[inline]
fn from(millis: NonZeroU64) -> Self {
Self(millis)
}
}

impl fmt::Debug for NonZeroDurationMs {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl Serialize for NonZeroDurationMs {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.0.serialize(serializer)
}
}

impl<'de> Deserialize<'de> for NonZeroDurationMs {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let millis = NonZeroU64::deserialize(deserializer)?;
Ok(Self::from(millis))
}
}

impl JsonSchema for NonZeroDurationMs {
fn schema_name() -> Cow<'static, str> {
"NonZeroDurationMs".into()
}

fn json_schema(_gen: &mut schemars::SchemaGenerator) -> Schema {
json_schema!({
"type": "integer",
"format": "uint64",
"minimum": 1,
"x-subtype": "DurationMs",
})
}

fn inline_schema() -> bool {
true
}
}
4 changes: 3 additions & 1 deletion crates/diom-core/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ mod metadata;
mod unix_timestamp_ms;

pub use self::{
byte_string::ByteString, duration_ms::DurationMs, metadata::Metadata,
byte_string::ByteString,
duration_ms::{DurationMs, NonZeroDurationMs},
metadata::Metadata,
unix_timestamp_ms::UnixTimestampMs,
};

Expand Down
12 changes: 5 additions & 7 deletions src/v1/endpoints/idempotency.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use aide::axum::{ApiRouter, routing::post_with};
use axum::{Extension, extract::State};
use diom_authorization::RequestedOperation;
use diom_core::types::{ByteString, DurationMs, EntityKey, Metadata, UnixTimestampMs};
use diom_core::types::{ByteString, EntityKey, Metadata, NonZeroDurationMs, UnixTimestampMs};
use diom_derive::aide_annotate;
use diom_error::{OptionExt as _, ResultExt};
use diom_id::Module;
Expand Down Expand Up @@ -59,8 +59,7 @@ pub struct IdempotencyStartIn {
pub key: EntityKey,
/// How long to hold the lock on start before releasing it.
#[serde(rename = "lock_period_ms")]
#[validate(range(min = 1))]
pub lock_period: DurationMs,
pub lock_period: NonZeroDurationMs,
}

request_input!(IdempotencyStartIn, "start");
Expand Down Expand Up @@ -112,8 +111,7 @@ pub struct IdempotencyCompleteIn {

/// How long to keep the idempotency response for.
#[serde(rename = "ttl_ms")]
#[validate(range(min = 1))]
pub ttl: DurationMs,
pub ttl: NonZeroDurationMs,
}

request_input!(IdempotencyCompleteIn, "complete");
Expand Down Expand Up @@ -150,7 +148,7 @@ async fn idempotency_start(
.fetch_namespace(data.namespace.as_ref())?
.ok_or_not_found()?;

let operation = TryStartOperation::new(namespace, data.key.to_string(), data.lock_period);
let operation = TryStartOperation::new(namespace, data.key.to_string(), data.lock_period.get());
let response = repl.client_write(operation).await.or_internal_error()?.0?;

Ok(MsgPackOrJson(response.result.into()))
Expand All @@ -173,7 +171,7 @@ async fn idempotency_complete(
data.key.to_string(),
data.response,
data.context,
data.ttl,
data.ttl.get(),
);
repl.client_write(operation).await.or_internal_error()?.0?;

Expand Down
7 changes: 3 additions & 4 deletions src/v1/endpoints/kv.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use aide::axum::{ApiRouter, routing::post_with};
use axum::{Extension, extract::State};
use diom_authorization::RequestedOperation;
use diom_core::types::{ByteString, Consistency, DurationMs, EntityKey, UnixTimestampMs};
use diom_core::types::{ByteString, Consistency, EntityKey, NonZeroDurationMs, UnixTimestampMs};
use diom_derive::aide_annotate;
use diom_error::{OptionExt, ResultExt};
use diom_id::Module;
Expand Down Expand Up @@ -53,8 +53,7 @@ pub struct KvSetIn {

/// Time to live in milliseconds
#[serde(rename = "ttl_ms")]
#[validate(range(min = 1))]
pub ttl: Option<DurationMs>,
pub ttl: Option<NonZeroDurationMs>,

#[serde(default)]
pub behavior: OperationBehavior,
Expand Down Expand Up @@ -142,7 +141,7 @@ async fn kv_set(
namespace,
data.key,
data.value,
data.ttl,
data.ttl.map(NonZeroDurationMs::get),
data.behavior,
data.version,
);
Expand Down
28 changes: 12 additions & 16 deletions src/v1/endpoints/rate_limit.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::num::NonZeroU64;

use aide::axum::{ApiRouter, routing::post_with};
use axum::{Extension, extract::State};
use diom_authorization::RequestedOperation;
use diom_core::types::{DurationMs, EntityKey, UnixTimestampMs};
use diom_core::types::{DurationMs, EntityKey, NonZeroDurationMs, UnixTimestampMs};
use diom_derive::aide_annotate;
use diom_error::{OptionExt, ResultExt};
use diom_id::Module;
Expand Down Expand Up @@ -45,31 +47,28 @@ macro_rules! request_input {
impl From<RateLimitConfig> for TokenBucket {
fn from(val: RateLimitConfig) -> Self {
TokenBucket {
bucket_size: val.capacity,
refill_rate: val.refill_amount,
refill_interval: val.refill_interval,
bucket_size: val.capacity.get(),
refill_rate: val.refill_amount.get(),
refill_interval: val.refill_interval.get(),
}
}
}

#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Validate, JsonSchema)]
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
pub struct RateLimitConfig {
/// Maximum capacity of the bucket
#[validate(range(min = 1))]
pub capacity: u64,
pub capacity: NonZeroU64,

/// Number of tokens to add per refill interval
#[validate(range(min = 1))]
pub refill_amount: u64,
pub refill_amount: NonZeroU64,

/// Interval in milliseconds between refills (minimum 1 millisecond)
#[serde(rename = "refill_interval_ms", default = "default_interval_ms")]
#[validate(range(min = 1))]
pub refill_interval: DurationMs,
pub refill_interval: NonZeroDurationMs,
}

fn default_interval_ms() -> DurationMs {
1000.into()
fn default_interval_ms() -> NonZeroDurationMs {
const { NonZeroU64::new(1000).unwrap() }.into()
}

#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Validate, JsonSchema)]
Expand All @@ -84,7 +83,6 @@ pub struct RateLimitCheckIn {
pub tokens: u64,

/// Rate limiter configuration
#[validate(nested)]
pub config: RateLimitConfig,
}

Expand Down Expand Up @@ -115,7 +113,6 @@ pub struct RateLimitGetRemainingIn {
pub key: EntityKey,

/// Rate limiter configuration
#[validate(nested)]
pub config: RateLimitConfig,
}

Expand Down Expand Up @@ -207,7 +204,6 @@ pub struct RateLimitResetIn {
pub key: EntityKey,

/// Rate limiter configuration
#[validate(nested)]
pub config: RateLimitConfig,
}

Expand Down
Loading