深入扩展Spring Cloud Oauth2授权模式
概述
当使用Spring Cloud Oauth2的时候,默认提供了四种授权模式,有时模这式是不满足我们的需求,比如需要短信登录、微信登录等,本文通过源码分析到扩展,讲解如何通过自定义TokenGranter来优雅的扩展。阅读本文之前,需要熟悉Spring Cloud Oauth2知识。
注意:目前仅做原理分析,未提供直接使用的源码
Spring Security Oauth2认证序列图
整体逻辑图如下,后面的扩展也是从自定义CompositeTokenGranter入手扩展CompositeTokenGranter
开始源码分析
首先我们创建一个Oauth2认证服务配置类AuthServerConfig ,定义授权和令牌端点以及令牌服务等相关配置。
@Configuration
@EnableAuthorizationServer
public class AuthServerConfig extends AuthorizationServerConfigurerAdapter {
@Autowired
private AuthenticationManager authenticationManager;
@Autowired
private AuthorizationServerTokenServices tokenServices;
@Autowired(required = false)
private List<ITokenGranter> tokenGranters;
@Autowired
private WebResponseExceptionTranslator<OAuth2Exception> webResponseExceptionTranslator;
/**
* 定义授权和令牌端点以及令牌服务
*/
@Override
public void configure(AuthorizationServerEndpointsConfigurer endpoints) {
CompositeTokenGranter customTokenGranter = new CompositeTokenGranter(
authenticationManager,
tokenServices,
clientDetailsService,
tokenGranters);
endpoints
.tokenGranter(customTokenGranter)
.authenticationManager(authenticationManager)//认证管理器
.tokenServices(tokenServices)//令牌管理服务
.allowedTokenEndpointRequestMethods(HttpMethod.POST)
.exceptionTranslator(webResponseExceptionTranslator);
}
// 省略其他代码
}
分析@EnableAuthorizationServer
注解中引入了AuthorizationServerEndpointsConfiguration,在AuthorizationServerEndpointsConfiguration中定义了TokenEndpoint,并且设置了TokenGranter
@Bean
public TokenEndpoint tokenEndpoint() throws Exception {
TokenEndpoint tokenEndpoint = new TokenEndpoint();
tokenEndpoint.setClientDetailsService(clientDetailsService);
tokenEndpoint.setProviderExceptionHandler(exceptionTranslator());
tokenEndpoint.setTokenGranter(tokenGranter());
tokenEndpoint.setOAuth2RequestFactory(oauth2RequestFactory());
tokenEndpoint.setOAuth2RequestValidator(oauth2RequestValidator());
tokenEndpoint.setAllowedRequestMethods(allowedTokenEndpointRequestMethods());
return tokenEndpoint;
}
TokenEndpoint#postAccessToken
就是实际Oauth2颁发令牌的入口,通过分析tokenEndpoint.setTokenGranter(tokenGranter());
实际调用的是AuthorizationServerEndpointsConfigurer#tokenGranter
方法,如下
private TokenGranter tokenGranter() {
if (tokenGranter == null) {
tokenGranter = new TokenGranter() {
private CompositeTokenGranter delegate;
@Override
public OAuth2AccessToken grant(String grantType, TokenRequest tokenRequest) {
if (delegate == null) {
delegate = new CompositeTokenGranter(getDefaultTokenGranters());
}
return delegate.grant(grantType, tokenRequest);
}
};
}
return tokenGranter;
}
继续深入分析可以看到,默认使用的是CompositeTokenGranter(关键),并且通过方法AuthorizationServerEndpointsConfigurer#getDefaultTokenGranters()
设置TokenGranter,在这里我们就可以看到默认的四种授权模式对应的TokenGranter了,本文也正是从此着手扩展CompositeTokenGranter
private List<TokenGranter> getDefaultTokenGranters() {
ClientDetailsService clientDetails = clientDetailsService();
AuthorizationServerTokenServices tokenServices = tokenServices();
AuthorizationCodeServices authorizationCodeServices = authorizationCodeServices();
OAuth2RequestFactory requestFactory = requestFactory();
List<TokenGranter> tokenGranters = new ArrayList<TokenGranter>();
tokenGranters.add(new AuthorizationCodeTokenGranter(tokenServices, authorizationCodeServices, clientDetails,
requestFactory));
tokenGranters.add(new RefreshTokenGranter(tokenServices, clientDetails, requestFactory));
ImplicitTokenGranter implicit = new ImplicitTokenGranter(tokenServices, clientDetails, requestFactory);
tokenGranters.add(implicit);
tokenGranters.add(new ClientCredentialsTokenGranter(tokenServices, clientDetails, requestFactory));
if (authenticationManager != null) {
tokenGranters.add(new ResourceOwnerPasswordTokenGranter(authenticationManager, tokenServices,
clientDetails, requestFactory));
}
return tokenGranters;
}
可以通过模仿Spring的CompositeTokenGranter自定义CompositeTokenGranter,并且通过前面AuthServerConfig#configure(AuthorizationServerEndpointsConfigurer)
设置我们自定义的TokenGranter。如下
/**
* {@link org.springframework.security.oauth2.provider.CompositeTokenGranter}
*
* @author :jiangxiaowei
*/
@Slf4j
public class CompositeTokenGranter implements TokenGranter {
private final AuthenticationManager authenticationManager;
private final AuthorizationServerTokenServices tokenServices;
private final ClientDetailsService clientDetailsService;
private final OAuth2RequestFactory requestFactory;
private final List<TokenGranter> tokenGranters;
public CompositeTokenGranter(AuthenticationManager authenticationManager,
AuthorizationServerTokenServices tokenServices,
ClientDetailsService clientDetailsService,
List<ITokenGranter> tokenGranters) {
this.authenticationManager = authenticationManager;
this.tokenServices = tokenServices;
this.clientDetailsService = clientDetailsService;
this.requestFactory = new DefaultOAuth2RequestFactory(clientDetailsService);
this.tokenGranters = getDefaultTokenGranters();
if (!CollectionUtils.isEmpty(tokenGranters)) {
tokenGranters.stream()
.map(granter -> new TokenGranterAdaptor(granter, authenticationManager, tokenServices, clientDetailsService, requestFactory))
.forEach(granter -> {
log.info("Register custom token granter: {}", granter.getClass().getName());
this.tokenGranters.add(granter);
});
}
}
public OAuth2AccessToken grant(String grantType, TokenRequest tokenRequest) {
for (TokenGranter granter : tokenGranters) {
OAuth2AccessToken grant = granter.grant(grantType, tokenRequest);
if (grant != null) {
return grant;
}
}
return null;
}
private List<TokenGranter> getDefaultTokenGranters() {
List<TokenGranter> tokenGranters = new ArrayList<TokenGranter>();
InMemoryAuthorizationCodeServices authorizationCodeServices = new InMemoryAuthorizationCodeServices();
tokenGranters.add(new AuthorizationCodeTokenGranter(tokenServices, authorizationCodeServices, clientDetailsService, requestFactory));
tokenGranters.add(new RefreshTokenGranter(tokenServices, clientDetailsService, requestFactory));
ImplicitTokenGranter implicit = new ImplicitTokenGranter(tokenServices, clientDetailsService, requestFactory);
tokenGranters.add(implicit);
tokenGranters.add(new ClientCredentialsTokenGranter(tokenServices, clientDetailsService, requestFactory));
if (authenticationManager != null) {
tokenGranters.add(new ResourceOwnerPasswordTokenGranter(authenticationManager, tokenServices,
clientDetailsService, requestFactory));
}
return tokenGranters;
}
}
跟着,我们定义相关接口ITokenGranter 、适配器TokenGranterAdaptor
/**
* @author :jiangxiaowei
*/
public interface ITokenGranter {
/**
* 授权类型
*
* @return 授权类型
*/
String grantType();
/**
* 基于参数生成认证信息
*
* @param requestParameters 参数
* @return 认证信息
*/
Authentication authenticate(Map<String, String> requestParameters);
/**
* 生成认证信息
*
* @param client 客户端信息
* @param tokenRequest request
* @return 认证信息
*/
default Authentication authenticate(ClientDetails client, TokenRequest tokenRequest) {
Map<String, String> requestParameters = tokenRequest.getRequestParameters();
return this.authenticate(requestParameters);
}
}
/**
* @author :jiangxiaowei
*/
public class TokenGranterAdaptor extends AbstractTokenGranter {
private final ITokenGranter tokenGranter;
private final AuthenticationManager authenticationManager;
public TokenGranterAdaptor(ITokenGranter tokenGranter, AuthenticationManager authenticationManager, AuthorizationServerTokenServices tokenServices, ClientDetailsService clientDetailsService, OAuth2RequestFactory requestFactory) {
super(tokenServices, clientDetailsService, requestFactory, tokenGranter.grantType());
this.tokenGranter = tokenGranter;
this.authenticationManager = authenticationManager;
}
@Override
protected OAuth2Authentication getOAuth2Authentication(ClientDetails client, TokenRequest tokenRequest) {
Authentication authenticate = tokenGranter.authenticate(client, tokenRequest);
Map<String, String> parameters = new LinkedHashMap<>(tokenRequest.getRequestParameters());
((AbstractAuthenticationToken) authenticate).setDetails(parameters);
try {
authenticate = authenticationManager.authenticate(authenticate);
} catch (AccountStatusException | BadCredentialsException ase) {
throw new InvalidGrantException(ase.getMessage());
}
if (authenticate == null || !authenticate.isAuthenticated()) {
throw new InvalidGrantException("Login authenticate failed: " + parameters);
}
OAuth2Request storedOAuth2Request = getRequestFactory().createOAuth2Request(client, tokenRequest);
return new OAuth2Authentication(storedOAuth2Request, authenticate);
}
}
然后我们通过实现接口ITokenGranter 就可以自定义我们自己的授权模式了,比如SMS登录
实际操作
首先模仿DaoAuthenticationProvider的创建AuthenticationProvider,参考AbstractUserDetailsAuthenticationProvider的源码做修改
public abstract class AbstractAuthenticationProvider implements AuthenticationProvider, InitializingBean, MessageSourceAware {
private final Log logger = LogFactory.getLog(getClass());
private MessageSourceAccessor messages = SpringSecurityMessageSource.getAccessor();
private UserCache userCache = new NullUserCache();
private UserDetailsChecker preAuthenticationChecks = new DefaultPreAuthenticationChecks();
private UserDetailsChecker postAuthenticationChecks = new DefaultPostAuthenticationChecks();
private boolean forcePrincipalAsString = false;
@Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException {
// Determine username
String username = (authentication.getPrincipal() == null) ? "NONE_PROVIDED"
: authentication.getName();
boolean cacheWasUsed = true;
UserDetails user = this.userCache.getUserFromCache(username);
if (user == null) {
cacheWasUsed = false;
try {
user = retrieveUser(username, authentication);
}
catch (UsernameNotFoundException notFound) {
logger.debug("User '" + username + "' not found");
// 未找到用户
throw new BadCredentialsException(messages.getMessage(
"AbstractUserDetailsAuthenticationProvider.badCredentials",
"Bad credentials"));
}
Assert.notNull(user,
"retrieveUser returned null - a violation of the interface contract");
}
try {
preAuthenticationChecks.check(user);
additionalAuthenticationChecks(user, authentication);
}
catch (AuthenticationException exception) {
if (cacheWasUsed) {
// There was a problem, so try again after checking
// we're using latest data (i.e. not from the cache)
cacheWasUsed = false;
user = retrieveUser(username, authentication);
preAuthenticationChecks.check(user);
additionalAuthenticationChecks(user, authentication);
}
else {
throw exception;
}
}
postAuthenticationChecks.check(user);
if (!cacheWasUsed) {
this.userCache.putUserInCache(user);
}
// 转化为principal
Object principalToReturn = user;
if (forcePrincipalAsString) {
principalToReturn = user.getUsername();
}
return createSuccessAuthentication(principalToReturn, authentication, user);
}
protected void additionalAuthenticationChecks(UserDetails userDetails, Authentication authentication) throws AuthenticationException {
if (authentication.getCredentials() == null) {
this.logger.debug("Authentication failed: no credentials provided");
throw new BadCredentialsException(this.messages.getMessage("PhoneAuthenticationProvider.badCredentials", "Bad credentials"));
} else {
if (!isValidCredential(authentication)) {
this.logger.debug("Authentication failed: credential does not match stored value");
throw new BadCredentialsException(this.messages.getMessage("PhoneAuthenticationProvider.badCredentials", "Bad credential"));
}
}
}
protected UserDetails retrieveUser(String username, Authentication authentication) throws AuthenticationException {
UserDetails loadedUser;
try {
loadedUser = getUserDetailsService().loadUser(authentication);
} catch (Exception var7) {
throw new InternalAuthenticationServiceException(var7.getMessage(), var7);
}
if (loadedUser == null) {
throw new InternalAuthenticationServiceException("UserDetailsService returned null, which is an interface contract violation");
} else {
return loadedUser;
}
}
@Override
public final void afterPropertiesSet() throws Exception {
Assert.notNull(this.userCache, "A user cache must be set");
Assert.notNull(this.messages, "A message source must be set");
}
@Override
public void setMessageSource(MessageSource messageSource) {
this.messages = new MessageSourceAccessor(messageSource);
}
/**
* 判断是否支持authentication
*/
protected abstract boolean isValidCredential(Authentication authentication);
/**
* 获取UserDetailsService
*/
protected abstract MyUserDetailsService getUserDetailsService();
/**
* 创建凭证
*/
protected abstract Authentication createSuccessAuthentication(Object principal,
Authentication authentication, UserDetails user);
private class DefaultPreAuthenticationChecks implements UserDetailsChecker {
@Override
public void check(UserDetails user) {
if (!user.isAccountNonLocked()) {
logger.debug("User account is locked");
throw new LockedException(messages.getMessage(
"AbstractUserDetailsAuthenticationProvider.locked",
"User account is locked"));
}
if (!user.isEnabled()) {
logger.debug("User account is disabled");
throw new DisabledException(messages.getMessage(
"AbstractUserDetailsAuthenticationProvider.disabled",
"User is disabled"));
}
if (!user.isAccountNonExpired()) {
logger.debug("User account is expired");
throw new AccountExpiredException(messages.getMessage(
"AbstractUserDetailsAuthenticationProvider.expired",
"User account has expired"));
}
}
}
private class DefaultPostAuthenticationChecks implements UserDetailsChecker {
@Override
public void check(UserDetails user) {
if (!user.isCredentialsNonExpired()) {
logger.debug("User account credentials have expired");
throw new CredentialsExpiredException(messages.getMessage(
"AbstractUserDetailsAuthenticationProvider.credentialsExpired",
"User credentials have expired"));
}
}
}
}
然后继承AbstractAuthenticationProvider ,创建PhoneAuthenticationProvider
@Service
@RequiredArgsConstructor
public class PhoneAuthenticationProvider extends AbstractAuthenticationProvider {
private final PhoneUserDetailService userDetailsService;
@Override
protected Authentication createSuccessAuthentication(Object principal, Authentication authentication, UserDetails user) {
PhoneAuthenticationToken result = new PhoneAuthenticationToken(principal, authentication.getCredentials(), user.getAuthorities());
result.setDetails(authentication.getDetails());
return result;
}
@Override
protected boolean isValidCredential(Authentication authentication) {
return true;
}
@Override
protected MyUserDetailsService getUserDetailsService() {
return userDetailsService;
}
/**
* 判断是否支持token
*/
@Override
public boolean supports(Class<?> authentication) {
return PhoneAuthenticationToken.class.isAssignableFrom(authentication);
}
}
创建自定义的UserDetails、UserDetailService(源码仅供参考,没有细节处理)
@EqualsAndHashCode(callSuper = true)
@Getter
public class MyUserDetails extends User {
private final UserInfo userInfo;
MyUserDetails(UserInfo info, String username, String password, boolean enabled,
boolean accountNonExpired, boolean credentialsNonExpired,
boolean accountNonLocked, Collection<? extends GrantedAuthority> authorities) {
super(username, password,
enabled, accountNonExpired, credentialsNonExpired, accountNonLocked, authorities);
//this.userId = basic.getUserId();
//this.deptId = basic.getDeptId();
this.userInfo = info;
}
}
public abstract class MyUserDetailsService implements UserDetailsService {
private final Logger log = LoggerFactory.getLogger(getClass());
public UserDetails loadUser(Authentication authentication) {
throw new IllegalStateException("Not allowed service");
}
public UserDetails loadUserByUsername(String username) throws UsernameNotFoundException {
throw new IllegalStateException("Not allowed service");
}
void checkUser(XCloudResponse<UserInfo> userResult, String username) {
UserBasic basic = Optional.ofNullable(userResult).map(XCloudResponse::getPayload).map(UserInfo::getBasic).orElse(null);
if (basic == null) {
log.info("登录用户:{} 不存在.", username);
throw new UsernameNotFoundException("登录用户:" + username + " 不存在");
}
}
UserDetails getUserDetails(XCloudResponse<UserInfo> result) {
UserInfo info = result.getPayload();
Set<String> dbAuthsSet = new HashSet<>();
// 获取角色,SecurityExpressionRoot.defaultRolePrefix
if (CollectionUtils.isNotEmpty(info.getRoles())) {
info.getRoles().stream().map(SECURITY_ROLE_PREFIX::concat).forEach(dbAuthsSet::add);
}
// 获取权限
if (CollectionUtils.isNotEmpty(info.getPermissions())) {
dbAuthsSet.addAll(info.getPermissions());
}
Collection<? extends GrantedAuthority> authorities = AuthorityUtils.createAuthorityList(dbAuthsSet.toArray(new String[0]));
UserBasic user = info.getBasic();
String password = UUID.randomUUID().toString();
return new MyUserDetails(info, user.getLoginName(), password,
true, true, true, true, authorities);
}
}
然后创建PhoneUserDetailService 继承上面MyUserDetailsService,通过调用远程Feign接口MpRemoteUserService 进行验证码认证登录
@Slf4j
@Component
public class PhoneUserDetailService extends MyUserDetailsService {
@Autowired
private MpRemoteUserService remoteUserService;
@Override
public UserDetails loadUser(Authentication authentication) {
String phone = authentication.getPrincipal().toString();
String code = authentication.getCredentials().toString();
XCloudResponse<UserInfo> userInfoR = remoteUserService.getUserInfoByPhone(phone, code);
UserInfo userInfo = userInfoR.getPayload();
if (userInfo == null) {
throw new InvalidGrantException("手机号登陆失败:" + userInfoR.getMsg());
}
String username = phone + ":" + code;
checkUser(userInfoR, username);
return getUserDetails(userInfoR);
}
}
然后模仿User(UserDetails子类)创建我们自己的Token
public class MyAuthenticationToken extends AbstractAuthenticationToken {
private static final long serialVersionUID = 110L;
protected final Object principal;
protected Object credentials;
/**
* This constructor can be safely used by any code that wishes to create a
* <code>UsernamePasswordAuthenticationToken</code>, as the {@link
* #isAuthenticated()} will return <code>false</code>.
*
*/
public MyAuthenticationToken(Object principal, Object credentials) {
super(null);
this.principal = principal;
this.credentials = credentials;
this.setAuthenticated(false);
}
/**
* This constructor should only be used by <code>AuthenticationManager</code> or <code>AuthenticationProvider</code>
* implementations that are satisfied with producing a trusted (i.e. {@link #isAuthenticated()} = <code>true</code>)
* token token.
*
* @param principal
* @param credentials
* @param authorities
*/
public MyAuthenticationToken(Object principal, Object credentials, Collection<? extends GrantedAuthority> authorities) {
super(authorities);
this.principal = principal;
this.credentials = credentials;
super.setAuthenticated(true);
}
@Override
public Object getCredentials() {
return this.credentials;
}
@Override
public Object getPrincipal() {
return this.principal;
}
@Override
public void setAuthenticated(boolean isAuthenticated) throws IllegalArgumentException {
if(isAuthenticated) {
throw new IllegalArgumentException("Cannot set this token to trusted - use constructor which takes a GrantedAuthority list instead");
} else {
super.setAuthenticated(false);
}
}
@Override
public void eraseCredentials() {
super.eraseCredentials();
this.credentials = null;
}
}
再次创建PhoneAuthenticationToken
/**
* 手机验证码token
*
* @author jiangxiaowei
*/
public class PhoneAuthenticationToken extends MyAuthenticationToken {
public PhoneAuthenticationToken(Object principal, Object credentials) {
super(principal, credentials);
}
public PhoneAuthenticationToken(Object principal, Object credentials, Collection<? extends GrantedAuthority> authorities) {
super(principal, credentials, authorities);
}
}
大致梳理一下
- 自定义CompositeTokenGranter,参考
org.springframework.security.oauth2.provider.CompositeTokenGranter
- 创建自己的ITokenGranter接口,及适配器TokenGranterAdaptor(简化TokenGranter的代码编写)
- 自定义Authentication,并且继承AuthenticationToken
- 自定义UserDetails
- 自定义UserDetailsService,并且继承AbstractUserDetailsService,实现方法loadUser
- 自定义AuthenticationProvider,并且继承AbstractAuthenticationProvider,实现对应方法
- 自定义TokenGranter,实现接口ITokenGranter