Skip to content

Commit 9689f17

Browse files
authored
Merge pull request #299 from muzarski/enforce-get-coordinator
Implement exposing/enforcing coordinator for request
2 parents b1ff429 + d760c24 commit 9689f17

File tree

11 files changed

+451
-69
lines changed

11 files changed

+451
-69
lines changed

Makefile

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ SCYLLA_TEST_FILTER := $(subst ${SPACE},${EMPTY},ClusterTests.*\
1212
:SerialConsistencyTests.*\
1313
:HeartbeatTests.*\
1414
:PreparedTests.*\
15+
:StatementNoClusterTests.*\
16+
:StatementTests.*\
1517
:NamedParametersTests.*\
1618
:CassandraTypes/CassandraTypesTests/*.Integration_Cassandra_*\
1719
:ControlConnectionTests.*\
@@ -27,6 +29,7 @@ SCYLLA_TEST_FILTER := $(subst ${SPACE},${EMPTY},ClusterTests.*\
2729
:PreparedMetadataTests.*\
2830
:UseKeyspaceCaseSensitiveTests.*\
2931
:ServerSideFailureTests.*\
32+
:ServerSideFailureThreeNodeTests.*\
3033
:TimestampTests.*\
3134
:MetricsTests.Integration_Cassandra_ErrorsRequestTimeouts\
3235
:MetricsTests.Integration_Cassandra_Requests\
@@ -69,6 +72,8 @@ CASSANDRA_TEST_FILTER := $(subst ${SPACE},${EMPTY},ClusterTests.*\
6972
:SerialConsistencyTests.*\
7073
:HeartbeatTests.*\
7174
:PreparedTests.*\
75+
:StatementNoClusterTests.*\
76+
:StatementTests.*\
7277
:NamedParametersTests.*\
7378
:CassandraTypes/CassandraTypesTests/*.Integration_Cassandra_*\
7479
:ControlConnectionTests.*\
@@ -83,6 +88,7 @@ CASSANDRA_TEST_FILTER := $(subst ${SPACE},${EMPTY},ClusterTests.*\
8388
:PreparedMetadataTests.*\
8489
:UseKeyspaceCaseSensitiveTests.*\
8590
:ServerSideFailureTests.*\
91+
:ServerSideFailureThreeNodeTests.*\
8692
:TimestampTests.*\
8793
:MetricsTests.Integration_Cassandra_ErrorsRequestTimeouts\
8894
:MetricsTests.Integration_Cassandra_Requests\

scylla-rust-wrapper/Cargo.lock

Lines changed: 8 additions & 13 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

scylla-rust-wrapper/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ categories = ["database"]
1010
license = "MIT OR Apache-2.0"
1111

1212
[dependencies]
13-
scylla = { git = "https://github.com/scylladb/scylla-rust-driver.git", rev = "v1.1.0", features = [
13+
scylla = { git = "https://github.com/scylladb/scylla-rust-driver.git", rev = "v1.2.0", features = [
1414
"openssl-010",
1515
"metrics",
1616
] }
@@ -34,7 +34,7 @@ bindgen = "0.65"
3434
chrono = "0.4.20"
3535

3636
[dev-dependencies]
37-
scylla-proxy = { git = "https://github.com/scylladb/scylla-rust-driver.git", rev = "v1.1.0" }
37+
scylla-proxy = { git = "https://github.com/scylladb/scylla-rust-driver.git", rev = "v1.2.0" }
3838
bytes = "1.10.0"
3939

4040
assert_matches = "1.5.0"

scylla-rust-wrapper/src/future.rs

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,18 @@ use crate::cass_error::CassErrorMessage;
55
use crate::cass_error::ToCassError;
66
use crate::execution_error::CassErrorResult;
77
use crate::prepared::CassPrepared;
8-
use crate::query_result::CassResult;
8+
use crate::query_result::{CassNode, CassResult};
99
use crate::types::*;
1010
use crate::uuid::CassUuid;
1111
use futures::future;
1212
use std::future::Future;
1313
use std::mem;
1414
use std::os::raw::c_void;
15-
use std::sync::{Arc, Condvar, Mutex};
15+
use std::sync::{Arc, Condvar, Mutex, OnceLock};
1616
use tokio::task::JoinHandle;
1717
use tokio::time::Duration;
1818

19+
#[derive(Debug)]
1920
pub enum CassResultValue {
2021
Empty,
2122
QueryResult(Arc<CassResult>),
@@ -50,14 +51,14 @@ impl BoundCallback {
5051

5152
#[derive(Default)]
5253
struct CassFutureState {
53-
value: Option<CassFutureResult>,
5454
err_string: Option<String>,
5555
callback: Option<BoundCallback>,
5656
join_handle: Option<JoinHandle<()>>,
5757
}
5858

5959
pub struct CassFuture {
6060
state: Mutex<CassFutureState>,
61+
result: OnceLock<CassFutureResult>,
6162
wait_for_value: Condvar,
6263
}
6364

@@ -87,14 +88,18 @@ impl CassFuture {
8788
) -> Arc<CassFuture> {
8889
let cass_fut = Arc::new(CassFuture {
8990
state: Mutex::new(Default::default()),
91+
result: OnceLock::new(),
9092
wait_for_value: Condvar::new(),
9193
});
9294
let cass_fut_clone = Arc::clone(&cass_fut);
9395
let join_handle = RUNTIME.spawn(async move {
9496
let r = fut.await;
9597
let maybe_cb = {
9698
let mut guard = cass_fut_clone.state.lock().unwrap();
97-
guard.value = Some(r);
99+
cass_fut_clone
100+
.result
101+
.set(r)
102+
.expect("Tried to resolve future result twice!");
98103
// Take the callback and call it after releasing the lock
99104
guard.callback.take()
100105
};
@@ -115,16 +120,17 @@ impl CassFuture {
115120

116121
pub fn new_ready(r: CassFutureResult) -> Arc<Self> {
117122
Arc::new(CassFuture {
118-
state: Mutex::new(CassFutureState {
119-
value: Some(r),
120-
..Default::default()
121-
}),
123+
state: Mutex::new(CassFutureState::default()),
124+
result: OnceLock::from(r),
122125
wait_for_value: Condvar::new(),
123126
})
124127
}
125128

126-
pub fn with_waited_result<T>(&self, f: impl FnOnce(&mut CassFutureResult) -> T) -> T {
127-
self.with_waited_state(|s| f(s.value.as_mut().unwrap()))
129+
pub fn with_waited_result<'s, T>(&'s self, f: impl FnOnce(&'s CassFutureResult) -> T) -> T
130+
where
131+
T: 's,
132+
{
133+
self.with_waited_state(|_| f(self.result.get().unwrap()))
128134
}
129135

130136
/// Awaits the future until completion.
@@ -153,7 +159,7 @@ impl CassFuture {
153159
guard = self
154160
.wait_for_value
155161
.wait_while(guard, |state| {
156-
state.value.is_none() && state.join_handle.is_none()
162+
self.result.get().is_none() && state.join_handle.is_none()
157163
})
158164
// unwrap: Error appears only when mutex is poisoned.
159165
.unwrap();
@@ -171,10 +177,10 @@ impl CassFuture {
171177

172178
fn with_waited_result_timed<T>(
173179
&self,
174-
f: impl FnOnce(&mut CassFutureResult) -> T,
180+
f: impl FnOnce(&CassFutureResult) -> T,
175181
timeout_duration: Duration,
176182
) -> Result<T, FutureError> {
177-
self.with_waited_state_timed(|s| f(s.value.as_mut().unwrap()), timeout_duration)
183+
self.with_waited_state_timed(|_| f(self.result.get().unwrap()), timeout_duration)
178184
}
179185

180186
/// Tries to await the future with a given timeout.
@@ -242,7 +248,7 @@ impl CassFuture {
242248
let (guard_result, timeout_result) = self
243249
.wait_for_value
244250
.wait_timeout_while(guard, remaining_timeout, |state| {
245-
state.value.is_none() && state.join_handle.is_none()
251+
self.result.get().is_none() && state.join_handle.is_none()
246252
})
247253
// unwrap: Error appears only when mutex is poisoned.
248254
.unwrap();
@@ -275,7 +281,7 @@ impl CassFuture {
275281
return CassError::CASS_ERROR_LIB_CALLBACK_ALREADY_SET;
276282
}
277283
let bound_cb = BoundCallback { cb, data };
278-
if lock.value.is_some() {
284+
if self.result.get().is_some() {
279285
// The value is already available, we need to call the callback ourselves
280286
mem::drop(lock);
281287
bound_cb.invoke(self_ptr);
@@ -345,8 +351,7 @@ pub unsafe extern "C" fn cass_future_ready(
345351
return cass_false;
346352
};
347353

348-
let state_guard = future.state.lock().unwrap();
349-
match state_guard.value {
354+
match future.result.get() {
350355
None => cass_false,
351356
Some(_) => cass_true,
352357
}
@@ -361,7 +366,7 @@ pub unsafe extern "C" fn cass_future_error_code(
361366
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
362367
};
363368

364-
future.with_waited_result(|r: &mut CassFutureResult| match r {
369+
future.with_waited_result(|r: &CassFutureResult| match r {
365370
Ok(CassResultValue::QueryError(err)) => err.to_cass_error(),
366371
Err((err, _)) => *err,
367372
_ => CassError::CASS_OK,
@@ -380,7 +385,7 @@ pub unsafe extern "C" fn cass_future_error_message(
380385
};
381386

382387
future.with_waited_state(|state: &mut CassFutureState| {
383-
let value = &state.value;
388+
let value = future.result.get();
384389
let msg = state
385390
.err_string
386391
.get_or_insert_with(|| match value.as_ref().unwrap() {
@@ -407,7 +412,7 @@ pub unsafe extern "C" fn cass_future_get_result(
407412
};
408413

409414
future
410-
.with_waited_result(|r: &mut CassFutureResult| -> Option<Arc<CassResult>> {
415+
.with_waited_result(|r: &CassFutureResult| -> Option<Arc<CassResult>> {
411416
match r.as_ref().ok()? {
412417
CassResultValue::QueryResult(qr) => Some(Arc::clone(qr)),
413418
_ => None,
@@ -426,7 +431,7 @@ pub unsafe extern "C" fn cass_future_get_error_result(
426431
};
427432

428433
future
429-
.with_waited_result(|r: &mut CassFutureResult| -> Option<Arc<CassErrorResult>> {
434+
.with_waited_result(|r: &CassFutureResult| -> Option<Arc<CassErrorResult>> {
430435
match r.as_ref().ok()? {
431436
CassResultValue::QueryError(qr) => Some(Arc::clone(qr)),
432437
_ => None,
@@ -445,7 +450,7 @@ pub unsafe extern "C" fn cass_future_get_prepared(
445450
};
446451

447452
future
448-
.with_waited_result(|r: &mut CassFutureResult| -> Option<Arc<CassPrepared>> {
453+
.with_waited_result(|r: &CassFutureResult| -> Option<Arc<CassPrepared>> {
449454
match r.as_ref().ok()? {
450455
CassResultValue::Prepared(p) => Some(Arc::clone(p)),
451456
_ => None,
@@ -464,7 +469,7 @@ pub unsafe extern "C" fn cass_future_tracing_id(
464469
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
465470
};
466471

467-
future.with_waited_result(|r: &mut CassFutureResult| match r {
472+
future.with_waited_result(|r: &CassFutureResult| match r {
468473
Ok(CassResultValue::QueryResult(result)) => match result.tracing_id {
469474
Some(id) => {
470475
unsafe { *tracing_id = CassUuid::from(id) };
@@ -476,6 +481,24 @@ pub unsafe extern "C" fn cass_future_tracing_id(
476481
})
477482
}
478483

484+
#[unsafe(no_mangle)]
485+
pub unsafe extern "C" fn cass_future_coordinator(
486+
future_raw: CassBorrowedSharedPtr<CassFuture, CMut>,
487+
) -> CassBorrowedSharedPtr<CassNode, CConst> {
488+
let Some(future) = ArcFFI::as_ref(future_raw) else {
489+
tracing::error!("Provided null future to cass_future_coordinator!");
490+
return RefFFI::null();
491+
};
492+
493+
future.with_waited_result(|r| match r {
494+
Ok(CassResultValue::QueryResult(result)) => {
495+
// unwrap: Coordinator is `None` only for tests.
496+
RefFFI::as_ptr(result.coordinator.as_ref().unwrap())
497+
}
498+
_ => RefFFI::null(),
499+
})
500+
}
501+
479502
#[cfg(test)]
480503
mod tests {
481504
use crate::testing::{assert_cass_error_eq, assert_cass_future_error_message_eq};

scylla-rust-wrapper/src/integration_testing.rs

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@ use scylla::errors::{RequestAttemptError, RequestError};
77
use scylla::observability::history::{AttemptId, HistoryListener, RequestId, SpeculativeId};
88
use scylla::policies::retry::RetryDecision;
99

10-
use crate::argconv::{BoxFFI, CMut, CassBorrowedExclusivePtr};
10+
use crate::argconv::{
11+
ArcFFI, BoxFFI, CConst, CMut, CassBorrowedExclusivePtr, CassBorrowedSharedPtr,
12+
};
1113
use crate::batch::CassBatch;
1214
use crate::cluster::CassCluster;
15+
use crate::future::{CassFuture, CassResultValue};
1316
use crate::statement::{BoundStatement, CassStatement};
1417
use crate::types::{cass_int32_t, cass_uint16_t, cass_uint64_t, size_t};
1518

@@ -60,8 +63,47 @@ pub unsafe extern "C" fn testing_cluster_get_contact_points(
6063
}
6164

6265
#[unsafe(no_mangle)]
63-
pub unsafe extern "C" fn testing_free_contact_points(contact_points: *mut c_char) {
64-
let _ = unsafe { CString::from_raw(contact_points) };
66+
pub unsafe extern "C" fn testing_future_get_host(
67+
future_raw: CassBorrowedSharedPtr<CassFuture, CConst>,
68+
host: *mut *mut c_char,
69+
host_length: *mut size_t,
70+
) {
71+
let Some(future) = ArcFFI::as_ref(future_raw) else {
72+
tracing::error!("Provided null future pointer to testing_future_get_host!");
73+
unsafe {
74+
*host = std::ptr::null_mut();
75+
*host_length = 0;
76+
};
77+
return;
78+
};
79+
80+
future.with_waited_result(|r| match r {
81+
Ok(CassResultValue::QueryResult(result)) => {
82+
// unwrap: Coordinator is none only for unit tests.
83+
let coordinator = result.coordinator.as_ref().unwrap();
84+
85+
let ip_addr_str = coordinator.node().address.ip().to_string();
86+
let length = ip_addr_str.len();
87+
88+
let ip_addr_cstr = CString::new(ip_addr_str).expect(
89+
"String obtained from IpAddr::to_string() should not contain any nul bytes!",
90+
);
91+
92+
unsafe {
93+
*host = ip_addr_cstr.into_raw();
94+
*host_length = length as size_t
95+
};
96+
}
97+
_ => unsafe {
98+
*host = std::ptr::null_mut();
99+
*host_length = 0;
100+
},
101+
})
102+
}
103+
104+
#[unsafe(no_mangle)]
105+
pub unsafe extern "C" fn testing_free_cstring(s: *mut c_char) {
106+
let _ = unsafe { CString::from_raw(s) };
65107
}
66108

67109
#[derive(Debug)]

0 commit comments

Comments
 (0)