Skip to content

Commit f9ef7c2

Browse files
JigarJoshipboros
andauthored
feat: Implemented API key authentication mechanism (#1246)
* refactor: refactor metronome metrics with consistent tags (#1241) * feat: Implemented API key authentication mechanism * fix: Fixed marshaling based on code review suggestion * feat: Added whoami endpoint * refactor: minor refactor based on code review comments --------- Co-authored-by: Peter Boros <[email protected]>
1 parent 69f1c71 commit f9ef7c2

File tree

17 files changed

+480
-79
lines changed

17 files changed

+480
-79
lines changed

api/proto

Submodule proto updated from fedd435 to 2f9df79

api/server/v1/marshaler.go

+26
Original file line numberDiff line numberDiff line change
@@ -1004,6 +1004,32 @@ func (x *ListInvoicesRequest) UnmarshalJSON(data []byte) error {
10041004
return nil
10051005
}
10061006

1007+
// UnmarshalJSON on ListAppKeysRequest. Handles query param.
1008+
func (x *ListAppKeysRequest) UnmarshalJSON(data []byte) error {
1009+
var mp map[string]jsoniter.RawMessage
1010+
1011+
if err := jsoniter.Unmarshal(data, &mp); err != nil {
1012+
return err
1013+
}
1014+
1015+
for key, value := range mp {
1016+
var v any
1017+
1018+
switch key {
1019+
case "key_type":
1020+
v = &x.KeyType
1021+
case "project":
1022+
v = &x.Project
1023+
default:
1024+
continue
1025+
}
1026+
if err := jsoniter.Unmarshal(value, v); err != nil {
1027+
return err
1028+
}
1029+
}
1030+
return nil
1031+
}
1032+
10071033
func (x *GetNamespaceMetadataResponse) MarshalJSON() ([]byte, error) {
10081034
resp := struct {
10091035
MetadataKey string `json:"metadataKey,omitempty"`

api/server/v1/tx.go

+1
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ const (
120120
QuotaLimitsMetricsMethodName = ObservabilityMethodPrefix + "QuotaLimits"
121121
QuotaUsageMethodName = ObservabilityMethodPrefix + "QuotaUsage"
122122
GetInfoMethodName = ObservabilityMethodPrefix + "GetInfo"
123+
WhoAmIMethodName = ObservabilityMethodPrefix + "WhoAmI"
123124

124125
// Realtime.
125126
PresenceMethodName = realtimeMethodPrefix + "Presence"

config/server.test.yaml

+6
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ auth:
6262
- issuer: http://tigris_gotrue:8086
6363
algorithm: HS256
6464
audience: https://tigris-testB
65+
api_keys:
66+
auds:
67+
- https://tigris-test
68+
length: 120
69+
email_suffix: "@apikey.tigrisdata.com"
70+
user_password: hello
6571
token_cache_size: 100
6672
primary_audience: https://tigris-test
6773
oauth_provider: gotrue

server/config/options.go

+12
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ type AuthzConfig struct {
7676
type AuthConfig struct {
7777
Enabled bool `mapstructure:"enabled" yaml:"enabled" json:"enabled"`
7878
Validators []ValidatorConfig `mapstructure:"validators" yaml:"validators" json:"validators"`
79+
ApiKeys ApiKeysConfig `mapstructure:"api_keys" yaml:"api_keys" json:"api_keys"`
7980
PrimaryAudience string `mapstructure:"primary_audience" yaml:"primary_audience" json:"primary_audience"`
8081
JWKSCacheTimeout time.Duration `mapstructure:"jwks_cache_timeout" yaml:"jwks_cache_timeout" json:"jwks_cache_timeout"`
8182
LogOnly bool `mapstructure:"log_only" yaml:"log_only" json:"log_only"`
@@ -118,6 +119,13 @@ type ValidatorConfig struct {
118119
Audience string `mapstructure:"audience" yaml:"audience" json:"audience"`
119120
}
120121

122+
type ApiKeysConfig struct {
123+
Auds []string `mapstructure:"auds" yaml:"auds" json:"auds"`
124+
Length int `mapstructure:"length" yaml:"length" json:"length"`
125+
EmailSuffix string `mapstructure:"email_suffix" yaml:"email_suffix" json:"email_suffix"`
126+
UserPassword string `mapstructure:"user_password" yaml:"user_password" json:"user_password"`
127+
}
128+
121129
type CdcConfig struct {
122130
Enabled bool `mapstructure:"enabled" yaml:"enabled" json:"enabled"`
123131
StreamInterval time.Duration
@@ -321,6 +329,10 @@ var DefaultConfig = Config{
321329
Audience: "https://tigris-api",
322330
},
323331
},
332+
ApiKeys: ApiKeysConfig{
333+
Auds: nil,
334+
Length: 120,
335+
},
324336
PrimaryAudience: "https://tigris-api",
325337
JWKSCacheTimeout: 5 * time.Minute,
326338
TokenValidationCacheSize: 1000,

server/middleware/auth.go

+35-6
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import (
3232
"github.com/tigrisdata/tigris/server/defaults"
3333
"github.com/tigrisdata/tigris/server/metrics"
3434
"github.com/tigrisdata/tigris/server/request"
35+
"github.com/tigrisdata/tigris/server/services/v1/auth"
3536
"github.com/tigrisdata/tigris/server/types"
3637
"google.golang.org/grpc"
3738
)
@@ -135,10 +136,10 @@ func GetJWTValidators(config *config.Config) []*validator.Validator {
135136
return jwtValidators
136137
}
137138

138-
func measuredAuthFunction(ctx context.Context, jwtValidators []*validator.Validator, config *config.Config, cache gcache.Cache) (context.Context, error) {
139+
func measuredAuthFunction(ctx context.Context, jwtValidators []*validator.Validator, config *config.Config, cache gcache.Cache, a auth.Provider) (context.Context, error) {
139140
measurement := metrics.NewMeasurement("auth", "auth", metrics.AuthSpanType, metrics.GetAuthBaseTags(ctx))
140141
measurement.StartTracing(ctx, true)
141-
ctxResult, err := authFunction(ctx, jwtValidators, config, cache)
142+
ctxResult, err := authFunction(ctx, jwtValidators, config, cache, a)
142143
if err != nil {
143144
measurement.CountErrorForScope(metrics.AuthErrorCount, measurement.GetAuthErrorTags(err))
144145
measurement.FinishWithError(ctxResult, err)
@@ -151,7 +152,7 @@ func measuredAuthFunction(ctx context.Context, jwtValidators []*validator.Valida
151152
return ctxResult, nil
152153
}
153154

154-
func authFunction(ctx context.Context, jwtValidators []*validator.Validator, config *config.Config, cache gcache.Cache) (ctxResult context.Context, err error) {
155+
func authFunction(ctx context.Context, jwtValidators []*validator.Validator, config *config.Config, cache gcache.Cache, a auth.Provider) (ctxResult context.Context, err error) {
155156
reqMetadata, err := request.GetRequestMetadataFromContext(ctx)
156157
if err != nil {
157158
log.Warn().Err(err).Msg("Failed to load request metadata")
@@ -178,9 +179,36 @@ func authFunction(ctx context.Context, jwtValidators []*validator.Validator, con
178179
if err != nil {
179180
return ctx, err
180181
}
182+
if strings.Contains(tkn, ".") {
183+
return authenticateUsingAuthToken(ctx, jwtValidators, config, cache, tkn, reqMetadata)
184+
} else if strings.HasPrefix(tkn, auth.ApiKeyPrefix) {
185+
return authenticateUsingApiKey(ctx, jwtValidators, config, cache, tkn, reqMetadata, a)
186+
}
187+
return ctx, errors.Unauthenticated("Failed to authenticate")
188+
}
181189

182-
validatedToken := getCachedToken(ctx, tkn, cache)
190+
func authenticateUsingApiKey(ctx context.Context, _ []*validator.Validator, _ *config.Config, cache gcache.Cache, apiKey string, reqMetadata *request.Metadata, a auth.Provider) (context.Context, error) {
191+
token, err := cache.Get(apiKey)
192+
if err != nil || token == nil {
193+
token, err = a.ValidateApiKey(ctx, apiKey, config.DefaultConfig.Auth.ApiKeys.Auds)
194+
if err != nil {
195+
return ctx, err
196+
}
197+
// put it to cache
198+
err = cache.Set(apiKey, token)
199+
if err != nil {
200+
log.Err(err).Msg("Could not set to the cache")
201+
}
202+
}
203+
castedToken := token.(*types.AccessToken)
204+
reqMetadata.SetNamespace(ctx, castedToken.Namespace)
205+
reqMetadata.SetAccessToken(castedToken)
206+
return ctx, nil
207+
}
183208

209+
func authenticateUsingAuthToken(ctx context.Context, jwtValidators []*validator.Validator, config *config.Config, cache gcache.Cache, tkn string, reqMetadata *request.Metadata) (context.Context, error) {
210+
validatedToken := getCachedToken(ctx, tkn, cache)
211+
var err error
184212
// if not found from cache
185213
if validatedToken == nil {
186214
count := 0
@@ -272,16 +300,17 @@ func getAuthFunction(config *config.Config) func(ctx context.Context) (context.C
272300
lruCache := gcache.New(config.Auth.TokenValidationCacheSize).
273301
Expiration(time.Duration(config.Auth.TokenValidationCacheTTLSec) * time.Second).
274302
Build()
303+
authProvider := auth.NewGotrueProvider()
275304

276305
// inline closure to access the state of jwtValidator
277306
if config.Tracing.Enabled {
278307
return func(ctx context.Context) (context.Context, error) {
279-
return measuredAuthFunction(ctx, jwtValidators, config, lruCache)
308+
return measuredAuthFunction(ctx, jwtValidators, config, lruCache, authProvider)
280309
}
281310
}
282311

283312
return func(ctx context.Context) (context.Context, error) {
284-
return authFunction(ctx, jwtValidators, config, lruCache)
313+
return authFunction(ctx, jwtValidators, config, lruCache, authProvider)
285314
}
286315
}
287316

server/middleware/auth_test.go

+8-6
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
api "github.com/tigrisdata/tigris/api/server/v1"
2727
"github.com/tigrisdata/tigris/errors"
2828
"github.com/tigrisdata/tigris/server/config"
29+
"github.com/tigrisdata/tigris/server/services/v1/auth"
2930
"github.com/tigrisdata/tigris/util/log"
3031
"google.golang.org/grpc/metadata"
3132
)
@@ -47,35 +48,36 @@ func TestAuth(t *testing.T) {
4748
FoundationDB: config.FoundationDBConfig{},
4849
}
4950
cache := gcache.New(10).Expiration(time.Duration(5) * time.Minute).Build()
51+
authProvider := auth.NewGotrueProvider()
5052
t.Run("log_only mode: no token", func(t *testing.T) {
51-
ctx, err := authFunction(context.TODO(), []*validator.Validator{{}}, &config.DefaultConfig, cache)
53+
ctx, err := authFunction(context.TODO(), []*validator.Validator{{}}, &config.DefaultConfig, cache, authProvider)
5254
require.NotNil(t, ctx)
5355
require.Nil(t, err)
5456
})
5557

5658
t.Run("enforcing mode: no token", func(t *testing.T) {
57-
_, err := authFunction(context.TODO(), []*validator.Validator{{}}, &enforcedAuthConfig, cache)
59+
_, err := authFunction(context.TODO(), []*validator.Validator{{}}, &enforcedAuthConfig, cache, authProvider)
5860
require.NotNil(t, err)
5961
require.Equal(t, err, errors.Unauthenticated("request unauthenticated with bearer"))
6062
})
6163

6264
t.Run("enforcing mode: Bad authorization string1", func(t *testing.T) {
6365
incomingCtx := metadata.NewIncomingContext(context.TODO(), metadata.Pairs("authorization", "bearer"))
64-
_, err := authFunction(incomingCtx, []*validator.Validator{{}}, &enforcedAuthConfig, cache)
66+
_, err := authFunction(incomingCtx, []*validator.Validator{{}}, &enforcedAuthConfig, cache, authProvider)
6567
require.NotNil(t, err)
6668
require.Equal(t, err, errors.Unauthenticated("bad authorization string"))
6769
})
6870

6971
t.Run("enforcing mode: Bad token", func(t *testing.T) {
7072
incomingCtx := metadata.NewIncomingContext(context.TODO(), metadata.Pairs("authorization", "bearer somebadtoken"))
71-
_, err := authFunction(incomingCtx, []*validator.Validator{{}}, &enforcedAuthConfig, cache)
73+
_, err := authFunction(incomingCtx, []*validator.Validator{{}}, &enforcedAuthConfig, cache, authProvider)
7274
require.NotNil(t, err)
73-
require.Equal(t, err, errors.Unauthenticated("Failed to validate access token, could not be validated"))
75+
require.Equal(t, errors.Unauthenticated("Failed to authenticate"), err)
7476
})
7577

7678
t.Run("enforcing mode: Bad token 2", func(t *testing.T) {
7779
incomingCtx := metadata.NewIncomingContext(context.TODO(), metadata.Pairs("authorization", "bearer some.bad.token"))
78-
_, err := authFunction(incomingCtx, []*validator.Validator{{}}, &enforcedAuthConfig, cache)
80+
_, err := authFunction(incomingCtx, []*validator.Validator{{}}, &enforcedAuthConfig, cache, authProvider)
7981
require.NotNil(t, err)
8082
require.Contains(t, err.Error(), "Failed to validate access token")
8183
})

server/middleware/authz.go

+8-12
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import (
2323
"github.com/tigrisdata/tigris/lib/container"
2424
"github.com/tigrisdata/tigris/server/config"
2525
"github.com/tigrisdata/tigris/server/request"
26-
"github.com/tigrisdata/tigris/server/types"
2726
"google.golang.org/grpc"
2827
)
2928

@@ -73,6 +72,7 @@ var (
7372
api.QuotaLimitsMetricsMethodName,
7473
api.QuotaUsageMethodName,
7574
api.GetInfoMethodName,
75+
api.WhoAmIMethodName,
7676

7777
// realtime
7878
api.ReadMessagesMethodName,
@@ -152,6 +152,7 @@ var (
152152
api.QuotaLimitsMetricsMethodName,
153153
api.QuotaUsageMethodName,
154154
api.GetInfoMethodName,
155+
api.WhoAmIMethodName,
155156

156157
// realtime
157158
api.PresenceMethodName,
@@ -253,6 +254,7 @@ var (
253254
api.QuotaLimitsMetricsMethodName,
254255
api.QuotaUsageMethodName,
255256
api.GetInfoMethodName,
257+
api.WhoAmIMethodName,
256258

257259
// realtime
258260
api.PresenceMethodName,
@@ -348,6 +350,7 @@ var (
348350
api.QuotaLimitsMetricsMethodName,
349351
api.QuotaUsageMethodName,
350352
api.GetInfoMethodName,
353+
api.WhoAmIMethodName,
351354

352355
// realtime
353356
api.PresenceMethodName,
@@ -424,11 +427,11 @@ func authorize(ctx context.Context) (err error) {
424427
Msg("Empty role allowed for transition purpose")
425428
return nil
426429
}
430+
// if !isAuthorizedProject(reqMetadata, accessToken) {
431+
// authorizationErr = errors.PermissionDenied("You are not allowed to perform operation: %s", reqMetadata.GetFullMethod())
432+
//}
427433
var authorizationErr error
428-
if !isAuthorizedProject(reqMetadata, accessToken) {
429-
authorizationErr = errors.PermissionDenied("You are not allowed to perform operation: %s", reqMetadata.GetFullMethod())
430-
}
431-
if err == nil && !isAuthorizedOperation(reqMetadata.GetFullMethod(), role) {
434+
if !isAuthorizedOperation(reqMetadata.GetFullMethod(), role) {
432435
authorizationErr = errors.PermissionDenied("You are not allowed to perform operation: %s", reqMetadata.GetFullMethod())
433436
}
434437

@@ -447,13 +450,6 @@ func authorize(ctx context.Context) (err error) {
447450
return nil
448451
}
449452

450-
func isAuthorizedProject(reqMetadata *request.Metadata, accessToken *types.AccessToken) bool {
451-
if accessToken.Project != "" && reqMetadata.GetProject() != accessToken.Project {
452-
return false
453-
}
454-
return true
455-
}
456-
457453
func isAuthorizedOperation(method string, role string) bool {
458454
if methods := getMethodsForRole(role); methods != nil {
459455
return methods.Contains(method)

server/middleware/authz_test.go

+3
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ func TestAuthzOwnerRole(t *testing.T) {
9797
require.True(t, isAuthorizedOperation(api.QuotaLimitsMetricsMethodName, ownerRoleName))
9898
require.True(t, isAuthorizedOperation(api.QuotaUsageMethodName, ownerRoleName))
9999
require.True(t, isAuthorizedOperation(api.GetInfoMethodName, ownerRoleName))
100+
require.True(t, isAuthorizedOperation(api.WhoAmIMethodName, ownerRoleName))
100101

101102
// realtime
102103
require.True(t, isAuthorizedOperation(api.PresenceMethodName, ownerRoleName))
@@ -193,6 +194,7 @@ func TestAuthzEditorRole(t *testing.T) {
193194
require.True(t, isAuthorizedOperation(api.QuotaLimitsMetricsMethodName, editorRoleName))
194195
require.True(t, isAuthorizedOperation(api.QuotaUsageMethodName, editorRoleName))
195196
require.True(t, isAuthorizedOperation(api.GetInfoMethodName, editorRoleName))
197+
require.True(t, isAuthorizedOperation(api.WhoAmIMethodName, editorRoleName))
196198

197199
// realtime
198200
require.True(t, isAuthorizedOperation(api.PresenceMethodName, editorRoleName))
@@ -263,6 +265,7 @@ func TestAuthzReadOnlyRole(t *testing.T) {
263265
require.True(t, isAuthorizedOperation(api.QuotaLimitsMetricsMethodName, readOnlyRoleName))
264266
require.True(t, isAuthorizedOperation(api.QuotaUsageMethodName, readOnlyRoleName))
265267
require.True(t, isAuthorizedOperation(api.GetInfoMethodName, readOnlyRoleName))
268+
require.True(t, isAuthorizedOperation(api.WhoAmIMethodName, readOnlyRoleName))
266269

267270
// realtime
268271
require.True(t, isAuthorizedOperation(api.ReadMessagesMethodName, readOnlyRoleName))

server/services/v1/auth/auth0.go

+5
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import (
3333
"github.com/tigrisdata/tigris/server/metadata"
3434
"github.com/tigrisdata/tigris/server/request"
3535
"github.com/tigrisdata/tigris/server/transaction"
36+
"github.com/tigrisdata/tigris/server/types"
3637
"golang.org/x/net/context/ctxhttp"
3738
)
3839

@@ -328,6 +329,10 @@ func (*auth0) ListGlobalAppKeys(_ context.Context, _ *api.ListGlobalAppKeysReque
328329
return nil, errors.Internal("auth0 implementation doesn't support it")
329330
}
330331

332+
func (*auth0) ValidateApiKey(_ context.Context, _ string, _ []string) (*types.AccessToken, error) {
333+
return nil, errors.Internal("auth0 implementation doesn't support it")
334+
}
335+
331336
func validateOwnershipAuth0(ctx context.Context, operationName string, appId string, a *auth0) (*management.Client, string, error) {
332337
client, err := a.Management.Client.Read(appId)
333338
if err != nil {

0 commit comments

Comments
 (0)