Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Polish JdbcOAuth2AuthorizationService #1908

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,19 @@
@ImportRuntimeHints(JdbcOAuth2AuthorizationService.JdbcOAuth2AuthorizationServiceRuntimeHintsRegistrar.class)
public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationService {

private static final String REFRESH_TOKEN_VALUE = "refresh_token_value";
private static final String AUTHORIZATION_CODE_VALUE = "authorization_code_value";
private static final String ACCESS_TOKEN_VALUE = "access_token_value";
private static final String OIDC_ID_TOKEN_VALUE = "oidc_id_token_value";
private static final String USER_CODE_VALUE = "user_code_value";
private static final String DEVICE_CODE_VALUE = "device_code_value";
private static final String AUTHORIZATION_CODE_METADATA = "authorization_code_metadata";
private static final String ACCESS_TOKEN_METADATA = "access_token_metadata";
private static final String OIDC_ID_TOKEN_METADATA = "oidc_id_token_metadata";
private static final String REFRESH_TOKEN_METADATA = "refresh_token_metadata";
private static final String USER_CODE_METADATA = "user_code_metadata";
private static final String DEVICE_CODE_METADATA = "device_code_metadata";

// @formatter:off
private static final String COLUMN_NAMES = "id, "
+ "registered_client_id, "
Expand Down Expand Up @@ -279,40 +292,40 @@ public OAuth2Authorization findByToken(String token, @Nullable OAuth2TokenType t
List<SqlParameterValue> parameters = new ArrayList<>();
if (tokenType == null) {
parameters.add(new SqlParameterValue(Types.VARCHAR, token));
parameters.add(mapToSqlParameter("authorization_code_value", token));
parameters.add(mapToSqlParameter("access_token_value", token));
parameters.add(mapToSqlParameter("oidc_id_token_value", token));
parameters.add(mapToSqlParameter("refresh_token_value", token));
parameters.add(mapToSqlParameter("user_code_value", token));
parameters.add(mapToSqlParameter("device_code_value", token));
parameters.add(mapToSqlParameter(AUTHORIZATION_CODE_VALUE, token));
parameters.add(mapToSqlParameter(ACCESS_TOKEN_VALUE, token));
parameters.add(mapToSqlParameter(OIDC_ID_TOKEN_VALUE, token));
parameters.add(mapToSqlParameter(REFRESH_TOKEN_VALUE, token));
parameters.add(mapToSqlParameter(USER_CODE_VALUE, token));
parameters.add(mapToSqlParameter(DEVICE_CODE_VALUE, token));
return findBy(UNKNOWN_TOKEN_TYPE_FILTER, parameters);
}
else if (OAuth2ParameterNames.STATE.equals(tokenType.getValue())) {
parameters.add(new SqlParameterValue(Types.VARCHAR, token));
return findBy(STATE_FILTER, parameters);
}
else if (OAuth2ParameterNames.CODE.equals(tokenType.getValue())) {
parameters.add(mapToSqlParameter("authorization_code_value", token));
parameters.add(mapToSqlParameter(AUTHORIZATION_CODE_VALUE, token));
return findBy(AUTHORIZATION_CODE_FILTER, parameters);
}
else if (OAuth2TokenType.ACCESS_TOKEN.equals(tokenType)) {
parameters.add(mapToSqlParameter("access_token_value", token));
parameters.add(mapToSqlParameter(ACCESS_TOKEN_VALUE, token));
return findBy(ACCESS_TOKEN_FILTER, parameters);
}
else if (OidcParameterNames.ID_TOKEN.equals(tokenType.getValue())) {
parameters.add(mapToSqlParameter("oidc_id_token_value", token));
parameters.add(mapToSqlParameter(OIDC_ID_TOKEN_VALUE, token));
return findBy(ID_TOKEN_FILTER, parameters);
}
else if (OAuth2TokenType.REFRESH_TOKEN.equals(tokenType)) {
parameters.add(mapToSqlParameter("refresh_token_value", token));
parameters.add(mapToSqlParameter(REFRESH_TOKEN_VALUE, token));
return findBy(REFRESH_TOKEN_FILTER, parameters);
}
else if (OAuth2ParameterNames.USER_CODE.equals(tokenType.getValue())) {
parameters.add(mapToSqlParameter("user_code_value", token));
parameters.add(mapToSqlParameter(USER_CODE_VALUE, token));
return findBy(USER_CODE_FILTER, parameters);
}
else if (OAuth2ParameterNames.DEVICE_CODE.equals(tokenType.getValue())) {
parameters.add(mapToSqlParameter("device_code_value", token));
parameters.add(mapToSqlParameter(DEVICE_CODE_VALUE, token));
return findBy(DEVICE_CODE_FILTER, parameters);
}
return null;
Expand Down Expand Up @@ -375,29 +388,29 @@ private static void initColumnMetadata(JdbcOperations jdbcOperations) {

columnMetadata = getColumnMetadata(jdbcOperations, "attributes", Types.BLOB);
columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
columnMetadata = getColumnMetadata(jdbcOperations, "authorization_code_value", Types.BLOB);
columnMetadata = getColumnMetadata(jdbcOperations, AUTHORIZATION_CODE_VALUE, Types.BLOB);
columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
columnMetadata = getColumnMetadata(jdbcOperations, "authorization_code_metadata", Types.BLOB);
columnMetadata = getColumnMetadata(jdbcOperations, AUTHORIZATION_CODE_METADATA, Types.BLOB);
columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
columnMetadata = getColumnMetadata(jdbcOperations, "access_token_value", Types.BLOB);
columnMetadata = getColumnMetadata(jdbcOperations, ACCESS_TOKEN_VALUE, Types.BLOB);
columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
columnMetadata = getColumnMetadata(jdbcOperations, "access_token_metadata", Types.BLOB);
columnMetadata = getColumnMetadata(jdbcOperations, ACCESS_TOKEN_METADATA, Types.BLOB);
columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
columnMetadata = getColumnMetadata(jdbcOperations, "oidc_id_token_value", Types.BLOB);
columnMetadata = getColumnMetadata(jdbcOperations, OIDC_ID_TOKEN_VALUE, Types.BLOB);
columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
columnMetadata = getColumnMetadata(jdbcOperations, "oidc_id_token_metadata", Types.BLOB);
columnMetadata = getColumnMetadata(jdbcOperations, OIDC_ID_TOKEN_METADATA, Types.BLOB);
columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
columnMetadata = getColumnMetadata(jdbcOperations, "refresh_token_value", Types.BLOB);
columnMetadata = getColumnMetadata(jdbcOperations, REFRESH_TOKEN_VALUE, Types.BLOB);
columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
columnMetadata = getColumnMetadata(jdbcOperations, "refresh_token_metadata", Types.BLOB);
columnMetadata = getColumnMetadata(jdbcOperations, REFRESH_TOKEN_METADATA, Types.BLOB);
columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
columnMetadata = getColumnMetadata(jdbcOperations, "user_code_value", Types.BLOB);
columnMetadata = getColumnMetadata(jdbcOperations, USER_CODE_VALUE, Types.BLOB);
columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
columnMetadata = getColumnMetadata(jdbcOperations, "user_code_metadata", Types.BLOB);
columnMetadata = getColumnMetadata(jdbcOperations, USER_CODE_METADATA, Types.BLOB);
columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
columnMetadata = getColumnMetadata(jdbcOperations, "device_code_value", Types.BLOB);
columnMetadata = getColumnMetadata(jdbcOperations, DEVICE_CODE_VALUE, Types.BLOB);
columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
columnMetadata = getColumnMetadata(jdbcOperations, "device_code_metadata", Types.BLOB);
columnMetadata = getColumnMetadata(jdbcOperations, DEVICE_CODE_METADATA, Types.BLOB);
columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
}

Expand Down Expand Up @@ -490,24 +503,24 @@ public OAuth2Authorization mapRow(ResultSet rs, int rowNum) throws SQLException

Instant tokenIssuedAt;
Instant tokenExpiresAt;
String authorizationCodeValue = getLobValue(rs, "authorization_code_value");
String authorizationCodeValue = getLobValue(rs, AUTHORIZATION_CODE_VALUE);

if (StringUtils.hasText(authorizationCodeValue)) {
tokenIssuedAt = rs.getTimestamp("authorization_code_issued_at").toInstant();
tokenExpiresAt = rs.getTimestamp("authorization_code_expires_at").toInstant();
Map<String, Object> authorizationCodeMetadata = parseMap(
getLobValue(rs, "authorization_code_metadata"));
getLobValue(rs, AUTHORIZATION_CODE_METADATA));

OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(authorizationCodeValue,
tokenIssuedAt, tokenExpiresAt);
builder.token(authorizationCode, (metadata) -> metadata.putAll(authorizationCodeMetadata));
}

String accessTokenValue = getLobValue(rs, "access_token_value");
String accessTokenValue = getLobValue(rs, ACCESS_TOKEN_VALUE);
if (StringUtils.hasText(accessTokenValue)) {
tokenIssuedAt = rs.getTimestamp("access_token_issued_at").toInstant();
tokenExpiresAt = rs.getTimestamp("access_token_expires_at").toInstant();
Map<String, Object> accessTokenMetadata = parseMap(getLobValue(rs, "access_token_metadata"));
Map<String, Object> accessTokenMetadata = parseMap(getLobValue(rs, ACCESS_TOKEN_METADATA));
OAuth2AccessToken.TokenType tokenType = null;
if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase(rs.getString("access_token_type"))) {
tokenType = OAuth2AccessToken.TokenType.BEARER;
Expand All @@ -527,47 +540,47 @@ else if (OAuth2AccessToken.TokenType.DPOP.getValue()
builder.token(accessToken, (metadata) -> metadata.putAll(accessTokenMetadata));
}

String oidcIdTokenValue = getLobValue(rs, "oidc_id_token_value");
String oidcIdTokenValue = getLobValue(rs, OIDC_ID_TOKEN_VALUE);
if (StringUtils.hasText(oidcIdTokenValue)) {
tokenIssuedAt = rs.getTimestamp("oidc_id_token_issued_at").toInstant();
tokenExpiresAt = rs.getTimestamp("oidc_id_token_expires_at").toInstant();
Map<String, Object> oidcTokenMetadata = parseMap(getLobValue(rs, "oidc_id_token_metadata"));
Map<String, Object> oidcTokenMetadata = parseMap(getLobValue(rs, OIDC_ID_TOKEN_METADATA));

OidcIdToken oidcToken = new OidcIdToken(oidcIdTokenValue, tokenIssuedAt, tokenExpiresAt,
(Map<String, Object>) oidcTokenMetadata.get(OAuth2Authorization.Token.CLAIMS_METADATA_NAME));
builder.token(oidcToken, (metadata) -> metadata.putAll(oidcTokenMetadata));
}

String refreshTokenValue = getLobValue(rs, "refresh_token_value");
String refreshTokenValue = getLobValue(rs, REFRESH_TOKEN_VALUE);
if (StringUtils.hasText(refreshTokenValue)) {
tokenIssuedAt = rs.getTimestamp("refresh_token_issued_at").toInstant();
tokenExpiresAt = null;
Timestamp refreshTokenExpiresAt = rs.getTimestamp("refresh_token_expires_at");
if (refreshTokenExpiresAt != null) {
tokenExpiresAt = refreshTokenExpiresAt.toInstant();
}
Map<String, Object> refreshTokenMetadata = parseMap(getLobValue(rs, "refresh_token_metadata"));
Map<String, Object> refreshTokenMetadata = parseMap(getLobValue(rs, REFRESH_TOKEN_METADATA));

OAuth2RefreshToken refreshToken = new OAuth2RefreshToken(refreshTokenValue, tokenIssuedAt,
tokenExpiresAt);
builder.token(refreshToken, (metadata) -> metadata.putAll(refreshTokenMetadata));
}

String userCodeValue = getLobValue(rs, "user_code_value");
String userCodeValue = getLobValue(rs, USER_CODE_VALUE);
if (StringUtils.hasText(userCodeValue)) {
tokenIssuedAt = rs.getTimestamp("user_code_issued_at").toInstant();
tokenExpiresAt = rs.getTimestamp("user_code_expires_at").toInstant();
Map<String, Object> userCodeMetadata = parseMap(getLobValue(rs, "user_code_metadata"));
Map<String, Object> userCodeMetadata = parseMap(getLobValue(rs, USER_CODE_METADATA));

OAuth2UserCode userCode = new OAuth2UserCode(userCodeValue, tokenIssuedAt, tokenExpiresAt);
builder.token(userCode, (metadata) -> metadata.putAll(userCodeMetadata));
}

String deviceCodeValue = getLobValue(rs, "device_code_value");
String deviceCodeValue = getLobValue(rs, DEVICE_CODE_VALUE);
if (StringUtils.hasText(deviceCodeValue)) {
tokenIssuedAt = rs.getTimestamp("device_code_issued_at").toInstant();
tokenExpiresAt = rs.getTimestamp("device_code_expires_at").toInstant();
Map<String, Object> deviceCodeMetadata = parseMap(getLobValue(rs, "device_code_metadata"));
Map<String, Object> deviceCodeMetadata = parseMap(getLobValue(rs, DEVICE_CODE_METADATA));

OAuth2DeviceCode deviceCode = new OAuth2DeviceCode(deviceCodeValue, tokenIssuedAt, tokenExpiresAt);
builder.token(deviceCode, (metadata) -> metadata.putAll(deviceCodeMetadata));
Expand Down Expand Up @@ -670,13 +683,13 @@ public List<SqlParameterValue> apply(OAuth2Authorization authorization) {

OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode = authorization
.getToken(OAuth2AuthorizationCode.class);
List<SqlParameterValue> authorizationCodeSqlParameters = toSqlParameterList("authorization_code_value",
"authorization_code_metadata", authorizationCode);
List<SqlParameterValue> authorizationCodeSqlParameters = toSqlParameterList(AUTHORIZATION_CODE_VALUE,
AUTHORIZATION_CODE_METADATA, authorizationCode);
parameters.addAll(authorizationCodeSqlParameters);

OAuth2Authorization.Token<OAuth2AccessToken> accessToken = authorization.getToken(OAuth2AccessToken.class);
List<SqlParameterValue> accessTokenSqlParameters = toSqlParameterList("access_token_value",
"access_token_metadata", accessToken);
List<SqlParameterValue> accessTokenSqlParameters = toSqlParameterList(ACCESS_TOKEN_VALUE,
ACCESS_TOKEN_METADATA, accessToken);
parameters.addAll(accessTokenSqlParameters);
String accessTokenType = null;
String accessTokenScopes = null;
Expand All @@ -691,23 +704,23 @@ public List<SqlParameterValue> apply(OAuth2Authorization authorization) {
parameters.add(new SqlParameterValue(Types.VARCHAR, accessTokenScopes));

OAuth2Authorization.Token<OidcIdToken> oidcIdToken = authorization.getToken(OidcIdToken.class);
List<SqlParameterValue> oidcIdTokenSqlParameters = toSqlParameterList("oidc_id_token_value",
"oidc_id_token_metadata", oidcIdToken);
List<SqlParameterValue> oidcIdTokenSqlParameters = toSqlParameterList(OIDC_ID_TOKEN_VALUE,
OIDC_ID_TOKEN_METADATA, oidcIdToken);
parameters.addAll(oidcIdTokenSqlParameters);

OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken = authorization.getRefreshToken();
List<SqlParameterValue> refreshTokenSqlParameters = toSqlParameterList("refresh_token_value",
"refresh_token_metadata", refreshToken);
List<SqlParameterValue> refreshTokenSqlParameters = toSqlParameterList(REFRESH_TOKEN_VALUE,
REFRESH_TOKEN_METADATA, refreshToken);
parameters.addAll(refreshTokenSqlParameters);

OAuth2Authorization.Token<OAuth2UserCode> userCode = authorization.getToken(OAuth2UserCode.class);
List<SqlParameterValue> userCodeSqlParameters = toSqlParameterList("user_code_value", "user_code_metadata",
List<SqlParameterValue> userCodeSqlParameters = toSqlParameterList(USER_CODE_VALUE, USER_CODE_METADATA,
userCode);
parameters.addAll(userCodeSqlParameters);

OAuth2Authorization.Token<OAuth2DeviceCode> deviceCode = authorization.getToken(OAuth2DeviceCode.class);
List<SqlParameterValue> deviceCodeSqlParameters = toSqlParameterList("device_code_value",
"device_code_metadata", deviceCode);
List<SqlParameterValue> deviceCodeSqlParameters = toSqlParameterList(DEVICE_CODE_VALUE,
DEVICE_CODE_METADATA, deviceCode);
parameters.addAll(deviceCodeSqlParameters);

return parameters;
Expand Down