@@ -32,6 +32,7 @@ import (
32
32
"github.com/tigrisdata/tigris/server/defaults"
33
33
"github.com/tigrisdata/tigris/server/metrics"
34
34
"github.com/tigrisdata/tigris/server/request"
35
+ "github.com/tigrisdata/tigris/server/services/v1/auth"
35
36
"github.com/tigrisdata/tigris/server/types"
36
37
"google.golang.org/grpc"
37
38
)
@@ -135,10 +136,10 @@ func GetJWTValidators(config *config.Config) []*validator.Validator {
135
136
return jwtValidators
136
137
}
137
138
138
- func measuredAuthFunction (ctx context.Context , jwtValidators []* validator.Validator , config * config.Config , cache gcache.Cache ) (context.Context , error ) {
139
+ func measuredAuthFunction (ctx context.Context , jwtValidators []* validator.Validator , config * config.Config , cache gcache.Cache , a auth. Provider ) (context.Context , error ) {
139
140
measurement := metrics .NewMeasurement ("auth" , "auth" , metrics .AuthSpanType , metrics .GetAuthBaseTags (ctx ))
140
141
measurement .StartTracing (ctx , true )
141
- ctxResult , err := authFunction (ctx , jwtValidators , config , cache )
142
+ ctxResult , err := authFunction (ctx , jwtValidators , config , cache , a )
142
143
if err != nil {
143
144
measurement .CountErrorForScope (metrics .AuthErrorCount , measurement .GetAuthErrorTags (err ))
144
145
measurement .FinishWithError (ctxResult , err )
@@ -151,7 +152,7 @@ func measuredAuthFunction(ctx context.Context, jwtValidators []*validator.Valida
151
152
return ctxResult , nil
152
153
}
153
154
154
- func authFunction (ctx context.Context , jwtValidators []* validator.Validator , config * config.Config , cache gcache.Cache ) (ctxResult context.Context , err error ) {
155
+ func authFunction (ctx context.Context , jwtValidators []* validator.Validator , config * config.Config , cache gcache.Cache , a auth. Provider ) (ctxResult context.Context , err error ) {
155
156
reqMetadata , err := request .GetRequestMetadataFromContext (ctx )
156
157
if err != nil {
157
158
log .Warn ().Err (err ).Msg ("Failed to load request metadata" )
@@ -178,9 +179,36 @@ func authFunction(ctx context.Context, jwtValidators []*validator.Validator, con
178
179
if err != nil {
179
180
return ctx , err
180
181
}
182
+ if strings .Contains (tkn , "." ) {
183
+ return authenticateUsingAuthToken (ctx , jwtValidators , config , cache , tkn , reqMetadata )
184
+ } else if strings .HasPrefix (tkn , auth .ApiKeyPrefix ) {
185
+ return authenticateUsingApiKey (ctx , jwtValidators , config , cache , tkn , reqMetadata , a )
186
+ }
187
+ return ctx , errors .Unauthenticated ("Failed to authenticate" )
188
+ }
181
189
182
- validatedToken := getCachedToken (ctx , tkn , cache )
190
+ func authenticateUsingApiKey (ctx context.Context , _ []* validator.Validator , _ * config.Config , cache gcache.Cache , apiKey string , reqMetadata * request.Metadata , a auth.Provider ) (context.Context , error ) {
191
+ token , err := cache .Get (apiKey )
192
+ if err != nil || token == nil {
193
+ token , err = a .ValidateApiKey (ctx , apiKey , config .DefaultConfig .Auth .ApiKeys .Auds )
194
+ if err != nil {
195
+ return ctx , err
196
+ }
197
+ // put it to cache
198
+ err = cache .Set (apiKey , token )
199
+ if err != nil {
200
+ log .Err (err ).Msg ("Could not set to the cache" )
201
+ }
202
+ }
203
+ castedToken := token .(* types.AccessToken )
204
+ reqMetadata .SetNamespace (ctx , castedToken .Namespace )
205
+ reqMetadata .SetAccessToken (castedToken )
206
+ return ctx , nil
207
+ }
183
208
209
+ func authenticateUsingAuthToken (ctx context.Context , jwtValidators []* validator.Validator , config * config.Config , cache gcache.Cache , tkn string , reqMetadata * request.Metadata ) (context.Context , error ) {
210
+ validatedToken := getCachedToken (ctx , tkn , cache )
211
+ var err error
184
212
// if not found from cache
185
213
if validatedToken == nil {
186
214
count := 0
@@ -272,16 +300,17 @@ func getAuthFunction(config *config.Config) func(ctx context.Context) (context.C
272
300
lruCache := gcache .New (config .Auth .TokenValidationCacheSize ).
273
301
Expiration (time .Duration (config .Auth .TokenValidationCacheTTLSec ) * time .Second ).
274
302
Build ()
303
+ authProvider := auth .NewGotrueProvider ()
275
304
276
305
// inline closure to access the state of jwtValidator
277
306
if config .Tracing .Enabled {
278
307
return func (ctx context.Context ) (context.Context , error ) {
279
- return measuredAuthFunction (ctx , jwtValidators , config , lruCache )
308
+ return measuredAuthFunction (ctx , jwtValidators , config , lruCache , authProvider )
280
309
}
281
310
}
282
311
283
312
return func (ctx context.Context ) (context.Context , error ) {
284
- return authFunction (ctx , jwtValidators , config , lruCache )
313
+ return authFunction (ctx , jwtValidators , config , lruCache , authProvider )
285
314
}
286
315
}
287
316
0 commit comments