Skip to content

feat: implement Encode, Decode, Type for Arc<str> and Arc<[u8]> (and Rc equivalents) #3675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
79 changes: 79 additions & 0 deletions sqlx-core/src/decode.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
//! Provides [`Decode`] for decoding values from the database.

use std::borrow::Cow;
use std::rc::Rc;
use std::sync::Arc;

use crate::database::Database;
use crate::error::BoxDynError;

Expand Down Expand Up @@ -77,3 +81,78 @@ where
}
}
}

macro_rules! impl_decode_for_smartpointer {
($smart_pointer:tt) => {
impl<'r, DB, T> Decode<'r, DB> for $smart_pointer<T>
where
DB: Database,
T: Decode<'r, DB>,
{
fn decode(value: <DB as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
Ok(Self::new(T::decode(value)?))
}
}

impl<'r, DB> Decode<'r, DB> for $smart_pointer<str>
where
DB: Database,
&'r str: Decode<'r, DB>,
{
fn decode(value: <DB as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
let ref_str = <&str as Decode<DB>>::decode(value)?;
Ok(ref_str.into())
}
}

impl<'r, DB> Decode<'r, DB> for $smart_pointer<[u8]>
where
DB: Database,
&'r [u8]: Decode<'r, DB>,
{
fn decode(value: <DB as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
let ref_str = <&[u8] as Decode<DB>>::decode(value)?;
Ok(ref_str.into())
}
}
};
}

impl_decode_for_smartpointer!(Arc);
impl_decode_for_smartpointer!(Box);
impl_decode_for_smartpointer!(Rc);

// implement `Decode` for Cow<T> for all SQL types
impl<'r, DB, T> Decode<'r, DB> for Cow<'_, T>
where
DB: Database,
T: ToOwned,
<T as ToOwned>::Owned: Decode<'r, DB>,
{
fn decode(value: <DB as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
let owned = <<T as ToOwned>::Owned as Decode<DB>>::decode(value)?;
Ok(Cow::Owned(owned))
}
}

impl<'r, DB> Decode<'r, DB> for Cow<'r, str>
where
DB: Database,
&'r str: Decode<'r, DB>,
{
fn decode(value: <DB as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
let borrowed = <&str as Decode<DB>>::decode(value)?;
Ok(Cow::Borrowed(borrowed))
}
}

impl<'r, DB> Decode<'r, DB> for Cow<'r, [u8]>
where
DB: Database,
&'r [u8]: Decode<'r, DB>,
{
fn decode(value: <DB as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
let borrowed = <&[u8] as Decode<DB>>::decode(value)?;
Ok(Cow::Borrowed(borrowed))
}
}
96 changes: 96 additions & 0 deletions sqlx-core/src/encode.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
//! Provides [`Encode`] for encoding values for the database.

use std::borrow::Cow;
use std::mem;
use std::rc::Rc;
use std::sync::Arc;

use crate::database::Database;
use crate::error::BoxDynError;
Expand Down Expand Up @@ -129,3 +132,96 @@ macro_rules! impl_encode_for_option {
}
};
}

macro_rules! impl_encode_for_smartpointer {
($smart_pointer:ty) => {
impl<'q, T, DB: Database> Encode<'q, DB> for $smart_pointer
where
T: Encode<'q, DB>,
{
#[inline]
fn encode(
self,
buf: &mut <DB as Database>::ArgumentBuffer<'q>,
) -> Result<IsNull, BoxDynError> {
<T as Encode<DB>>::encode_by_ref(self.as_ref(), buf)
}

#[inline]
fn encode_by_ref(
&self,
buf: &mut <DB as Database>::ArgumentBuffer<'q>,
) -> Result<IsNull, BoxDynError> {
<&T as Encode<DB>>::encode(self, buf)
}

#[inline]
fn produces(&self) -> Option<DB::TypeInfo> {
(**self).produces()
}

#[inline]
fn size_hint(&self) -> usize {
(**self).size_hint()
}
}
};
}

