Skip to content

JWT Auth: Extract Pub/Sub Rules from Claims + Minor Fixes and Improvements #235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 3 additions & 50 deletions application/src/main/data/upgrade/basic/schema_update.sql
Original file line number Diff line number Diff line change
Expand Up @@ -14,55 +14,8 @@
-- limitations under the License.
--

-- UPGRADE FROM VERSION 2.0.0 TO 2.0.1 START
-- UPGRADE FROM VERSION 2.1.0 TO 2.2.0 START

ALTER TABLE broker_user DROP COLUMN IF EXISTS search_text;
ALTER TABLE mqtt_client_credentials DROP COLUMN IF EXISTS search_text;
ALTER TABLE application_shared_subscription DROP COLUMN IF EXISTS search_text;
ALTER TABLE websocket_connection DROP COLUMN IF EXISTS search_text;
UPDATE mqtt_client_credentials SET credentials_type = 'X_509' WHERE credentials_type = 'SSL';

-- UPGRADE FROM VERSION 2.0.0 TO 2.0.1 END

-- UPGRADE FROM VERSION 2.0.1 TO 2.1.0 START

CREATE TABLE IF NOT EXISTS integration (
id uuid NOT NULL CONSTRAINT integration_pkey PRIMARY KEY,
created_time bigint NOT NULL,
disconnected_time bigint,
additional_info varchar,
configuration varchar(10000000),
enabled boolean,
name varchar(255),
type varchar(255),
status varchar
);

CREATE TABLE IF NOT EXISTS stats_event (
id uuid NOT NULL,
ts bigint NOT NULL,
entity_id uuid NOT NULL,
service_id varchar NOT NULL,
e_messages_processed bigint NOT NULL,
e_errors_occurred bigint NOT NULL
) PARTITION BY RANGE (ts);

CREATE TABLE IF NOT EXISTS lc_event (
id uuid NOT NULL,
ts bigint NOT NULL,
entity_id uuid NOT NULL,
service_id varchar NOT NULL,
e_type varchar NOT NULL,
e_success boolean NOT NULL,
e_error varchar
) PARTITION BY RANGE (ts);

CREATE TABLE IF NOT EXISTS error_event (
id uuid NOT NULL,
ts bigint NOT NULL,
entity_id uuid NOT NULL,
service_id varchar NOT NULL,
e_method varchar NOT NULL,
e_error varchar
) PARTITION BY RANGE (ts);

