Skip to content

Commit ff041d6

Browse files
committed
Relax Sized bound for Decode, Encode
1 parent b6521ae commit ff041d6

File tree

12 files changed

+219
-75
lines changed

12 files changed

+219
-75
lines changed

sqlx-core/src/decode.rs

+53-8
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ where
8383
}
8484

8585
macro_rules! impl_decode_for_smartpointer {
86-
($smart_pointer:ty) => {
87-
impl<'r, DB, T> Decode<'r, DB> for $smart_pointer
86+
($smart_pointer:tt) => {
87+
impl<'r, DB, T> Decode<'r, DB> for $smart_pointer<T>
8888
where
8989
DB: Database,
9090
T: Decode<'r, DB>,
@@ -93,21 +93,66 @@ macro_rules! impl_decode_for_smartpointer {
9393
Ok(Self::new(T::decode(value)?))
9494
}
9595
}
96+
97+
impl<'r, DB> Decode<'r, DB> for $smart_pointer<str>
98+
where
99+
DB: Database,
100+
&'r str: Decode<'r, DB>,
101+
{
102+
fn decode(value: <DB as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
103+
let ref_str = <&str as Decode<DB>>::decode(value)?;
104+
Ok(ref_str.into())
105+
}
106+
}
107+
108+
impl<'r, DB> Decode<'r, DB> for $smart_pointer<[u8]>
109+
where
110+
DB: Database,
111+
&'r [u8]: Decode<'r, DB>,
112+
{
113+
fn decode(value: <DB as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
114+
let ref_str = <&[u8] as Decode<DB>>::decode(value)?;
115+
Ok(ref_str.into())
116+
}
117+
}
96118
};
97119
}
98120

99-
impl_decode_for_smartpointer!(Arc<T>);
100-
impl_decode_for_smartpointer!(Box<T>);
101-
impl_decode_for_smartpointer!(Rc<T>);
121+
impl_decode_for_smartpointer!(Arc);
122+
impl_decode_for_smartpointer!(Box);
123+
impl_decode_for_smartpointer!(Rc);
102124

103125
// implement `Decode` for Cow<T> for all SQL types
104126
impl<'r, DB, T> Decode<'r, DB> for Cow<'_, T>
105127
where
106128
DB: Database,
107-
T: Decode<'r, DB>,
108-
T: ToOwned<Owned = T>,
129+
T: ToOwned,
130+
<T as ToOwned>::Owned: Decode<'r, DB>,
131+
{
132+
fn decode(value: <DB as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
133+
let owned = <<T as ToOwned>::Owned as Decode<DB>>::decode(value)?;
134+
Ok(Cow::Owned(owned))
135+
}
136+
}
137+
138+
impl<'r, DB> Decode<'r, DB> for Cow<'r, str>
139+
where
140+
DB: Database,
141+
&'r str: Decode<'r, DB>,
142+
{
143+
fn decode(value: <DB as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
144+
let borrowed = <&str as Decode<DB>>::decode(value)?;
145+
Ok(Cow::Borrowed(borrowed))
146+
}
147+
}
148+
149+
impl<'r, DB> Decode<'r, DB> for Cow<'r, [u8]>
150+
where
151+
DB: Database,
152+
&'r [u8]: Decode<'r, DB>,
109153
{
110154
fn decode(value: <DB as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
111-
Ok(Cow::Owned(T::decode(value)?))
155+
let borrowed = <&[u8] as Decode<DB>>::decode(value)?;
156+
Ok(Cow::Borrowed(borrowed))
112157
}
113158
}

sqlx-core/src/encode.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -174,29 +174,29 @@ impl_encode_for_smartpointer!(Rc<T>);
174174

175175
impl<'q, T, DB: Database> Encode<'q, DB> for Cow<'_, T>
176176
where
177-
T: Encode<'q, DB>,
178-
T: ToOwned<Owned = T>,
177+
for<'a> &'a T: Encode<'q, DB>,
178+
T: ToOwned,
179179
{
180180
#[inline]
181181
fn encode(self, buf: &mut <DB as Database>::ArgumentBuffer<'q>) -> Result<IsNull, BoxDynError> {
182-
<T as Encode<DB>>::encode_by_ref(self.as_ref(), buf)
182+
<&T as Encode<DB>>::encode_by_ref(&self.as_ref(), buf)
183183
}
184184

185185
#[inline]
186186
fn encode_by_ref(
187187
&self,
188188
buf: &mut <DB as Database>::ArgumentBuffer<'q>,
189189
) -> Result<IsNull, BoxDynError> {
190-
<&T as Encode<DB>>::encode(self, buf)
190+
<&T as Encode<DB>>::encode_by_ref(&self.as_ref(), buf)
191191
}
192192

193193
#[inline]
194194
fn produces(&self) -> Option<DB::TypeInfo> {
195-
(**self).produces()
195+
<&T as Encode<DB>>::produces(&self.as_ref())
196196
}
197197

198198
#[inline]
199199
fn size_hint(&self) -> usize {
200-
(**self).size_hint()
200+
<&T as Encode<DB>>::size_hint(&self.as_ref())
201201
}
202202
}