impl_encode_for_smartpointer!(Arc<T>);
impl_encode_for_smartpointer!(Box<T>);
impl_encode_for_smartpointer!(Rc<T>);

impl<'q, T, DB: Database> Encode<'q, DB> for Cow<'_, T>
where
T: Encode<'q, DB>,
T: ToOwned,
{
#[inline]
fn encode(self, buf: &mut <DB as Database>::ArgumentBuffer<'q>) -> Result<IsNull, BoxDynError> {
<&T as Encode<DB>>::encode_by_ref(&self.as_ref(), buf)
}

#[inline]
fn encode_by_ref(
&self,
buf: &mut <DB as Database>::ArgumentBuffer<'q>,
) -> Result<IsNull, BoxDynError> {
<&T as Encode<DB>>::encode_by_ref(&self.as_ref(), buf)
}

#[inline]
fn produces(&self) -> Option<DB::TypeInfo> {
<&T as Encode<DB>>::produces(&self.as_ref())
}

#[inline]
fn size_hint(&self) -> usize {
<&T as Encode<DB>>::size_hint(&self.as_ref())
}
}

#[macro_export]
macro_rules! forward_encode_impl {
($for_type:ty, $forward_to:ty, $db:ident) => {
impl<'q> Encode<'q, $db> for $for_type {
fn encode_by_ref(
&self,
buf: &mut <$db as Database>::ArgumentBuffer<'q>,
) -> Result<IsNull, BoxDynError> {
<$forward_to as Encode<$db>>::encode(&**self, buf)
}
}
};
($for_type:ty, $forward_to:ty, $db:ident, $before:expr) => {
impl<'q> Encode<'q, $db> for $for_type {
fn encode_by_ref(
&self,
buf: &mut <$db as Database>::ArgumentBuffer<'q>,
) -> Result<IsNull, BoxDynError> {
let value = $before(self);
<$forward_to as Encode<$db>>::encode(value, buf)
}
}
};
}
39 changes: 39 additions & 0 deletions sqlx-core/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
//! To represent nullable SQL types, `Option<T>` is supported where `T` implements `Type`.
//! An `Option<T>` represents a potentially `NULL` value from SQL.

use std::{borrow::Cow, rc::Rc, sync::Arc};

use crate::database::Database;
use crate::type_info::TypeInfo;

Expand Down Expand Up @@ -241,3 +243,40 @@ impl<T: Type<DB>, DB: Database> Type<DB> for Option<T> {
ty.is_null() || <T as Type<DB>>::compatible(ty)
}
}

macro_rules! impl_type_for_smartpointer {
($smart_pointer:ty) => {
impl<T, DB: Database> Type<DB> for $smart_pointer
where
T: Type<DB>,
T: ?Sized,
{
fn type_info() -> DB::TypeInfo {
<T as Type<DB>>::type_info()
}

fn compatible(ty: &DB::TypeInfo) -> bool {
<T as Type<DB>>::compatible(ty)
}
}
};
}

impl_type_for_smartpointer!(Arc<T>);
impl_type_for_smartpointer!(Box<T>);
impl_type_for_smartpointer!(Rc<T>);

impl<T, DB: Database> Type<DB> for Cow<'_, T>
where
T: Type<DB>,
T: ToOwned,
T: ?Sized,
{
fn type_info() -> DB::TypeInfo {
<T as Type<DB>>::type_info()
}

fn compatible(ty: &DB::TypeInfo) -> bool {
<T as Type<DB>>::compatible(ty)
}
}
33 changes: 11 additions & 22 deletions sqlx-mysql/src/types/bytes.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
use std::borrow::Cow;
use std::rc::Rc;
use std::sync::Arc;

use sqlx_core::database::Database;

use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
Expand Down Expand Up @@ -40,28 +46,6 @@ impl<'r> Decode<'r, MySql> for &'r [u8] {
}
}

