Skip to content

Commit 19567d8

Browse files
authored
implement serde::{Serialize, Deserialize} for HashSet
1 parent 7426f7e commit 19567d8

File tree

1 file changed

+103
-2
lines changed

1 file changed

+103
-2
lines changed

src/serde_impls.rs

Lines changed: 103 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
use serde::de::{MapAccess, Visitor};
1+
use serde::de::{MapAccess, SeqAccess, Visitor};
22
use serde::{Deserialize, Deserializer, Serialize, Serializer};
33

44
use std::fmt::{self, Formatter};
55
use std::hash::{BuildHasher, Hash};
66
use std::marker::PhantomData;
77

8-
use crate::{Guard, HashMap, HashMapRef};
8+
use crate::{Guard, HashMap, HashMapRef, HashSet, HashSetRef};
99

1010
struct MapVisitor<K, V, S> {
1111
_marker: PhantomData<HashMap<K, V, S>>,
@@ -94,9 +94,93 @@ where
9494
}
9595
}
9696

97+
struct SetVisitor<K, S> {
98+
_marker: PhantomData<HashSet<K, S>>,
99+
}
100+
101+
impl<K, S, G> Serialize for HashSetRef<'_, K, S, G>
102+
where
103+
K: Serialize + Hash + Eq,
104+
G: Guard,
105+
S: BuildHasher,
106+
{
107+
fn serialize<Sr>(&self, serializer: Sr) -> Result<Sr::Ok, Sr::Error>
108+
where
109+
Sr: Serializer,
110+
{
111+
serializer.collect_seq(self)
112+
}
113+
}
114+
115+
impl<K, S> Serialize for HashSet<K, S>
116+
where
117+
K: Serialize + Hash + Eq,
118+
S: BuildHasher,
119+
{
120+
fn serialize<Sr>(&self, serializer: Sr) -> Result<Sr::Ok, Sr::Error>
121+
where
122+
Sr: Serializer,
123+
{
124+
self.pin().serialize(serializer)
125+
}
126+
}
127+
128+
impl<'de, K, S> Deserialize<'de> for HashSet<K, S>
129+
where
130+
K: Deserialize<'de> + Hash + Eq,
131+
S: Default + BuildHasher,
132+
{
133+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
134+
where
135+
D: Deserializer<'de>,
136+
{
137+
deserializer.deserialize_seq(SetVisitor::new())
138+
}
139+
}
140+
141+
impl<K, S> SetVisitor<K, S> {
142+
pub(crate) fn new() -> Self {
143+
Self {
144+
_marker: PhantomData,
145+
}
146+
}
147+
}
148+
149+
impl<'de, K, S> Visitor<'de> for SetVisitor<K, S>
150+
where
151+
K: Deserialize<'de> + Hash + Eq,
152+
S: Default + BuildHasher,
153+
{
154+
type Value = HashSet<K, S>;
155+
156+
fn expecting(&self, f: &mut Formatter<'_>) -> fmt::Result {
157+
write!(f, "a map")
158+
}
159+
160+
fn visit_seq<M>(self, mut access: M) -> Result<Self::Value, M::Error>
161+
where
162+
M: SeqAccess<'de>,
163+
{
164+
let values = match access.size_hint() {
165+
Some(size) => HashSet::with_capacity_and_hasher(size, S::default()),
166+
None => HashSet::default(),
167+
};
168+
169+
{
170+
let values = values.pin();
171+
while let Some(key) = access.next_element()? {
172+
values.insert(key);
173+
}
174+
}
175+
176+
Ok(values)
177+
}
178+
}
179+
97180
#[cfg(test)]
98181
mod test {
99182
use crate::HashMap;
183+
use crate::HashSet;
100184

101185
#[test]
102186
fn test_map() {
@@ -114,4 +198,21 @@ mod test {
114198

115199
assert_eq!(map, deserialized);
116200
}
201+
202+
#[test]
203+
fn test_set() {
204+
let map: HashSet<u8> = HashSet::new();
205+
let guard = map.guard();
206+
207+
map.insert(0, &guard);
208+
map.insert(1, &guard);
209+
map.insert(2, &guard);
210+
map.insert(3, &guard);
211+
map.insert(4, &guard);
212+
213+
let serialized = serde_json::to_string(&map).unwrap();
214+
let deserialized = serde_json::from_str(&serialized).unwrap();
215+
216+
assert_eq!(map, deserialized);
217+
}
117218
}

0 commit comments

Comments
 (0)