-- UPGRADE FROM VERSION 2.0.1 TO 2.1.0 END
-- UPGRADE FROM VERSION 2.1.0 TO 2.2.0 END
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,6 @@ public static InternodeNotificationProto toMqttAuthSettingUpdateProto(MqttAuthSe
return InternodeNotificationProto.newBuilder()
.setMqttAuthSettingsProto(MqttAuthSettingsProto.newBuilder()
.addAllPriorities(toMqttAuthPriorities(mqttAuthSettings.getPriorities()))
.setUseListenerBasedProviderOnly(mqttAuthSettings.isUseListenerBasedProviderOnly())
.build()).build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ private boolean existsBasicCredentials() {
}

private boolean existsX509Credentials() {
return existsByCredentialsType(ClientCredentialsType.SSL);
return existsByCredentialsType(ClientCredentialsType.X_509);
}

private boolean existsScramCredentials() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,26 +47,22 @@ public class DefaultAuthorizationRoutingService implements AuthorizationRoutingS
private final AdminSettingsService adminSettingsService;

private volatile List<MqttAuthProviderType> priorities;
private volatile boolean useListenerBasedProviderOnly;

@PostConstruct
public void init() {
AdminSettings mqttAuthorization = adminSettingsService.findAdminSettingsByKey(SysAdminSettingType.MQTT_AUTHORIZATION.getKey());
if (mqttAuthorization == null) {
priorities = MqttAuthProviderType.getDefaultPriorityList();
log.warn("Failed to find MQTT authorization settings. Going to apply default settings. " +
"Auth Priorities {}, Use listener based provider only: {}", priorities, useListenerBasedProviderOnly);
priorities = MqttAuthProviderType.defaultPriorityList;
log.warn("Failed to find MQTT authorization settings. Going to use default authentication execution order {}", priorities);
return;
}
MqttAuthSettings mqttAuthSettings = MqttAuthSettings.fromJsonValue(mqttAuthorization.getJsonValue());
useListenerBasedProviderOnly = mqttAuthSettings.isUseListenerBasedProviderOnly();
priorities = getPriorities(mqttAuthSettings.getPriorities());
}

@Override
public void onMqttAuthSettingsUpdate(MqttAuthSettingsProto mqttAuthSettingsProto) {
priorities = getPriorities(ProtoConverter.fromMqttAuthPriorities(mqttAuthSettingsProto.getPrioritiesList()));
useListenerBasedProviderOnly = mqttAuthSettingsProto.getUseListenerBasedProviderOnly();
}

@Override
Expand All @@ -79,10 +75,9 @@ public AuthResponse executeAuthFlow(AuthContext authContext) {
return AuthResponse.defaultAuthResponse();
}

List<MqttAuthProviderType> prioritiesForCurrentAuthContext = getPrioritiesForCurrentAuthContext(authContext);
List<String> failureReasons = new ArrayList<>(prioritiesForCurrentAuthContext.size());
List<String> failureReasons = new ArrayList<>(priorities.size());

for (MqttAuthProviderType providerType : prioritiesForCurrentAuthContext) {
for (MqttAuthProviderType providerType : priorities) {
AuthResponse response = switch (providerType) {
case JWT -> jwtMqttClientAuthProvider.authenticate(authContext);
case MQTT_BASIC -> basicMqttClientAuthProvider.authenticate(authContext);
Expand All @@ -103,14 +98,6 @@ private boolean defaultProvidersEnabled() {
jwtMqttClientAuthProvider.isEnabled();
}

private List<MqttAuthProviderType> getPrioritiesForCurrentAuthContext(AuthContext authContext) {
List<MqttAuthProviderType> effectivePriorities = new ArrayList<>(priorities);
if (useListenerBasedProviderOnly) {
effectivePriorities.remove(authContext.isSecurePortUsed() ? MqttAuthProviderType.MQTT_BASIC : MqttAuthProviderType.X_509);
}
return effectivePriorities;
}

private void addFailureReason(AuthContext authContext, AuthResponse response, String authType, List<String> failureReasons) {
String reason = response.getReason();
if (log.isDebugEnabled()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
package org.thingsboard.mqtt.broker.service.auth.providers;

import lombok.Builder;
import lombok.Getter;
import lombok.Data;
import org.thingsboard.mqtt.broker.common.data.ClientType;
import org.thingsboard.mqtt.broker.common.data.security.MqttAuthProviderType;
import org.thingsboard.mqtt.broker.service.security.authorization.AuthRulePatterns;

import java.util.List;

@Getter
@Data
@Builder(toBuilder = true)
public class AuthResponse {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
package org.thingsboard.mqtt.broker.service.auth.providers.jwt;

import com.nimbusds.jwt.JWTClaimsSet;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.thingsboard.mqtt.broker.common.data.ClientType;
import org.thingsboard.mqtt.broker.common.data.security.jwt.JwtMqttAuthProviderConfiguration;
import org.thingsboard.mqtt.broker.common.data.util.AuthRulesUtil;
import org.thingsboard.mqtt.broker.common.data.util.StringUtils;
import org.thingsboard.mqtt.broker.service.auth.providers.AuthContext;
import org.thingsboard.mqtt.broker.service.auth.providers.AuthResponse;
import org.thingsboard.mqtt.broker.service.security.authorization.AuthRulePatterns;
Expand All @@ -26,8 +30,24 @@
import java.util.Date;
import java.util.List;
import java.util.Objects;
import java.util.regex.Pattern;

public record JwtClaimsValidator(JwtMqttAuthProviderConfiguration configuration, AuthRulePatterns authRulePatterns) {
@Data
@Slf4j
public class JwtClaimsValidator {

private final JwtMqttAuthProviderConfiguration configuration;
private final AuthRulePatterns defaultAuthRulePatterns;

private final boolean hasPubAuthRulesClaim;
private final boolean hasSubAuthRulesClaim;

public JwtClaimsValidator(JwtMqttAuthProviderConfiguration configuration, AuthRulePatterns defaultAuthRulePatterns) {
this.configuration = configuration;
this.defaultAuthRulePatterns = defaultAuthRulePatterns;
this.hasPubAuthRulesClaim = StringUtils.isNotBlank(configuration.getPubAuthRuleClaim());
this.hasSubAuthRulesClaim = StringUtils.isNotBlank(configuration.getSubAuthRuleClaim());
}

public AuthResponse validateAll(AuthContext authContext, JWTClaimsSet claims) throws ParseException {
Date now = new Date();
Expand All @@ -44,7 +64,8 @@ public AuthResponse validateAll(AuthContext authContext, JWTClaimsSet claims) th
return AuthResponse.failure("Failed to validate JWT auth claims.");
}
ClientType clientType = resolveClientType(claims);
return AuthResponse.success(clientType, List.of(authRulePatterns));
AuthRulePatterns rulePatterns = resolveAuthRulePatterns(claims);
return AuthResponse.success(clientType, List.of(rulePatterns));
}

private ClientType resolveClientType(JWTClaimsSet claims) throws ParseException {
Expand Down Expand Up @@ -80,4 +101,32 @@ private boolean validateAuthClaims(AuthContext authContext, JWTClaimsSet claims)
}
return true;
}

private AuthRulePatterns resolveAuthRulePatterns(JWTClaimsSet claims) {
if (!hasPubAuthRulesClaim && !hasSubAuthRulesClaim) {
return defaultAuthRulePatterns;
}
return AuthRulePatterns.of(
resolvePatterns(hasPubAuthRulesClaim, claims, configuration.getPubAuthRuleClaim(), defaultAuthRulePatterns.getPubPatterns()),
resolvePatterns(hasSubAuthRulesClaim, claims, configuration.getSubAuthRuleClaim(), defaultAuthRulePatterns.getSubPatterns())
);
}

private List<Pattern> resolvePatterns(boolean enabled, JWTClaimsSet claims, String claimName, List<Pattern> fallback) {
if (!enabled) {
return fallback;
}
try {
List<String> raw = claims.getStringListClaim(claimName);
return raw == null ? fallback : AuthRulesUtil.fromStringList(raw);
} catch (Exception e) {
if (log.isDebugEnabled()) {
log.debug("[{}] Failed to parse auth rules claim, claims {}", claimName, claims.toString(), e);
} else {
log.warn("[{}] Failed to parse auth rules claim", claimName, e);
}
return fallback;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -61,25 +61,25 @@ public void init() {
}

private JwtVerificationStrategy createStrategy() throws JOSEException {
return switch (configuration.getJwtVerifierType()) {
case ALGORITHM_BASED -> {
var conf = (AlgorithmBasedVerifierConfiguration) configuration.getJwtVerifierConfiguration();
yield switch (conf.getAlgorithm()) {
case HMAC_BASED -> {
String rawSecret = ((HmacBasedAlgorithmConfiguration) conf.getJwtSignAlgorithmConfiguration()).getSecret();
yield new HmacJwtVerificationStrategy(rawSecret, new JwtClaimsValidator(configuration, authRulePatterns));
}
case PEM_KEY -> {
String publicPemKey = ((PemKeyAlgorithmConfiguration) conf.getJwtSignAlgorithmConfiguration()).getPublicPemKey();
yield new PemKeyJwtVerificationStrategy(publicPemKey, new JwtClaimsValidator(configuration, authRulePatterns));
}
};
var jwtVerifierConfig = configuration.getJwtVerifierConfiguration();
var validator = new JwtClaimsValidator(configuration, authRulePatterns);

if (jwtVerifierConfig instanceof AlgorithmBasedVerifierConfiguration algConfiguration) {
var algoConfig = algConfiguration.getJwtSignAlgorithmConfiguration();
if (algoConfig instanceof HmacBasedAlgorithmConfiguration hmacConfig) {
return new HmacJwtVerificationStrategy(hmacConfig.getSecret(), validator);
}
case JWKS -> {
var conf = (JwksVerifierConfiguration) configuration.getJwtVerifierConfiguration();
yield new JwksVerificationStrategy(conf, new JwtClaimsValidator(configuration, authRulePatterns));
if (algoConfig instanceof PemKeyAlgorithmConfiguration pemConfig) {
return new PemKeyJwtVerificationStrategy(pemConfig.getPublicPemKey(), validator);
}
};
throw new IllegalArgumentException("Unsupported AlgorithmBasedVerifierConfiguration: " + algoConfig.getClass().getSimpleName());
}

if (jwtVerifierConfig instanceof JwksVerifierConfiguration jwksConfig) {
return new JwksVerificationStrategy(jwksConfig, validator);
}

throw new IllegalArgumentException("Unsupported JwtVerifierConfiguration: " + jwtVerifierConfig.getClass().getSimpleName());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public AuthResponse authenticate(AuthContext authContext) {
if (!enabled) {
return AuthResponse.providerDisabled(MqttAuthProviderType.X_509);
}
if (authContext.getSslHandler() == null) {
if (!authContext.isSecurePortUsed()) {
String errorMsg = SSL_HANDLER_NOT_CONSTRUCTED.getErrorMsg();
String logErrorMsg = "[{}] " + errorMsg;
log.error(logErrorMsg, authContext);
Expand Down Expand Up @@ -194,9 +194,9 @@ private SslCredentialsCacheValue getSslRegexBasedCredentials() {
SslCredentialsCacheValue sslCredentialsCacheValue = getFromSslRegexCache();
if (sslCredentialsCacheValue == null) {
log.debug("sslRegexBasedCredentials cache is empty");
List<MqttClientCredentials> sslCredentials = clientCredentialsService.findByCredentialsType(ClientCredentialsType.SSL);
List<MqttClientCredentials> sslCredentials = clientCredentialsService.findByCredentialsType(ClientCredentialsType.X_509);
if (sslCredentials.isEmpty()) {
log.debug("SSL credentials are not found in DB");
log.debug("X_509 credentials are not found in DB");
return null;
} else {
sslCredentialsCacheValue = prepareSslRegexCredentialsWithValuesFromDb(sslCredentials);
Expand All @@ -205,7 +205,7 @@ private SslCredentialsCacheValue getSslRegexBasedCredentials() {
}
} else {
if (sslCredentialsCacheValue.getCredentials().isEmpty()) {
log.debug("Got empty SSL regex based credentials list from cache");
log.debug("Got empty X_509 regex based credentials list from cache");
}
return sslCredentialsCacheValue;
}
Expand Down Expand Up @@ -249,12 +249,12 @@ private SslCredentialsCacheValue prepareSslRegexCredentialsWithValuesFromDb(List

private SslCredentialsCacheValue getFromSslRegexCache() {
Cache cache = getSslRegexCredentialsCache();
return JacksonUtil.fromString(cache.get(ClientCredentialsType.SSL, String.class), SslCredentialsCacheValue.class);
return JacksonUtil.fromString(cache.get(ClientCredentialsType.X_509, String.class), SslCredentialsCacheValue.class);
}

private void putInSslRegexCache(SslCredentialsCacheValue sslCredentialsCacheValue) {
Cache cache = getSslRegexCredentialsCache();
cache.put(ClientCredentialsType.SSL, JacksonUtil.toString(sslCredentialsCacheValue));
cache.put(ClientCredentialsType.X_509, JacksonUtil.toString(sslCredentialsCacheValue));
}

private Cache getSslRegexCredentialsCache() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,11 @@ public class MqttAuthSettings implements Serializable {
@Serial
private static final long serialVersionUID = -8045245463193283033L;

private boolean useListenerBasedProviderOnly;
private List<MqttAuthProviderType> priorities;

public static AdminSettings createDefaults() {
MqttAuthSettings mqttAuthSettings = new MqttAuthSettings();
mqttAuthSettings.setUseListenerBasedProviderOnly(false);
mqttAuthSettings.setPriorities(MqttAuthProviderType.getDefaultPriorityList());
mqttAuthSettings.setPriorities(MqttAuthProviderType.defaultPriorityList);
return toAdminSettings(mqttAuthSettings);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@
@RequiredArgsConstructor
public class DefaultDataUpdateService implements DataUpdateService {

@Value("${security.mqtt.auth_strategy:BOTH}")
private String authStrategy;
@Value("${security.mqtt.basic.enabled:false}")
private boolean basicAuthEnabled;
@Value("${security.mqtt.ssl.enabled:false}")
Expand Down Expand Up @@ -66,8 +64,7 @@ private void createMqttAuthSettingsIfNotExist() {
return;
}
MqttAuthSettings mqttAuthSettings = new MqttAuthSettings();
mqttAuthSettings.setUseListenerBasedProviderOnly(isUseListenerBasedProviderOnly());
mqttAuthSettings.setPriorities(MqttAuthProviderType.getDefaultPriorityList());
mqttAuthSettings.setPriorities(MqttAuthProviderType.defaultPriorityList);
AdminSettings adminSettings = MqttAuthSettings.toAdminSettings(mqttAuthSettings);
adminSettingsService.saveAdminSettings(adminSettings);
log.info("Finished MQTT auth setting creation!");
Expand Down Expand Up @@ -100,11 +97,6 @@ private void createMqttAuthProvidersIfNotExist() {
log.info("Finished MQTT auth providers creation!");
}

boolean isUseListenerBasedProviderOnly() {
return Optional.ofNullable(System.getenv("SECURITY_MQTT_AUTH_STRATEGY"))
.map(String::trim).map("SINGLE"::equalsIgnoreCase).orElse("SINGLE".equalsIgnoreCase(authStrategy));
}

boolean isBasicAuthEnabled() {
return getLegacyConfig("SECURITY_MQTT_BASIC_ENABLED", basicAuthEnabled);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,9 @@ public class AuthRulePatterns {
public static AuthRulePatterns newInstance(List<Pattern> patterns) {
return new AuthRulePatterns(patterns, patterns);
}

public static AuthRulePatterns of(List<Pattern> pubPatterns, List<Pattern> subPatterns) {
return new AuthRulePatterns(pubPatterns, subPatterns);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,6 @@ public void givenMqttAuthSettings_whenConvertToProtoAndBack_thenPrioritiesPreser
);

MqttAuthSettings settings = new MqttAuthSettings();
settings.setUseListenerBasedProviderOnly(true);
settings.setPriorities(originalPriorities);

InternodeNotificationProto proto = ProtoConverter.toMqttAuthSettingUpdateProto(settings);
Expand All @@ -600,7 +599,6 @@ public void givenMqttAuthSettings_whenConvertToProtoAndBack_thenPrioritiesPreser
@Test
public void givenMqttAuthSettingsWithNullPriorities_whenConvertToProto_thenValidateNullHandledProperly() {
MqttAuthSettings settings = new MqttAuthSettings();
settings.setUseListenerBasedProviderOnly(true);
settings.setPriorities(null);

InternodeNotificationProto proto = ProtoConverter.toMqttAuthSettingUpdateProto(settings);
Expand All @@ -610,7 +608,6 @@ public void givenMqttAuthSettingsWithNullPriorities_whenConvertToProto_thenValid
// Validate
MqttAuthSettingsProto mqttAuthSettingsProto = proto.getMqttAuthSettingsProto();
assertTrue(mqttAuthSettingsProto.getPrioritiesList().isEmpty());
assertTrue(mqttAuthSettingsProto.getUseListenerBasedProviderOnly());
}

@Test
Expand Down
Loading