impl Type<MySql> for Box<[u8]> {
fn type_info() -> MySqlTypeInfo {
<&[u8] as Type<MySql>>::type_info()
}

fn compatible(ty: &MySqlTypeInfo) -> bool {
<&[u8] as Type<MySql>>::compatible(ty)
}
}

impl Encode<'_, MySql> for Box<[u8]> {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
<&[u8] as Encode<MySql>>::encode(self.as_ref(), buf)
}
}

impl<'r> Decode<'r, MySql> for Box<[u8]> {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
<&[u8] as Decode<MySql>>::decode(value).map(Box::from)
}
}

impl Type<MySql> for Vec<u8> {
fn type_info() -> MySqlTypeInfo {
<[u8] as Type<MySql>>::type_info()
Expand All @@ -83,3 +67,8 @@ impl Decode<'_, MySql> for Vec<u8> {
<&[u8] as Decode<MySql>>::decode(value).map(ToOwned::to_owned)
}
}

forward_encode_impl!(Arc<[u8]>, &[u8], MySql);
forward_encode_impl!(Rc<[u8]>, &[u8], MySql);
forward_encode_impl!(Box<[u8]>, &[u8], MySql);
forward_encode_impl!(Cow<'_, [u8]>, &[u8], MySql);
64 changes: 11 additions & 53 deletions sqlx-mysql/src/types/str.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
use std::borrow::Cow;
use std::rc::Rc;
use std::sync::Arc;

use sqlx_core::database::Database;

use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::io::MySqlBufMutExt;
use crate::protocol::text::{ColumnFlags, ColumnType};
use crate::types::Type;
use crate::{MySql, MySqlTypeInfo, MySqlValueRef};
use std::borrow::Cow;

impl Type<MySql> for str {
fn type_info() -> MySqlTypeInfo {
Expand Down Expand Up @@ -46,28 +51,6 @@ impl<'r> Decode<'r, MySql> for &'r str {
}
}

impl Type<MySql> for Box<str> {
fn type_info() -> MySqlTypeInfo {
<&str as Type<MySql>>::type_info()
}

fn compatible(ty: &MySqlTypeInfo) -> bool {
<&str as Type<MySql>>::compatible(ty)
}
}

impl Encode<'_, MySql> for Box<str> {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
<&str as Encode<MySql>>::encode(&**self, buf)
}
}

impl<'r> Decode<'r, MySql> for Box<str> {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
<&str as Decode<MySql>>::decode(value).map(Box::from)
}
}

impl Type<MySql> for String {
fn type_info() -> MySqlTypeInfo {
<str as Type<MySql>>::type_info()
Expand All @@ -78,39 +61,14 @@ impl Type<MySql> for String {
}
}

impl Encode<'_, MySql> for String {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
<&str as Encode<MySql>>::encode(&**self, buf)
}
}

impl Decode<'_, MySql> for String {
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
<&str as Decode<MySql>>::decode(value).map(ToOwned::to_owned)
}
}

impl Type<MySql> for Cow<'_, str> {
fn type_info() -> MySqlTypeInfo {
<&str as Type<MySql>>::type_info()
}

fn compatible(ty: &MySqlTypeInfo) -> bool {
<&str as Type<MySql>>::compatible(ty)
}
}

impl Encode<'_, MySql> for Cow<'_, str> {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
match self {
Cow::Borrowed(str) => <&str as Encode<MySql>>::encode(*str, buf),
Cow::Owned(str) => <&str as Encode<MySql>>::encode(&**str, buf),
}
}
}

impl<'r> Decode<'r, MySql> for Cow<'r, str> {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
value.as_str().map(Cow::Borrowed)
}
}
forward_encode_impl!(Arc<str>, &str, MySql);
forward_encode_impl!(Rc<str>, &str, MySql);
forward_encode_impl!(Cow<'_, str>, &str, MySql);
forward_encode_impl!(Box<str>, &str, MySql);
forward_encode_impl!(String, &str, MySql);
Loading
Loading