Skip to content

Commit

Permalink
add HashMap::try_insert_with
Browse files Browse the repository at this point in the history
  • Loading branch information
ibraheemdev committed Jan 15, 2025
1 parent 87311b3 commit 1229f93
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 3 deletions.
50 changes: 50 additions & 0 deletions src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,42 @@ where
}
}

/// Tries to insert a key and value computed from a closure into the map,
/// and returns a reference to the value that was inserted.
///
/// If the map already had this key present, nothing is updated, and
/// the existing value is returned.
///
/// # Examples
///
/// ```
/// use papaya::HashMap;
///
/// let map = HashMap::new();
/// let map = map.pin();
///
/// assert_eq!(map.try_insert_with(37, || "a").unwrap(), &"a");
///
/// let current = map.try_insert_with(37, || "b").unwrap_err();
/// assert_eq!(current, &"a");
/// ```
#[inline]
pub fn try_insert_with<'g, F>(
&self,
key: K,
f: F,
guard: &'g impl Guard,
) -> Result<&'g V, &'g V>
where
F: FnOnce() -> V,
K: 'g,
{
self.raw.check_guard(guard);

// Safety: Checked the guard above.
unsafe { self.raw.try_insert_with(key, f, guard) }
}

/// Returns a reference to the value corresponding to the key, or inserts a default value.
///
/// If the given key is present, the corresponding value is returned. If it is not present,
Expand Down Expand Up @@ -1341,6 +1377,20 @@ where
}
}

/// Tries to insert a key and value computed from a closure into the map,
/// and returns a reference to the value that was inserted.
///
/// See [`HashMap::try_insert_with`] for details.
/// ```
#[inline]
pub fn try_insert_with<F>(&self, key: K, f: F) -> Result<&V, &V>
where
F: FnOnce() -> V,
{
// Safety: `self.guard` was created from our map.
unsafe { self.map.raw.try_insert_with(key, f, &self.guard) }
}

/// Returns a reference to the value corresponding to the key, or inserts a default value.
///
/// See [`HashMap::get_or_insert`] for details.
Expand Down
58 changes: 56 additions & 2 deletions src/raw/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1175,6 +1175,14 @@ where
}

// Restores the state if an operation fails.
//
// This allows the result of the compute closure with a given input to be memoized.
// This is useful at it avoids calling the closure multiple times if an update needs
// to be retried in a new table.
//
// Additionally, update and insert operations are memoized separately, although this
// is not guaranteed in the public API. This means that internal methods can rely on
// `compute(None)` being called at most once.
#[inline]
fn restore(&mut self, input: Option<*mut Entry<K, V>>, output: Operation<V, T>) {
match input {
Expand Down Expand Up @@ -1238,6 +1246,42 @@ where
K: Hash + Eq,
S: BuildHasher,
{
/// Tries to insert a key and value computed from a closure into the map,
/// and returns a reference to the value that was inserted.
//
// # Safety
//
// The guard must be valid to use with this map.
#[inline]
pub unsafe fn try_insert_with<'g, F>(
&self,
key: K,
f: F,
guard: &'g impl Guard,
) -> Result<&'g V, &'g V>
where
F: FnOnce() -> V,
K: 'g,
{
let mut f = Some(f);
let compute = |entry| match entry {
// There is already an existing value.
Some((_, current)) => Operation::Abort(current),

// Insert the initial value.
//
// Note that this case is guaranteed to be executed at most
// once as insert values are memoized, so this can never panic.
None => Operation::Insert((f.take().unwrap())()),
};

match self.compute(key, compute, guard) {
Compute::Aborted(current) => Err(current),
Compute::Inserted(_, value) => Ok(value),
_ => unreachable!(),
}
}

/// Returns a reference to the value corresponding to the key, or inserts a default value
/// computed from a closure.
//
Expand All @@ -1254,7 +1298,11 @@ where
let compute = |entry| match entry {
// Return the existing value.
Some((_, current)) => Operation::Abort(current),

// Insert the initial value.
//
// Note that this case is guaranteed to be executed at most
// once as insert values are memoized, so this can never panic.
None => Operation::Insert((f.take().unwrap())()),
};

Expand Down Expand Up @@ -1282,8 +1330,10 @@ where
K: 'g,
{
let compute = |entry| match entry {
Some((_, value)) => Operation::Insert(update(value)),
// There is nothing to update.
None => Operation::Abort(()),
// Perform the update.
Some((_, value)) => Operation::Insert(update(value)),
};

match self.compute(key, compute, guard) {
Expand Down Expand Up @@ -1317,7 +1367,11 @@ where
let compute = |entry| match entry {
// Perform the update.
Some((_, value)) => Operation::Insert::<_, ()>(update(value)),

// Insert the initial value.
//
// Note that this case is guaranteed to be executed at most
// once as insert values are memoized, so this can never panic.
None => Operation::Insert((f.take().unwrap())()),
};

Expand Down Expand Up @@ -1359,7 +1413,7 @@ where
// Deallocate the entry if it was not inserted.
if matches!(result, Compute::Removed(..) | Compute::Aborted(_)) {
if let LazyEntry::Init(entry) = entry {
// Safety: We allocated this box above and it was not inserted into the map.
// Safety: The entry was allocated but not inserted into the map.
let _ = unsafe { Box::from_raw(entry) };
}
}
Expand Down
70 changes: 69 additions & 1 deletion tests/basic.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Adapted from: https://github.com/jonhoo/flurry/blob/main/tests/basic.rs

use papaya::{Compute, HashMap, Operation};
use papaya::{Compute, HashMap, OccupiedError, Operation};

use std::hash::{BuildHasher, BuildHasherDefault, Hasher};
use std::sync::Arc;
Expand Down Expand Up @@ -247,6 +247,74 @@ fn get_or_insert() {
});
}

#[test]
fn try_insert() {
with_map::<usize, usize>(|map| {
let map = map();
let guard = map.guard();

assert_eq!(map.try_insert(42, 1, &guard), Ok(&1));
assert_eq!(map.len(), 1);

{
let guard = map.guard();
let e = map.get(&42, &guard).unwrap();
assert_eq!(e, &1);
}

assert_eq!(
map.try_insert(42, 2, &guard),
Err(OccupiedError {
current: &1,
not_inserted: 2
})
);
assert_eq!(map.len(), 1);

{
let guard = map.guard();
let e = map.get(&42, &guard).unwrap();
assert_eq!(e, &1);
}

assert_eq!(map.try_insert(43, 2, &guard), Ok(&2));
});
}

#[test]
fn try_insert_with() {
with_map::<usize, usize>(|map| {
let map = map();
let guard = map.guard();

map.try_insert_with(42, || 1, &guard).unwrap();
assert_eq!(map.len(), 1);

{
let guard = map.guard();
let e = map.get(&42, &guard).unwrap();
assert_eq!(e, &1);
}

let mut called = false;
let insert = || {
called = true;
2
};
assert_eq!(map.try_insert_with(42, insert, &guard), Err(&1));
assert_eq!(map.len(), 1);
assert!(!called);

{
let guard = map.guard();
let e = map.get(&42, &guard).unwrap();
assert_eq!(e, &1);
}

assert_eq!(map.try_insert_with(43, || 2, &guard), Ok(&2));
});
}

#[test]
fn get_or_insert_with() {
with_map::<usize, usize>(|map| {
Expand Down

0 comments on commit 1229f93

Please sign in to comment.