diff --git a/core/src/main/scala-2/cats/syntax/MonadOps.scala b/core/src/main/scala-2/cats/syntax/MonadOps.scala index cc2b782378..bb469b311a 100644 --- a/core/src/main/scala-2/cats/syntax/MonadOps.scala +++ b/core/src/main/scala-2/cats/syntax/MonadOps.scala @@ -30,6 +30,8 @@ final class MonadOps[F[_], A](private val fa: F[A]) extends AnyVal { def untilM_(p: F[Boolean])(implicit M: Monad[F]): F[Unit] = M.untilM_(fa)(p) def iterateWhile(p: A => Boolean)(implicit M: Monad[F]): F[A] = M.iterateWhile(fa)(p) def iterateUntil(p: A => Boolean)(implicit M: Monad[F]): F[A] = M.iterateUntil(fa)(p) + def whenM(p: F[Boolean])(implicit M: Monad[F]): F[Unit] = M.whenM(fa)(p) + def unlessM(p: F[Boolean])(implicit M: Monad[F]): F[Unit] = M.unlessM(fa)(p) def flatMapOrKeep[A1 >: A](pfa: PartialFunction[A, F[A1]])(implicit M: Monad[F]): F[A1] = M.flatMapOrKeep[A, A1](fa)(pfa) } diff --git a/core/src/main/scala/cats/Monad.scala b/core/src/main/scala/cats/Monad.scala index c694f854e3..54af1c4136 100644 --- a/core/src/main/scala/cats/Monad.scala +++ b/core/src/main/scala/cats/Monad.scala @@ -162,6 +162,20 @@ trait Monad[F[_]] extends FlatMap[F] with Applicative[F] { tailRecM(branches.toList)(step) } + /** + * Returns the given argument (mapped to Unit) if `cond` evaluates to `false`, otherwise, + * unit lifted into F. + */ + def unlessM[A](f: F[A])(cond: F[Boolean]): F[Unit] = + flatMap(cond)(bool => if (bool) unit else void(f)) + + /** + * Returns the given argument (mapped to Unit) if `cond` evaluates to `true`, otherwise, + * unit lifted into F. + */ + def whenM[A](f: F[A])(cond: F[Boolean]): F[Unit] = + flatMap(cond)(bool => if (bool) void(f) else unit) + /** * Modifies the `A` value in `F[A]` with the supplied function, if the function is defined for the value. * Example: @@ -204,6 +218,8 @@ object Monad { def untilM_(cond: => F[Boolean]): F[Unit] = typeClassInstance.untilM_[A](self)(cond) def iterateWhile(p: A => Boolean): F[A] = typeClassInstance.iterateWhile[A](self)(p) def iterateUntil(p: A => Boolean): F[A] = typeClassInstance.iterateUntil[A](self)(p) + def whenM(cond: F[Boolean]): F[Unit] = typeClassInstance.whenM[A](self)(cond) + def unlessM(cond: F[Boolean]): F[Unit] = typeClassInstance.unlessM[A](self)(cond) } trait AllOps[F[_], A] extends Ops[F, A] with FlatMap.AllOps[F, A] with Applicative.AllOps[F, A] { type TypeClassType <: Monad[F] diff --git a/tests/shared/src/test/scala/cats/tests/MonadSuite.scala b/tests/shared/src/test/scala/cats/tests/MonadSuite.scala index aa5fa28794..675adc8894 100644 --- a/tests/shared/src/test/scala/cats/tests/MonadSuite.scala +++ b/tests/shared/src/test/scala/cats/tests/MonadSuite.scala @@ -67,6 +67,22 @@ class MonadSuite extends CatsSuite { } } + test("whenM should void when true") { + assert(List(1, 2).whenM(List(true)) == List((), ())) + } + + test("whenM should unit when false") { + assert(List(1, 2).whenM(List(false)) == List(())) + } + + test("unlessM should unit when true") { + assert(List(1, 2).unlessM(List(true)) == List(())) + } + + test("unlessM should void when false") { + assert(List(1, 2).unlessM(List(false)) == List((), ())) + } + test("whileM_ stack safety") { val (result, _) = increment.whileM_(StateT.inspect(i => !(i >= 50000))).run(0) assert(result === 50000)