sqlx-core/src/types/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ impl_type_for_smartpointer!(Rc<T>);
269269
impl<T, DB: Database> Type<DB> for Cow<'_, T>
270270
where
271271
T: Type<DB>,
272-
T: ToOwned<Owned = T>,
272+
T: ToOwned,
273273
T: ?Sized,
274274
{
275275
fn type_info() -> DB::TypeInfo {

sqlx-mysql/src/types/bytes.rs

-6
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,6 @@ impl Encode<'_, MySql> for Box<[u8]> {
4646
}
4747
}
4848

49-
impl<'r> Decode<'r, MySql> for Box<[u8]> {
50-
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
51-
<&[u8] as Decode<MySql>>::decode(value).map(Box::from)
52-
}
53-
}
54-
5549
impl Type<MySql> for Vec<u8> {
5650
fn type_info() -> MySqlTypeInfo {
5751
<[u8] as Type<MySql>>::type_info()

sqlx-mysql/src/types/str.rs

+5-10
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
use std::borrow::Cow;
2+
13
use crate::decode::Decode;
24
use crate::encode::{Encode, IsNull};
35
use crate::error::BoxDynError;
46
use crate::io::MySqlBufMutExt;
57
use crate::protocol::text::{ColumnFlags, ColumnType};
68
use crate::types::Type;
79
use crate::{MySql, MySqlTypeInfo, MySqlValueRef};
8-
use std::borrow::Cow;
910

1011
impl Type<MySql> for str {
1112
fn type_info() -> MySqlTypeInfo {
@@ -52,12 +53,6 @@ impl Encode<'_, MySql> for Box<str> {
5253
}
5354
}
5455

55-
impl<'r> Decode<'r, MySql> for Box<str> {
56-
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
57-
<&str as Decode<MySql>>::decode(value).map(Box::from)
58-
}
59-
}
60-
6156
impl Type<MySql> for String {
6257
fn type_info() -> MySqlTypeInfo {
6358
<str as Type<MySql>>::type_info()
@@ -89,8 +84,8 @@ impl Encode<'_, MySql> for Cow<'_, str> {
8984
}
9085
}
9186

92-
impl<'r> Decode<'r, MySql> for Cow<'r, str> {
93-
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
94-
value.as_str().map(Cow::Borrowed)
87+
impl Encode<'_, MySql> for Cow<'_, [u8]> {
88+
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
89+
<&[u8] as Encode<MySql>>::encode(self.as_ref(), buf)
9590
}
9691
}

sqlx-postgres/src/types/bytes.rs

-9
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,6 @@ fn text_hex_decode_input(value: PgValueRef<'_>) -> Result<&[u8], BoxDynError> {
8080
.map_err(Into::into)
8181
}
8282

83-
impl Decode<'_, Postgres> for Box<[u8]> {
84-
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
85-
Ok(match value.format() {
86-
PgValueFormat::Binary => Box::from(value.as_bytes()?),
87-
PgValueFormat::Text => Box::from(hex::decode(text_hex_decode_input(value)?)?),
88-
})
89-
}
90-
}
91-
9283
impl Decode<'_, Postgres> for Vec<u8> {
9384
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
9485
Ok(match value.format() {

sqlx-postgres/src/types/str.rs

+12-18
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,6 @@ impl Encode<'_, Postgres> for &'_ str {
8282
}
8383
}
8484

85-
impl Encode<'_, Postgres> for Cow<'_, str> {
86-
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
87-
match self {
88-
Cow::Borrowed(str) => <&str as Encode<Postgres>>::encode(*str, buf),
89-
Cow::Owned(str) => <&str as Encode<Postgres>>::encode(&**str, buf),
90-
}
91-
}
92-
}
93-
9485
impl Encode<'_, Postgres> for Box<str> {
9586
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
9687
<&str as Encode<Postgres>>::encode(&**self, buf)
@@ -109,20 +100,23 @@ impl<'r> Decode<'r, Postgres> for &'r str {
109100
}
110101
}
111102

112-
impl<'r> Decode<'r, Postgres> for Cow<'r, str> {
113-
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
114-
Ok(Cow::Borrowed(value.as_str()?))
103+
impl Decode<'_, Postgres> for String {
104+
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
105+
Ok(value.as_str()?.to_owned())
115106
}
116107
}
117108

118-
impl<'r> Decode<'r, Postgres> for Box<str> {
119-
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
120-
Ok(Box::from(value.as_str()?))
109+
impl Encode<'_, Postgres> for Cow<'_, str> {
110+
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
111+
match self {
112+
Cow::Borrowed(str) => <&str as Encode<Postgres>>::encode(*str, buf),
113+
Cow::Owned(str) => <&str as Encode<Postgres>>::encode(&**str, buf),
114+
}
121115
}
122116
}
123117

124-
impl Decode<'_, Postgres> for String {
125-
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
126-
Ok(value.as_str()?.to_owned())
118+
impl Encode<'_, Postgres> for Cow<'_, [u8]> {
119+
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
120+
<&[u8] as Encode<Postgres>>::encode(self.as_ref(), buf)
127121
}
128122
}

