Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.Collection;

import org.apereo.cas.client.validation.Assertion;
import org.jspecify.annotations.Nullable;

import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.GrantedAuthority;
Expand Down Expand Up @@ -104,6 +105,19 @@ private CasAuthenticationToken(final Integer keyHash, final Object principal, fi
setAuthenticated(true);
}

protected CasAuthenticationToken(Builder<?> builder) {
super(builder);
Assert.isTrue(!"".equals(builder.principal), "principal cannot be null or empty");
Assert.notNull(!"".equals(builder.credentials), "credentials cannot be null or empty");
Assert.notNull(builder.userDetails, "userDetails cannot be null");
Assert.notNull(builder.assertion, "assertion cannot be null");
this.keyHash = builder.keyHash;
this.principal = builder.principal;
this.credentials = builder.credentials;
this.userDetails = builder.userDetails;
this.assertion = builder.assertion;
}

private static Integer extractKeyHash(String key) {
Assert.hasLength(key, "key cannot be null or empty");
return key.hashCode();
Expand Down Expand Up @@ -153,6 +167,11 @@ public UserDetails getUserDetails() {
return this.userDetails;
}

@Override
public Builder<?> toBuilder() {
return new Builder<>(this);
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
Expand All @@ -162,4 +181,81 @@ public String toString() {
return (sb.toString());
}

/**
* A builder of {@link CasAuthenticationToken} instances
*
* @since 7.0
*/
public static class Builder<B extends Builder<B>> extends AbstractAuthenticationBuilder<B> {

private Integer keyHash;

private Object principal;

private Object credentials;

private UserDetails userDetails;

private Assertion assertion;

protected Builder(CasAuthenticationToken token) {
super(token);
this.keyHash = token.keyHash;
this.principal = token.principal;
this.credentials = token.credentials;
this.userDetails = token.userDetails;
this.assertion = token.assertion;
}

/**
* Use this key
* @param key the key to use
* @return the {@link Builder} for further configurations
*/
public B key(String key) {
this.keyHash = key.hashCode();
return (B) this;
}

@Override
public B principal(@Nullable Object principal) {
Assert.notNull(principal, "principal cannot be null");
this.principal = principal;
return (B) this;
}

@Override
public B credentials(@Nullable Object credentials) {
Assert.notNull(credentials, "credentials cannot be null");
this.credentials = credentials;
return (B) this;
}

/**
* Use this {@link UserDetails}
* @param userDetails the {@link UserDetails} to use
* @return the {@link Builder} for further configurations
*/
public B userDetails(UserDetails userDetails) {
this.userDetails = userDetails;
return (B) this;
}

/**
* Use this {@link Assertion}
* @param assertion the {@link Assertion} to use
* @return the {@link Builder} for further configurations
*/
public B assertion(Assertion assertion) {
this.assertion = assertion;
return (B) this;
}

@Override
public CasAuthenticationToken build() {
return new CasAuthenticationToken(this);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public class CasServiceTicketAuthenticationToken extends AbstractAuthenticationT
*
*/
public CasServiceTicketAuthenticationToken(String identifier, Object credentials) {
super(null);
super((Collection<? extends GrantedAuthority>) null);
this.identifier = identifier;
this.credentials = credentials;
setAuthenticated(false);
Expand All @@ -75,6 +75,12 @@ public CasServiceTicketAuthenticationToken(String identifier, Object credentials
super.setAuthenticated(true);
}

protected CasServiceTicketAuthenticationToken(Builder<?> builder) {
super(builder);
this.identifier = builder.principal;
this.credentials = builder.credentials;
}

public static CasServiceTicketAuthenticationToken stateful(Object credentials) {
return new CasServiceTicketAuthenticationToken(CAS_STATEFUL_IDENTIFIER, credentials);
}
Expand Down Expand Up @@ -110,4 +116,46 @@ public void eraseCredentials() {
this.credentials = null;
}

public Builder<?> toBuilder() {
return new Builder<>(this);
}

/**
* A builder of {@link CasServiceTicketAuthenticationToken} instances
*
* @since 7.0
*/
public static class Builder<B extends Builder<B>> extends AbstractAuthenticationBuilder<B> {

private String principal;

private @Nullable Object credentials;

protected Builder(CasServiceTicketAuthenticationToken token) {
super(token);
this.principal = token.identifier;
this.credentials = token.credentials;
}

@Override
public B principal(@Nullable Object principal) {
Assert.isInstanceOf(String.class, principal, "principal must be of type String");
this.principal = (String) principal;
return (B) this;
}

@Override
public B credentials(@Nullable Object credentials) {
Assert.notNull(credentials, "credentials cannot be null");
this.credentials = credentials;
return (B) this;
}

@Override
public CasServiceTicketAuthenticationToken build() {
return new CasServiceTicketAuthenticationToken(this);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.util.Collections;
import java.util.List;
import java.util.Set;

import org.apereo.cas.client.validation.Assertion;
import org.apereo.cas.client.validation.AssertionImpl;
Expand All @@ -26,6 +27,7 @@
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.userdetails.PasswordEncodedUser;
import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetails;

Expand Down Expand Up @@ -155,4 +157,29 @@ public void testToString() {
assertThat(result.lastIndexOf("Credentials (Service/Proxy Ticket):") != -1).isTrue();
}

@Test
public void toBuilderWhenApplyThenCopies() {
Assertion assertionOne = new AssertionImpl("test");
CasAuthenticationToken factorOne = new CasAuthenticationToken("key", "alice", "pass",
AuthorityUtils.createAuthorityList("FACTOR_ONE"), PasswordEncodedUser.user(), assertionOne);
Assertion assertionTwo = new AssertionImpl("test");
CasAuthenticationToken factorTwo = new CasAuthenticationToken("yek", "bob", "ssap",
AuthorityUtils.createAuthorityList("FACTOR_TWO"), PasswordEncodedUser.admin(), assertionTwo);
CasAuthenticationToken authentication = factorOne.toBuilder()
.authorities((a) -> a.addAll(factorTwo.getAuthorities()))
.key("yek")
.principal(factorTwo.getPrincipal())
.credentials(factorTwo.getCredentials())
.userDetails(factorTwo.getUserDetails())
.assertion(factorTwo.getAssertion())
.build();
Set<String> authorities = AuthorityUtils.authorityListToSet(authentication.getAuthorities());
assertThat(authentication.getKeyHash()).isEqualTo(factorTwo.getKeyHash());
assertThat(authentication.getPrincipal()).isEqualTo(factorTwo.getPrincipal());
assertThat(authentication.getCredentials()).isEqualTo(factorTwo.getCredentials());
assertThat(authentication.getUserDetails()).isEqualTo(factorTwo.getUserDetails());
assertThat(authentication.getAssertion()).isEqualTo(factorTwo.getAssertion());
assertThat(authorities).containsExactlyInAnyOrder("FACTOR_ONE", "FACTOR_TWO");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.springframework.security.config.annotation.web.configurers;

import java.util.Collection;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;

Expand All @@ -31,6 +33,7 @@
import org.springframework.security.config.test.SpringTestContextExtension;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.Transient;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.test.web.servlet.MockMvc;
Expand Down Expand Up @@ -113,7 +116,7 @@ public boolean supports(Class<?> authentication) {
static class SomeTransientAuthentication extends AbstractAuthenticationToken {

SomeTransientAuthentication() {
super(null);
super((Collection<? extends GrantedAuthority>) null);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.springframework.security.config.http;

import java.util.Collection;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;

Expand All @@ -26,6 +28,7 @@
import org.springframework.security.config.test.SpringTestContextExtension;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.Transient;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
Expand Down Expand Up @@ -82,7 +85,7 @@ public boolean supports(Class<?> authentication) {
static class SomeTransientAuthentication extends AbstractAuthenticationToken {

SomeTransientAuthentication() {
super(null);
super((Collection<? extends GrantedAuthority>) null);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

package org.springframework.security.authentication;

import java.io.Serial;
import java.security.Principal;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.function.Consumer;

import org.jspecify.annotations.Nullable;

Expand All @@ -41,6 +44,9 @@
*/
public abstract class AbstractAuthenticationToken implements Authentication, CredentialsContainer {

@Serial
private static final long serialVersionUID = -3194696462184782834L;

private final Collection<GrantedAuthority> authorities;

private @Nullable Object details;
Expand All @@ -63,6 +69,12 @@ public AbstractAuthenticationToken(@Nullable Collection<? extends GrantedAuthori
this.authorities = Collections.unmodifiableList(new ArrayList<>(authorities));
}

protected AbstractAuthenticationToken(AbstractAuthenticationBuilder<?> builder) {
this(builder.authorities);
this.authenticated = builder.authenticated;
this.details = builder.details;
}

@Override
public Collection<GrantedAuthority> getAuthorities() {
return this.authorities;
Expand Down Expand Up @@ -185,4 +197,48 @@ public String toString() {
return sb.toString();
}

/**
* A common abstract implementation of {@link Authentication.Builder}. It implements
* the builder methods that correspond to the {@link Authentication} methods that
* {@link AbstractAuthenticationToken} implements
*
* @param <B>
* @since 7.0
*/
protected abstract static class AbstractAuthenticationBuilder<B extends AbstractAuthenticationBuilder<B>>
implements Authentication.Builder<B> {

private boolean authenticated;

private @Nullable Object details;

private final Collection<GrantedAuthority> authorities;

protected AbstractAuthenticationBuilder(AbstractAuthenticationToken token) {
this.authorities = new LinkedHashSet<>(token.getAuthorities());
this.authenticated = token.isAuthenticated();
this.details = token.getDetails();
}

@Override
public B authenticated(boolean authenticated) {
this.authenticated = authenticated;
return (B) this;
}

@Override
public B details(@Nullable Object details) {
this.details = details;
return (B) this;
}

@Override
public B authorities(Consumer<Collection<GrantedAuthority>> authorities) {
authorities.accept(this.authorities);
this.authenticated = true;
return (B) this;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ public Mono<Authentication> authenticate(Authentication authentication) {
Function<ReactiveAuthenticationManager, Mono<Authentication>> logging = (m) -> m.authenticate(authentication)
.doOnError(AuthenticationException.class, (ex) -> ex.setAuthenticationRequest(authentication))
.doOnError(this.logger::debug);

return ((this.continueOnError) ? result.concatMapDelayError(logging) : result.concatMap(logging)).next();
}

Expand Down
Loading