Skip to content

Allow configuration of SessionAuthenticationStrategy for CSRF #7083

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -30,6 +30,7 @@
import org.springframework.security.web.access.AccessDeniedHandler;
import org.springframework.security.web.access.AccessDeniedHandlerImpl;
import org.springframework.security.web.access.DelegatingAccessDeniedHandler;
import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
import org.springframework.security.web.csrf.CsrfAuthenticationStrategy;
import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.csrf.CsrfLogoutHandler;
Expand Down Expand Up @@ -81,6 +82,7 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
new HttpSessionCsrfTokenRepository());
private RequestMatcher requireCsrfProtectionMatcher = CsrfFilter.DEFAULT_CSRF_MATCHER;
private List<RequestMatcher> ignoredCsrfProtectionMatchers = new ArrayList<>();
private SessionAuthenticationStrategy sessionAuthenticationStrategy;
private final ApplicationContext context;

/**
Expand Down Expand Up @@ -179,6 +181,26 @@ public CsrfConfigurer<H> ignoringRequestMatchers(RequestMatcher... requestMatche
.and();
}

/**
* <p>
* Specify the {@link SessionAuthenticationStrategy} to use. The default is a
* {@link CsrfAuthenticationStrategy}.
* </p>
*
* @author Michael Vitz
* @since 5.1
*
* @param sessionAuthenticationStrategy the {@link SessionAuthenticationStrategy} to use
* @return the {@link CsrfConfigurer} for further customizations
*/
public CsrfConfigurer<H> sessionAuthenticationStrategy(
SessionAuthenticationStrategy sessionAuthenticationStrategy) {
Assert.notNull(sessionAuthenticationStrategy,
"sessionAuthenticationStrategy cannot be null");
this.sessionAuthenticationStrategy = sessionAuthenticationStrategy;
return this;
}

@SuppressWarnings("unchecked")
@Override
public void configure(H http) throws Exception {
Expand All @@ -200,7 +222,7 @@ public void configure(H http) throws Exception {
.getConfigurer(SessionManagementConfigurer.class);
if (sessionConfigurer != null) {
sessionConfigurer.addSessionAuthenticationStrategy(
new CsrfAuthenticationStrategy(this.csrfTokenRepository));
getSessionAuthenticationStrategy());
}
filter = postProcess(filter);
http.addFilter(filter);
Expand Down Expand Up @@ -289,6 +311,23 @@ private AccessDeniedHandler createAccessDeniedHandler(H http) {
return new DelegatingAccessDeniedHandler(handlers, defaultAccessDeniedHandler);
}

/**
* Gets the {@link SessionAuthenticationStrategy} to use. If none was set by the user a
* {@link CsrfAuthenticationStrategy} is created.
*
* @author Michael Vitz
* @since 5.1
*
* @return the {@link SessionAuthenticationStrategy}
*/
private SessionAuthenticationStrategy getSessionAuthenticationStrategy() {
if (sessionAuthenticationStrategy != null) {
return sessionAuthenticationStrategy;
} else {
return new CsrfAuthenticationStrategy(this.csrfTokenRepository);
}
}

/**
* Allows registering {@link RequestMatcher} instances that should be ignored (even if
* the {@link HttpServletRequest} matches the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
import org.springframework.security.config.test.SpringTestRule;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.userdetails.PasswordEncodedUser;
import org.springframework.security.web.access.AccessDeniedHandler;
import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.firewall.StrictHttpFirewall;
Expand All @@ -51,22 +53,12 @@
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.*;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user;
import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated;
import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.unauthenticated;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.delete;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.head;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.options;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.patch;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.put;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.request;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;

Expand All @@ -75,6 +67,8 @@
*
* @author Rob Winch
* @author Eleftheria Stein
* @author Michael Vitz
* @author Sam Simmons
*/
public class CsrfConfigurerTests {
@Rule
Expand Down Expand Up @@ -603,6 +597,66 @@ protected void configure(AuthenticationManagerBuilder auth) throws Exception {
}
}

@EnableWebSecurity
static class NullAuthenticationStrategy extends WebSecurityConfigurerAdapter {
@Override
protected void configure(HttpSecurity http) throws Exception {
// @formatter:off
http
.csrf()
.sessionAuthenticationStrategy(null);
// @formatter:on
}
}

@Test
public void getWhenNullAuthenticationStrategyThenException() {
assertThatThrownBy(() -> this.spring.register(NullAuthenticationStrategy.class).autowire())
.isInstanceOf(BeanCreationException.class)
.hasRootCauseInstanceOf(IllegalArgumentException.class);
}

@EnableWebSecurity
static class CsrfAuthenticationStrategyConfig extends WebSecurityConfigurerAdapter {
static SessionAuthenticationStrategy STRATEGY;

@Override
protected void configure(HttpSecurity http) throws Exception {
// @formatter:off
http
.formLogin()
.and()
.csrf()
.sessionAuthenticationStrategy(STRATEGY);
// @formatter:on
}

@Override
protected void configure(AuthenticationManagerBuilder auth) throws Exception {
// @formatter:off
auth
.inMemoryAuthentication()
.withUser(PasswordEncodedUser.user());
// @formatter:on
}
}

@Test
public void csrfAuthenticationStrategyConfiguredThenStrategyUsed() throws Exception {
CsrfAuthenticationStrategyConfig.STRATEGY = mock(SessionAuthenticationStrategy.class);

this.spring.register(CsrfAuthenticationStrategyConfig.class).autowire();

this.mvc.perform(post("/login")
.with(csrf())
.param("username", "user")
.param("password", "password"))
.andExpect(redirectedUrl("/"));

verify(CsrfAuthenticationStrategyConfig.STRATEGY, atLeastOnce())
.onAuthentication(any(Authentication.class), any(HttpServletRequest.class), any(HttpServletResponse.class));
}

@RestController
static class BasicController {
@GetMapping("/")
Expand Down