sqlx-sqlite/src/types/bytes.rs

-6
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,6 @@ impl Encode<'_, Sqlite> for Box<[u8]> {
5353
}
5454
}
5555

56-
impl Decode<'_, Sqlite> for Box<[u8]> {
57-
fn decode(value: SqliteValueRef<'_>) -> Result<Self, BoxDynError> {
58-
Ok(Box::from(value.blob()))
59-
}
60-
}
61-
6256
impl Type<Sqlite> for Vec<u8> {
6357
fn type_info() -> SqliteTypeInfo {
6458
<&[u8] as Type<Sqlite>>::type_info()

sqlx-sqlite/src/types/str.rs

+14-9
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,6 @@ impl Encode<'_, Sqlite> for Box<str> {
4949
}
5050
}
5151

52-
impl Decode<'_, Sqlite> for Box<str> {
53-
fn decode(value: SqliteValueRef<'_>) -> Result<Self, BoxDynError> {
54-
value.text().map(Box::from)
55-
}
56-
}
57-
5852
impl Type<Sqlite> for String {
5953
fn type_info() -> SqliteTypeInfo {
6054
<&str as Type<Sqlite>>::type_info()
@@ -101,8 +95,19 @@ impl<'q> Encode<'q, Sqlite> for Cow<'q, str> {
10195
}
10296
}
10397

104-
impl<'r> Decode<'r, Sqlite> for Cow<'r, str> {
105-
fn decode(value: SqliteValueRef<'r>) -> Result<Self, BoxDynError> {
106-
value.text().map(Cow::Borrowed)
98+
impl<'q> Encode<'q, Sqlite> for Cow<'q, [u8]> {
99+
fn encode(self, args: &mut Vec<SqliteArgumentValue<'q>>) -> Result<IsNull, BoxDynError> {
100+
args.push(SqliteArgumentValue::Blob(self));
101+
102+
Ok(IsNull::No)
103+
}
104+
105+
fn encode_by_ref(
106+
&self,
107+
args: &mut Vec<SqliteArgumentValue<'q>>,
108+
) -> Result<IsNull, BoxDynError> {
109+
args.push(SqliteArgumentValue::Blob(self.clone()));
110+
111+
Ok(IsNull::No)
107112
}
108113
}

tests/mysql/types.rs

+50-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
extern crate time_ as time;
22

3+
use std::borrow::Cow;
34
use std::net::SocketAddr;
5+
use std::rc::Rc;
46
#[cfg(feature = "rust_decimal")]
57
use std::str::FromStr;
8+
use std::sync::Arc;
69

710
use sqlx::mysql::MySql;
8-
use sqlx::{Executor, Row};
11+
use sqlx::{Executor, FromRow, Row};
912

1013
use sqlx::types::Text;
1114

@@ -384,3 +387,49 @@ CREATE TEMPORARY TABLE user_login (
384387

385388
Ok(())
386389
}
390+
391+
#[sqlx_macros::test]
392+
async fn test_smartpointers() -> anyhow::Result<()> {
393+
let mut conn = new::<MySql>().await?;
394+
395+
let user_age: (Arc<i32>, Cow<'static, i32>, Box<i32>, i32) =
396+
sqlx::query_as("SELECT ?, ?, ?, ?")
397+
.bind(Arc::new(1i32))
398+
.bind(Cow::<'_, i32>::Borrowed(&2i32))
399+
.bind(Box::new(3i32))
400+
.bind(Rc::new(4i32))
401+
.fetch_one(&mut conn)
402+
.await?;
403+
404+
assert!(user_age.0.as_ref() == &1);
405+
assert!(user_age.1.as_ref() == &2);
406+
assert!(user_age.2.as_ref() == &3);
407+
assert!(user_age.3 == 4);
408+
Ok(())
409+
}
410+
411+
#[sqlx_macros::test]
412+
async fn test_str_slice() -> anyhow::Result<()> {
413+
let mut conn = new::<MySql>().await?;
414+
415+
let box_str: Box<str> = "John".into();
416+
let box_slice: Box<[u8]> = [1, 2, 3, 4].into();
417+
let cow_str: Cow<'static, str> = "Phil".into();
418+
let cow_slice: Cow<'static, [u8]> = Cow::Borrowed(&[1, 2, 3, 4]);
419+
// : (Box<str>, Box<[u8]>, Cow<'static, str>, Cow<'static, [u8]>)
420+
let row = sqlx::query("SELECT ?, ?, ?, ?")
421+
.bind(&box_str)
422+
.bind(&box_slice)
423+
.bind(&cow_str)
424+
.bind(&cow_slice)
425+
.fetch_one(&mut conn)
426+
.await?;
427+
428+
let data: (Box<str>, Box<[u8]>, Cow<'_, str>, Cow<'_, [u8]>) = FromRow::from_row(&row)?;
429+
430+
assert!(data.0 == box_str);
431+
assert!(data.1 == box_slice);
432+
assert!(data.2 == cow_str);
433+
assert!(data.3 == cow_slice);
434+
Ok(())
435+
}

0 commit comments

Comments
 (0)