diff --git a/common/src/main/scala/no/ndla/common/model/domain/myndla/auth/AuthUtility.scala b/common/src/main/scala/no/ndla/common/model/domain/myndla/auth/AuthUtility.scala deleted file mode 100644 index 1918b90ed0..0000000000 --- a/common/src/main/scala/no/ndla/common/model/domain/myndla/auth/AuthUtility.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Part of NDLA common - * Copyright (C) 2024 NDLA - * - * See LICENSE - * - */ - -package no.ndla.common.model.domain.myndla.auth - -import sttp.model.headers.{AuthenticationScheme, WWWAuthenticateChallenge} -import sttp.tapir.* -import sttp.tapir.CodecFormat.TextPlain -import sttp.tapir.EndpointInput.{AuthInfo, AuthType} - -import scala.collection.immutable.ListMap - -object AuthUtility { - private val authScheme = AuthenticationScheme.Bearer.name - private def filterHeaders(headers: List[String]) = headers.filter(_.toLowerCase.startsWith(authScheme.toLowerCase)) - private def stringPrefixWithSpace = Mapping.stringPrefixCaseInsensitiveForList(authScheme + " ") - val feideTokenAuthCodec: Codec[List[String], Option[String], TextPlain] = { - val codec = implicitly[Codec[List[String], Option[String], CodecFormat.TextPlain]] - Codec - .id[List[String], CodecFormat.TextPlain](codec.format, Schema.binary) - .map(filterHeaders(_))(identity) - .map(stringPrefixWithSpace) - .mapDecode(codec.decode)(codec.encode) - .schema(codec.schema) - } - - def feideOauth() = { - val authType: AuthType.ScopedOAuth2 = EndpointInput - .AuthType - .OAuth2(None, None, ListMap.empty, None) - .requiredScopes(Seq.empty) - - EndpointInput.Auth( - input = sttp.tapir.header("FeideAuthorization")(using feideTokenAuthCodec), - challenge = WWWAuthenticateChallenge.bearer, - authType = authType, - info = AuthInfo.Empty.securitySchemeName("oauth2"), - ) - } - -} diff --git a/network/src/main/scala/no/ndla/network/tapir/SwaggerController.scala b/network/src/main/scala/no/ndla/network/tapir/SwaggerController.scala index 6f2f5a26ac..12219d62e2 100644 --- a/network/src/main/scala/no/ndla/network/tapir/SwaggerController.scala +++ b/network/src/main/scala/no/ndla/network/tapir/SwaggerController.scala @@ -74,8 +74,10 @@ class SwaggerController(services: List[TapirController], swaggerInfo: SwaggerInf val options = OpenAPIDocsOptions.default val docs = OpenAPIDocsInterpreter(options).serverEndpointsToOpenAPI(swaggerEndpoints, info) val generatedComponents = docs.components.getOrElse(Components.Empty) - val newComponents = generatedComponents.copy(securitySchemes = ListMap("oauth2" -> Right(securityScheme))) - val docsWithComponents = docs.components(newComponents).asJson + val newComponents = generatedComponents.copy(securitySchemes = + generatedComponents.securitySchemes ++ ListMap("oauth2" -> Right(securityScheme)) + ) + val docsWithComponents = docs.components(newComponents).asJson docsWithComponents.asJson } diff --git a/network/src/main/scala/no/ndla/network/tapir/TapirController.scala b/network/src/main/scala/no/ndla/network/tapir/TapirController.scala index 5909103655..37e15e4aed 100644 --- a/network/src/main/scala/no/ndla/network/tapir/TapirController.scala +++ b/network/src/main/scala/no/ndla/network/tapir/TapirController.scala @@ -12,40 +12,28 @@ import cats.implicits.catsSyntaxEitherId import com.typesafe.scalalogging.StrictLogging import io.circe.{Decoder, Encoder} import no.ndla.common.SchemaImplicits -import no.ndla.common.model.domain.myndla.auth.AuthUtility +import no.ndla.common.model.api.myndla.MyNDLAUserDTO import no.ndla.network.clients.MyNDLAProvider -import no.ndla.network.model.{ - CombinedUser, - CombinedUserRequired, - CombinedUserWithBoth, - CombinedUserWithMyNDLAUser, - FeideUserWrapper, - HttpRequestException, - OptionalCombinedUser, -} -import no.ndla.network.tapir.auth.{Permission, TokenUser} -import sttp.model.StatusCode -import sttp.monad.MonadError -import sttp.tapir.* -import sttp.tapir.server.{PartialServerEndpoint, ServerEndpoint} +import no.ndla.network.model.* import no.ndla.network.tapir.NoNullJsonPrinter.jsonBody import no.ndla.network.tapir.TapirUtil.errorOutputVariantFor -import no.ndla.network.tapir.auth.TokenUser.{filterHeaders, stringPrefixWithSpace} -import sttp.model.headers.WWWAuthenticateChallenge +import no.ndla.network.tapir.auth.{FeideAuth, NdlaAuth, Permission, TokenUser} +import sttp.model.StatusCode import sttp.shared.Identity -import sttp.tapir.CodecFormat.TextPlain -import sttp.tapir.EndpointInput.{AuthInfo, AuthType} +import sttp.tapir.* +import sttp.tapir.server.{PartialServerEndpoint, ServerEndpoint} -import scala.collection.immutable.ListMap import scala.util.{Failure, Success} abstract class TapirController(using myNDLAApiClient: MyNDLAProvider, errorHelpers: ErrorHelpers, errorHandling: ErrorHandling, -) extends TapirErrorHandling - with StrictLogging - with SchemaImplicits { +) extends TapirErrorHandling, + NdlaAuth, + FeideAuth, + StrictLogging, + SchemaImplicits { type Eff[A] = Identity[A] val enableSwagger: Boolean = true val serviceName: String = this.getClass.getSimpleName @@ -87,136 +75,71 @@ abstract class TapirController(using case None => errorHelpers.unauthorized.asLeft } - private def encodeFeideUserWrapper(user: FeideUserWrapper): String = user.token - private def decodeFeideUserWrapper(s: String): DecodeResult[FeideUserWrapper] = { - myNDLAApiClient.getDomainUser(s) match { - case Failure(ex) => DecodeResult.Error(s, ex) - case Success(user) => DecodeResult.Value(FeideUserWrapper(s, Some(user))) - } - } - - private implicit val userinfoCodec: Codec[String, FeideUserWrapper, TextPlain] = Codec - .string - .mapDecode(decodeFeideUserWrapper)(encodeFeideUserWrapper) - private val feideHeaderCodec = implicitly[Codec[List[String], Option[FeideUserWrapper], CodecFormat.TextPlain]] - private val authCodec: Codec[List[String], Option[FeideUserWrapper], TextPlain] = Codec - .id[List[String], CodecFormat.TextPlain](feideHeaderCodec.format, Schema.binary) - .map(filterHeaders)(identity) - .map(stringPrefixWithSpace) - .mapDecode(feideHeaderCodec.decode)(feideHeaderCodec.encode) - .schema(feideHeaderCodec.schema) - - private val feideWrapperAuth: EndpointInput.Auth[Option[FeideUserWrapper], AuthType.OAuth2] = { - EndpointInput.Auth( - input = sttp.tapir.header("FeideAuthorization")(using authCodec), - challenge = WWWAuthenticateChallenge.bearer, - authType = EndpointInput.AuthType.OAuth2(None, None, ListMap.empty, None), - info = AuthInfo.Empty.securitySchemeName("oauth2"), - ) - } - private val unauthorizedErrorOutput = errorOutputVariantFor(StatusCode.Unauthorized.code) private val forbiddenErrorOutput = errorOutputVariantFor(StatusCode.Forbidden.code) extension [A, I, E, O, R](self: Endpoint[Unit, I, AllErrors, O, R]) { def requirePermission[F[_]]( requiredPermission: Permission* - ): PartialServerEndpoint[Option[TokenUser], TokenUser, I, AllErrors, O, R, F] = { - val endpointWithPermissionErrors = self - .errorOutVariantPrepend(unauthorizedErrorOutput) - .errorOutVariantPrepend(forbiddenErrorOutput) - val newEndpoint = endpointWithPermissionErrors.securityIn(TokenUser.oauth2Input(requiredPermission)) - val authFunc = requireScope(requiredPermission*) - val securityLogic = (m: MonadError[F]) => (a: Option[TokenUser]) => m.unit(authFunc(a)) - PartialServerEndpoint(newEndpoint, securityLogic) - } - - def withFeideUser[F[_]] - : PartialServerEndpoint[Option[FeideUserWrapper], FeideUserWrapper, I, AllErrors, O, R, F] = { - val newEndpoint = self.securityIn(feideWrapperAuth) - val authFunc: Option[FeideUserWrapper] => Either[AllErrors, FeideUserWrapper] = { - case Some(value) => value.asRight - case None => errorHelpers.unauthorized.asLeft + ): PartialServerEndpoint[Option[TokenUser], TokenUser, I, AllErrors, O, R, F] = self + .errorOutVariantPrepend(unauthorizedErrorOutput) + .errorOutVariantPrepend(forbiddenErrorOutput) + .securityIn(ndlaOptionalAuth(requiredPermission)) + .serverSecurityLogicPure(requireScope(requiredPermission*)) + + def withFeideUser[F[_]]: PartialServerEndpoint[FeideUserWrapper, FeideUserWrapper, I, AllErrors, O, R, F] = self + .securityIn(feideRequiredAuth) + .serverSecurityLogicPure { + case user @ FeideUserWrapper(_, Some(_)) => user.asRight + case _ => errorHelpers.unauthorized.asLeft } - val securityLogic = (m: MonadError[F]) => (a: Option[FeideUserWrapper]) => m.unit(authFunc(a)) - PartialServerEndpoint(newEndpoint, securityLogic) - } def withOptionalFeideUser[F[_]] - : PartialServerEndpoint[Option[FeideUserWrapper], Option[FeideUserWrapper], I, AllErrors, O, R, F] = { - val newEndpoint = self.securityIn(feideWrapperAuth) - val authFunc: Option[FeideUserWrapper] => Either[AllErrors, Option[FeideUserWrapper]] = { - case None => None.asRight - case someUser => someUser.asRight + : PartialServerEndpoint[Option[FeideUserWrapper], Option[FeideUserWrapper], I, AllErrors, O, R, F] = self + .securityIn(feideOptionalAuth) + .serverSecurityLogicPure { + case user @ Some(FeideUserWrapper(_, Some(_))) => user.asRight + case _ => None.asRight + } + + private def getMyNdlaUser(token: String): Option[MyNDLAUserDTO] = + myNDLAApiClient.getUserWithFeideToken(token) match { + case Failure(ex: HttpRequestException) if ex.code == 401 || ex.code == 403 => None + case Failure(ex) => + logger.warn("Got unexpected exception when fetching myndla user", ex) + None + case Success(user) => Some(user) } - val securityLogic = (m: MonadError[F]) => (a: Option[FeideUserWrapper]) => m.unit(authFunc(a)) - PartialServerEndpoint(newEndpoint, securityLogic) - } def withOptionalMyNDLAUserOrTokenUser[F[_]] - : PartialServerEndpoint[(Option[TokenUser], Option[String]), CombinedUser, I, AllErrors, O, R, F] = { - val newEndpoint = self.securityIn(TokenUser.oauth2Input(Seq.empty)).securityIn(AuthUtility.feideOauth()) - - val authFunc: ((Option[TokenUser], Option[String])) => Either[AllErrors, CombinedUser] = - (userInputOptions: (Option[TokenUser], Option[String])) => { - val maybeUser = userInputOptions._1 - val maybeToken = userInputOptions._2 - - val myndlaUser = maybeToken.flatMap { token => - myNDLAApiClient.getUserWithFeideToken(token) match { - case Failure(ex: HttpRequestException) if ex.code == 401 || ex.code == 403 => None - case Failure(ex) => - logger.warn("Got unexpected exception when fetching myndla user", ex) - None - case Success(user) => Some(user) - } - } - - val combinedUser = OptionalCombinedUser(maybeUser, myndlaUser) - Right(combinedUser) - } - val securityLogic = (m: MonadError[F]) => (a: (Option[TokenUser], Option[String])) => m.unit(authFunc(a)) - PartialServerEndpoint(newEndpoint, securityLogic) - } + : PartialServerEndpoint[(Option[TokenUser], Option[String]), CombinedUser, I, AllErrors, O, R, F] = self + .securityIn(ndlaOptionalAuth) + .securityIn(feideOptionalUncheckedAuth) + .serverSecurityLogicPure { (maybeUser, maybeToken) => + val maybeMyNdlaUser = maybeToken.flatMap(getMyNdlaUser) + val combinedUser = OptionalCombinedUser(maybeUser, maybeMyNdlaUser) + Right(combinedUser) + } def withRequiredMyNDLAUserOrTokenUser[F[_]] - : PartialServerEndpoint[(Option[TokenUser], Option[String]), CombinedUserRequired, I, AllErrors, O, R, F] = { - val newEndpoint = self.securityIn(TokenUser.oauth2Input(Seq.empty)).securityIn(AuthUtility.feideOauth()) - - val authFunc: ((Option[TokenUser], Option[String])) => Either[AllErrors, CombinedUserRequired] = - (userInputOptions: (Option[TokenUser], Option[String])) => { - val maybeUser = userInputOptions._1 - val maybeToken = userInputOptions._2 - - val myndlaUser = maybeToken.flatMap { token => - myNDLAApiClient.getUserWithFeideToken(token) match { - case Failure(ex: HttpRequestException) if ex.code == 401 || ex.code == 403 => None - case Failure(ex) => - logger.warn("Got unexpected exception when fetching myndla user", ex) - None - case Success(user) => Some(user) - } - } - - (maybeUser, myndlaUser) match { - case (Some(tokenUser), Some(ndlaUser)) => CombinedUserWithBoth(tokenUser, ndlaUser).asRight - case (Some(tokenUser), None) => tokenUser.toCombined.asRight - case (None, Some(ndlaUser)) => CombinedUserWithMyNDLAUser(None, ndlaUser).asRight - case _ => errorHelpers.unauthorized.asLeft - } + : PartialServerEndpoint[(Option[TokenUser], Option[String]), CombinedUserRequired, I, AllErrors, O, R, F] = self + .securityIn(ndlaOptionalAuth) + .securityIn(feideOptionalUncheckedAuth) + .serverSecurityLogicPure { (maybeUser, maybeToken) => + val maybeMyNdlaUser = maybeToken.flatMap(getMyNdlaUser) + (maybeUser, maybeMyNdlaUser) match { + case (Some(tokenUser), Some(ndlaUser)) => CombinedUserWithBoth(tokenUser, ndlaUser).asRight + case (Some(tokenUser), None) => tokenUser.toCombined.asRight + case (None, Some(ndlaUser)) => CombinedUserWithMyNDLAUser(None, ndlaUser).asRight + case _ => errorHelpers.unauthorized.asLeft } - val securityLogic = (m: MonadError[F]) => (a: (Option[TokenUser], Option[String])) => m.unit(authFunc(a)) - PartialServerEndpoint(newEndpoint, securityLogic) - } + } } extension [A, I, E, O, R, X](self: Endpoint[Unit, I, X, O, R]) { - def withOptionalUser[F[_]]: PartialServerEndpoint[Option[TokenUser], Option[TokenUser], I, X, O, R, F] = { - val newEndpoint = self.securityIn(TokenUser.oauth2Input(Seq.empty)) - val authFunc = (tokenUser: Option[TokenUser]) => Right(tokenUser): Either[X, Option[TokenUser]] - val securityLogic = (m: MonadError[F]) => (a: Option[TokenUser]) => m.unit(authFunc(a)) - PartialServerEndpoint(newEndpoint, securityLogic) - } + def withOptionalUser[F[_]]: PartialServerEndpoint[Option[TokenUser], Option[TokenUser], I, X, O, R, F] = self + .securityIn(ndlaOptionalAuth) + .serverSecurityLogicPure(Right(_)) } private val zeroNoContentHeader: EndpointIO.FixedHeader[Unit] = header("Content-Length", "0") diff --git a/network/src/main/scala/no/ndla/network/tapir/auth/FeideAuth.scala b/network/src/main/scala/no/ndla/network/tapir/auth/FeideAuth.scala new file mode 100644 index 0000000000..48e25efa00 --- /dev/null +++ b/network/src/main/scala/no/ndla/network/tapir/auth/FeideAuth.scala @@ -0,0 +1,62 @@ +/* + * Part of NDLA network + * Copyright (C) 2024 NDLA + * + * See LICENSE + * + */ + +package no.ndla.network.tapir.auth + +import no.ndla.network.clients.MyNDLAProvider +import no.ndla.network.model.FeideUserWrapper +import sttp.model.headers.{AuthenticationScheme, WWWAuthenticateChallenge} +import sttp.tapir.* +import sttp.tapir.EndpointInput.AuthType + +import scala.collection.immutable.ListMap +import scala.util.{Failure, Success} + +trait FeideAuth(using myNdlaApiClient: MyNDLAProvider) { + private val headerName = "FeideAuthorization" + private val schemeName = "FeideAuth" + private val issuer = "https://auth.dataporten.no" + private val authorizationUrl = s"$issuer/oauth/authorization" + private val tokenUrl = s"$issuer/oauth/token" + private val challenge = WWWAuthenticateChallenge.bearer + + private val bearerMapping: Mapping[String, String] = + Mapping.stringPrefixCaseInsensitive(AuthenticationScheme.Bearer.name + " ") + private val feideUserWrapperMapping = Mapping.fromDecode(decodeFeideUserWrapper)(encodeFeideUserWrapper) + private val bearerFeideUserWrapperMapping = bearerMapping.map(feideUserWrapperMapping) + private val optionalBearerFeideUserWrapperMapping = TapirAuthUtil.makeOptionalMapping(bearerFeideUserWrapperMapping) + private val optionalBearerMapping = TapirAuthUtil.makeOptionalMapping(bearerMapping) + + private val requiredHeaderInput = header[String](headerName).map(bearerFeideUserWrapperMapping) + private val optionalHeaderInput = header[Option[String]](headerName).map(optionalBearerFeideUserWrapperMapping) + private val optionalUncheckedHeaderInput = header[Option[String]](headerName).map(optionalBearerMapping) + + val feideRequiredAuth: EndpointInput.Auth[FeideUserWrapper, AuthType.OAuth2] = + oauth2EndpointInput(requiredHeaderInput) + val feideOptionalAuth: EndpointInput.Auth[Option[FeideUserWrapper], AuthType.OAuth2] = + oauth2EndpointInput(optionalHeaderInput) + val feideOptionalUncheckedAuth: EndpointInput.Auth[Option[String], AuthType.OAuth2] = + oauth2EndpointInput(optionalUncheckedHeaderInput) + + private def encodeFeideUserWrapper(user: FeideUserWrapper): String = user.token + private def decodeFeideUserWrapper(s: String): DecodeResult[FeideUserWrapper] = { + myNdlaApiClient.getDomainUser(s) match { + case Success(user) => DecodeResult.Value(FeideUserWrapper(s, Some(user))) + case Failure(ex) => DecodeResult.Error(s, ex) + } + } + + private def oauth2EndpointInput[T]( + headerInput: EndpointIO.Header[T] + ): EndpointInput.Auth[T, EndpointInput.AuthType.OAuth2] = EndpointInput.Auth( + headerInput, + challenge, + EndpointInput.AuthType.OAuth2(Some(authorizationUrl), Some(tokenUrl), ListMap(), None), + EndpointInput.AuthInfo.Empty.securitySchemeName(schemeName), + ) +} diff --git a/network/src/main/scala/no/ndla/network/tapir/auth/NdlaAuth.scala b/network/src/main/scala/no/ndla/network/tapir/auth/NdlaAuth.scala new file mode 100644 index 0000000000..5497f32baf --- /dev/null +++ b/network/src/main/scala/no/ndla/network/tapir/auth/NdlaAuth.scala @@ -0,0 +1,63 @@ +/* + * Part of NDLA network + * Copyright (C) 2026 NDLA + * + * See LICENSE + * + */ + +package no.ndla.network.tapir.auth + +import no.ndla.network.jwt.JWTExtractor +import sttp.tapir.* +import sttp.tapir.EndpointInput.AuthType + +import scala.util.{Failure, Success} + +case class UserInfoException() extends RuntimeException("Could not build `TokenUser` from token.") + +trait NdlaAuth { + private val schemeName = "NDLAAuth" + private val tokenUserMapping = Mapping.fromDecode(decodeTokenUser)(encodeTokenUser) + private val optionalTokenUserMapping = TapirAuthUtil.makeOptionalMapping(tokenUserMapping) + + val ndlaOptionalAuth: EndpointInput.Auth[Option[TokenUser], AuthType.OAuth2] = TapirAuth + .oauth2 + .authorizationCodeFlowOptional("", "") + .securitySchemeName(schemeName) + .map(optionalTokenUserMapping) + + def ndlaOptionalAuth( + requiredPermissions: Seq[Permission] + ): EndpointInput.Auth[Option[TokenUser], AuthType.ScopedOAuth2] = { + val scopes = Permission.toSwaggerMap(requiredPermissions) + val requiredScopes = requiredPermissions.map(_.entryName) + TapirAuth + .oauth2 + .authorizationCodeFlowOptional("", "", scopes = scopes) + .securitySchemeName(schemeName) + .map(optionalTokenUserMapping) + .requiredScopes(requiredScopes) + } + + private def encodeTokenUser(user: TokenUser): String = user.id + private def decodeTokenUser(s: String): DecodeResult[TokenUser] = { + val jWTExtractor = JWTExtractor(s) + fromExtractor(jWTExtractor, s) match { + case Failure(ex) => DecodeResult.Error(s, ex) + case Success(value) => DecodeResult.Value(value) + } + } + + private def fromExtractor(jWTExtractor: JWTExtractor, token: String) = { + val userId = jWTExtractor.extractUserId() + val roles = jWTExtractor.extractUserRoles() + val userName = jWTExtractor.extractUserName() + val clientId = jWTExtractor.extractClientId() + + userId.orElse(clientId).orElse(userName) match { + case Some(userInfoName) => Success(TokenUser(userInfoName, Permission.fromStrings(roles), Some(token))) + case None => Failure(UserInfoException()) + } + } +} diff --git a/network/src/main/scala/no/ndla/network/tapir/auth/Permission.scala b/network/src/main/scala/no/ndla/network/tapir/auth/Permission.scala index ff3666efb9..06eb8ea713 100644 --- a/network/src/main/scala/no/ndla/network/tapir/auth/Permission.scala +++ b/network/src/main/scala/no/ndla/network/tapir/auth/Permission.scala @@ -40,7 +40,7 @@ object Permission extends Enum[Permission] with CirceEnum[Permission] { def fromStrings(s: List[String]): Set[Permission] = s.flatMap(fromString).toSet implicit val schema: Schema[Permission] = schemaForEnumEntry[Permission] - def thatStartsWith(start: String): List[Permission] = values.filter(_.entryName.startsWith(start)).toList - def toSwaggerMap(scopes: List[Permission]): ListMap[String, String] = + def thatStartsWith(start: String): List[Permission] = values.filter(_.entryName.startsWith(start)).toList + def toSwaggerMap(scopes: Seq[Permission]): ListMap[String, String] = ListMap.from(scopes.map(s => s.entryName -> s.entryName)) } diff --git a/network/src/main/scala/no/ndla/network/tapir/auth/TapirAuthUtil.scala b/network/src/main/scala/no/ndla/network/tapir/auth/TapirAuthUtil.scala new file mode 100644 index 0000000000..3222aa3071 --- /dev/null +++ b/network/src/main/scala/no/ndla/network/tapir/auth/TapirAuthUtil.scala @@ -0,0 +1,22 @@ +/* + * Part of NDLA network + * Copyright (C) 2026 NDLA + * + * See LICENSE + * + */ + +package no.ndla.network.tapir.auth + +import sttp.tapir.* + +object TapirAuthUtil { + def makeOptionalMapping[T](typeMapping: Mapping[String, T]): Mapping[Option[String], Option[T]] = + Mapping.fromDecode[Option[String], Option[T]] { + case Some(v) => typeMapping.decode(v).map(Some(_)) + case None => DecodeResult.Value(None) + } { + case Some(v) => Some(typeMapping.encode(v)) + case None => None + } +} diff --git a/network/src/main/scala/no/ndla/network/tapir/auth/TokenUser.scala b/network/src/main/scala/no/ndla/network/tapir/auth/TokenUser.scala index a38b33e853..a2750b1533 100644 --- a/network/src/main/scala/no/ndla/network/tapir/auth/TokenUser.scala +++ b/network/src/main/scala/no/ndla/network/tapir/auth/TokenUser.scala @@ -9,16 +9,7 @@ package no.ndla.network.tapir.auth import cats.implicits.* -import no.ndla.network.jwt.JWTExtractor import no.ndla.network.model.{CombinedUserWithTokenUser, JWTClaims} -import sttp.model.HeaderNames -import sttp.model.headers.{AuthenticationScheme, WWWAuthenticateChallenge} -import sttp.tapir.CodecFormat.TextPlain -import sttp.tapir.EndpointInput.{AuthInfo, AuthType} -import sttp.tapir.* - -import scala.collection.immutable.ListMap -import scala.util.{Failure, Success, Try} case class TokenUser(id: String, permissions: Set[Permission], jwt: JWTClaims, originalToken: Option[String]) { def hasPermission(permission: Permission): Boolean = permissions.contains(permission) @@ -53,60 +44,7 @@ object TokenUser { val PublicUser: TokenUser = TokenUser("public", Set.empty, None) val SystemUser: TokenUser = TokenUser("system", Permission.values.toSet, None) - case class UserInfoException() extends RuntimeException("Could not build `TokenUser` from token.") - - private def fromExtractor(jWTExtractor: JWTExtractor, token: String) = { - val userId = jWTExtractor.extractUserId() - val roles = jWTExtractor.extractUserRoles() - val userName = jWTExtractor.extractUserName() - val clientId = jWTExtractor.extractClientId() - - userId.orElse(clientId).orElse(userName) match { - case Some(userInfoName) => Success(TokenUser(userInfoName, Permission.fromStrings(roles), Some(token))) - case None => Failure(UserInfoException()) - } - } - - def fromToken(token: String): Try[TokenUser] = { - val jWTExtractor = JWTExtractor(token) - fromExtractor(jWTExtractor, token) - } - extension (self: Option[TokenUser]) { def hasPermission(permission: Permission): Boolean = self.exists(user => user.hasPermission(permission)) } - - def encode(user: TokenUser): String = user.id - def decode(s: String): DecodeResult[TokenUser] = fromToken(s) match { - case Failure(ex) => DecodeResult.Error(s, ex) - case Success(value) => DecodeResult.Value(value) - } - - private implicit val userinfoCodec: Codec[String, TokenUser, TextPlain] = Codec.string.mapDecode(decode)(encode) - private val authScheme = AuthenticationScheme.Bearer.name - private val codec = implicitly[Codec[List[String], Option[TokenUser], CodecFormat.TextPlain]] - def filterHeaders(headers: List[String]): List[String] = - headers.filter(_.toLowerCase.startsWith(authScheme.toLowerCase)) - def stringPrefixWithSpace: Mapping[List[String], List[String]] = - Mapping.stringPrefixCaseInsensitiveForList(authScheme + " ") - val authCodec: Codec[List[String], Option[TokenUser], TextPlain] = Codec - .id[List[String], CodecFormat.TextPlain](codec.format, Schema.binary) - .map(filterHeaders)(identity) - .map(stringPrefixWithSpace) - .mapDecode(codec.decode)(codec.encode) - .schema(codec.schema) - - def oauth2Input(permissions: Seq[Permission]): EndpointInput.Auth[Option[TokenUser], AuthType.ScopedOAuth2] = { - val authType: AuthType.ScopedOAuth2 = EndpointInput - .AuthType - .OAuth2(None, None, ListMap.from(permissions.map(p => p.entryName -> p.entryName)), None) - .requiredScopes(permissions.map(_.entryName)) - - EndpointInput.Auth( - input = sttp.tapir.header(HeaderNames.Authorization)(using authCodec), - challenge = WWWAuthenticateChallenge.bearer, - authType = authType, - info = AuthInfo.Empty.securitySchemeName("oauth2"), - ) - } }