Skip to content

Commit 5e02b0c

Browse files
committed
add support for aws sso credentials provider
1 parent 828b8cb commit 5e02b0c

File tree

3 files changed

+601
-0
lines changed

3 files changed

+601
-0
lines changed

v1/plugins/rest/auth.go

+7
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,7 @@ type awsSigningAuthPlugin struct {
769769
AWSAssumeRoleCredentials *awsAssumeRoleCredentialService `json:"assume_role_credentials,omitempty"`
770770
AWSWebIdentityCredentials *awsWebIdentityCredentialService `json:"web_identity_credentials,omitempty"`
771771
AWSProfileCredentials *awsProfileCredentialService `json:"profile_credentials,omitempty"`
772+
AWSSSOCredentials *awsSSOCredentialsService `json:"sso_credentials,omitempty"`
772773

773774
AWSService string `json:"service,omitempty"`
774775
AWSSignatureVersion string `json:"signature_version,omitempty"`
@@ -884,6 +885,11 @@ func (ap *awsSigningAuthPlugin) awsCredentialService() awsCredentialService {
884885
chain.addService(ap.AWSMetadataCredentials)
885886
}
886887

888+
if ap.AWSSSOCredentials != nil {
889+
ap.AWSSSOCredentials.logger = ap.logger
890+
chain.addService(ap.AWSSSOCredentials)
891+
}
892+
887893
return &chain
888894
}
889895

@@ -941,6 +947,7 @@ func (ap *awsSigningAuthPlugin) validateAndSetDefaults(serviceType string) error
941947
cfgs[ap.AWSAssumeRoleCredentials != nil]++
942948
cfgs[ap.AWSWebIdentityCredentials != nil]++
943949
cfgs[ap.AWSProfileCredentials != nil]++
950+
cfgs[ap.AWSSSOCredentials != nil]++
944951

945952
if cfgs[true] == 0 {
946953
return errors.New("a AWS credential service must be specified when S3 signing is enabled")

v1/plugins/rest/aws.go

+332
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,18 @@
55
package rest
66

77
import (
8+
"bytes"
89
"context"
10+
"crypto/sha1"
11+
"encoding/hex"
912
"encoding/json"
1013
"encoding/xml"
1114
"errors"
1215
"fmt"
1316
"net/http"
1417
"net/url"
1518
"os"
19+
"path"
1620
"path/filepath"
1721
"strings"
1822
"time"
@@ -51,6 +55,7 @@ const (
5155
awsRoleArnEnvVar = "AWS_ROLE_ARN"
5256
awsWebIdentityTokenFileEnvVar = "AWS_WEB_IDENTITY_TOKEN_FILE"
5357
awsCredentialsFileEnvVar = "AWS_SHARED_CREDENTIALS_FILE"
58+
awsConfigFileEnvVar = "AWS_CONFIG_FILE"
5459
awsProfileEnvVar = "AWS_PROFILE"
5560

5661
// ref. https://docs.aws.amazon.com/sdkref/latest/guide/settings-global.html
@@ -95,6 +100,333 @@ func (*awsEnvironmentCredentialService) credentials(context.Context) (aws.Creden
95100
return creds, nil
96101
}
97102

103+
type ssoSessionDetails struct {
104+
StartUrl string `json:"startUrl"`
105+
Region string `json:"region"`
106+
Name string
107+
AccountID string
108+
RoleName string
109+
AccessToken string `json:"accessToken"`
110+
ExpiresAt time.Time `json:"expiresAt"`
111+
RegistrationExpiresAt time.Time `json:"registrationExpiresAt"`
112+
RefreshToken string `json:"refreshToken"`
113+
ClientId string `json:"clientId"`
114+
ClientSecret string `json:"clientSecret"`
115+
}
116+
117+
type awsSSOCredentialsService struct {
118+
Path string `json:"path,omitempty"`
119+
SSOCachePath string `json:"cache_path,omitempty"`
120+
121+
Profile string `json:"profile,omitempty"`
122+
123+
logger logging.Logger
124+
125+
creds aws.Credentials
126+
127+
credentialsExpiresAt time.Time
128+
129+
session *ssoSessionDetails
130+
}
131+
132+
func (cs *awsSSOCredentialsService) configPath() (string, error) {
133+
if len(cs.Path) != 0 {
134+
return cs.Path, nil
135+
}
136+
137+
if cs.Path = os.Getenv(awsConfigFileEnvVar); len(cs.Path) != 0 {
138+
return cs.Path, nil
139+
}
140+
141+
homeDir, err := os.UserHomeDir()
142+
if err != nil {
143+
return "", fmt.Errorf("user home directory not found: %w", err)
144+
}
145+
146+
cs.Path = filepath.Join(homeDir, ".aws", "config")
147+
148+
return cs.Path, nil
149+
}
150+
func (cs *awsSSOCredentialsService) ssoCachePath() (string, error) {
151+
if len(cs.SSOCachePath) != 0 {
152+
return cs.SSOCachePath, nil
153+
}
154+
155+
homeDir, err := os.UserHomeDir()
156+
if err != nil {
157+
return "", fmt.Errorf("user home directory not found: %w", err)
158+
}
159+
160+
cs.Path = filepath.Join(homeDir, ".aws", "sso", "cache")
161+
162+
return cs.Path, nil
163+
}
164+
165+
func (cs *awsSSOCredentialsService) cacheKeyFileName() (string, error) {
166+
167+
val := cs.session.StartUrl
168+
if cs.session.Name != "" {
169+
val = cs.session.Name
170+
}
171+
172+
hash := sha1.New()
173+
hash.Write([]byte(val))
174+
cacheKey := hex.EncodeToString(hash.Sum(nil))
175+
176+
return cacheKey + ".json", nil
177+
}
178+
179+
func (cs *awsSSOCredentialsService) loadSSOCredentials() error {
180+
ssoCachePath, err := cs.ssoCachePath()
181+
if err != nil {
182+
return fmt.Errorf("failed to get sso cache path: %w", err)
183+
}
184+
185+
cacheKeyFile, err := cs.cacheKeyFileName()
186+
if err != nil {
187+
return err
188+
}
189+
190+
cacheFile := path.Join(ssoCachePath, cacheKeyFile)
191+
cache, err := os.ReadFile(cacheFile)
192+
if err != nil {
193+
return fmt.Errorf("failed to load cache file: %v", err)
194+
}
195+
196+
if err := json.Unmarshal(cache, &cs.session); err != nil {
197+
return fmt.Errorf("failed to unmarshal cache file: %v", err)
198+
}
199+
200+
return nil
201+
202+
}
203+
204+
func (cs *awsSSOCredentialsService) loadSession() error {
205+
configPath, err := cs.configPath()
206+
if err != nil {
207+
return fmt.Errorf("failed to get config path: %w", err)
208+
}
209+
config, err := ini.Load(configPath)
210+
if err != nil {
211+
return fmt.Errorf("failed to load config file: %w", err)
212+
}
213+
214+
section, err := config.GetSection("profile " + cs.Profile)
215+
216+
if err != nil {
217+
return fmt.Errorf("failed to find profile %s", cs.Profile)
218+
}
219+
220+
accountID, err := section.GetKey("sso_account_id")
221+
if err != nil {
222+
return fmt.Errorf("failed to find sso_account_id key in profile %s", cs.Profile)
223+
}
224+
225+
region, err := section.GetKey("region")
226+
if err != nil {
227+
return fmt.Errorf("failed to find region key in profile %s", cs.Profile)
228+
}
229+
230+
roleName, err := section.GetKey("sso_role_name")
231+
if err != nil {
232+
return fmt.Errorf("failed to find sso_role_name key in profile %s", cs.Profile)
233+
}
234+
235+
ssoSession, err := section.GetKey("sso_session")
236+
if err != nil {
237+
return fmt.Errorf("failed to find sso_session key in profile %s", cs.Profile)
238+
}
239+
240+
sessionName := ssoSession.Value()
241+
242+
session, err := config.GetSection("sso-session " + sessionName)
243+
if err != nil {
244+
return fmt.Errorf("failed to find sso-session %s", sessionName)
245+
}
246+
247+
startUrl, err := session.GetKey("sso_start_url")
248+
if err != nil {
249+
return fmt.Errorf("failed to find sso_start_url key in sso-session %s", sessionName)
250+
}
251+
252+
cs.session = &ssoSessionDetails{
253+
StartUrl: startUrl.Value(),
254+
Name: sessionName,
255+
AccountID: accountID.Value(),
256+
Region: region.Value(),
257+
RoleName: roleName.Value(),
258+
}
259+
260+
return nil
261+
}
262+
263+
func (cs *awsSSOCredentialsService) tryRefreshToken() error {
264+
// Check if refresh token is empty
265+
if cs.session.RefreshToken == "" {
266+
return errors.New("refresh token is empty")
267+
}
268+
269+
// Use the refresh token to get a new access token
270+
// using the clientId, clientSecret and refreshToken from the loaded token
271+
// return the new token
272+
// if error, return error
273+
274+
type refreshTokenRequest struct {
275+
ClientId string `json:"clientId"`
276+
ClientSecret string `json:"clientSecret"`
277+
RefreshToken string `json:"refreshToken"`
278+
GrantType string `json:"grantType"`
279+
}
280+
281+
data := refreshTokenRequest{
282+
ClientId: cs.session.ClientId,
283+
ClientSecret: cs.session.ClientSecret,
284+
RefreshToken: cs.session.RefreshToken,
285+
GrantType: "refresh_token",
286+
}
287+
288+
body, err := json.Marshal(data)
289+
if err != nil {
290+
return fmt.Errorf("failed to marshal refresh token request: %v", err)
291+
}
292+
293+
endpoint := fmt.Sprintf("https://oidc.%s.amazonaws.com/token", cs.session.Region)
294+
r, err := http.NewRequest("POST", endpoint, bytes.NewReader(body))
295+
if err != nil {
296+
return fmt.Errorf("failed to create new request: %v", err)
297+
}
298+
299+
r.Header.Add("Content-Type", "application/json")
300+
c := &http.Client{}
301+
resp, err := c.Do(r)
302+
if err != nil {
303+
return fmt.Errorf("failed to do request: %v", err)
304+
}
305+
defer resp.Body.Close()
306+
307+
type refreshTokenResponse struct {
308+
AccessToken string `json:"accessToken"`
309+
ExpiresIn int `json:"expiresIn"`
310+
RefreshToken string `json:"refreshToken"`
311+
}
312+
313+
refreshedToken := refreshTokenResponse{}
314+
315+
if err := json.NewDecoder(resp.Body).Decode(&refreshedToken); err != nil {
316+
return fmt.Errorf("failed to decode response: %v", err)
317+
}
318+
319+
cs.session.AccessToken = refreshedToken.AccessToken
320+
cs.session.ExpiresAt = time.Now().Add(time.Duration(refreshedToken.ExpiresIn) * time.Second)
321+
cs.session.RefreshToken = refreshedToken.RefreshToken
322+
323+
return nil
324+
}
325+
326+
func (cs *awsSSOCredentialsService) refreshCredentials() error {
327+
url := fmt.Sprintf("https://portal.sso.%s.amazonaws.com/federation/credentials?account_id=%s&role_name=%s", cs.session.Region, cs.session.AccountID, cs.session.RoleName)
328+
329+
req, err := http.NewRequest("GET", url, nil)
330+
if err != nil {
331+
return err
332+
}
333+
334+
req.Header.Set("Authorization", "Bearer "+cs.session.AccessToken)
335+
req.Header.Set("Content-Type", "application/json")
336+
337+
client := &http.Client{}
338+
resp, err := client.Do(req)
339+
if err != nil {
340+
return err
341+
}
342+
defer resp.Body.Close()
343+
344+
type roleCredentials struct {
345+
AccessKeyId string `json:"accessKeyId"`
346+
SecretAccessKey string `json:"secretAccessKey"`
347+
SessionToken string `json:"sessionToken"`
348+
Expiration int64 `json:"expiration"`
349+
}
350+
type getRoleCredentialsResponse struct {
351+
RoleCredentials roleCredentials `json:"roleCredentials"`
352+
}
353+
354+
var result getRoleCredentialsResponse
355+
356+
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
357+
return fmt.Errorf("failed to decode response: %v", err)
358+
}
359+
360+
cs.creds = aws.Credentials{
361+
AccessKey: result.RoleCredentials.AccessKeyId,
362+
SecretKey: result.RoleCredentials.SecretAccessKey,
363+
SessionToken: result.RoleCredentials.SessionToken,
364+
RegionName: cs.session.Region,
365+
}
366+
367+
cs.credentialsExpiresAt = time.Unix(result.RoleCredentials.Expiration, 0)
368+
369+
return nil
370+
}
371+
372+
func (cs *awsSSOCredentialsService) loadProfile() {
373+
if cs.Profile != "" {
374+
return
375+
}
376+
377+
cs.Profile = os.Getenv(awsProfileEnvVar)
378+
379+
if cs.Profile == "" {
380+
cs.Profile = "default"
381+
}
382+
383+
}
384+
385+
func (cs *awsSSOCredentialsService) init() error {
386+
cs.loadProfile()
387+
388+
if err := cs.loadSession(); err != nil {
389+
return fmt.Errorf("failed to load session: %w", err)
390+
}
391+
392+
if err := cs.loadSSOCredentials(); err != nil {
393+
return fmt.Errorf("failed to load SSO credentials: %w", err)
394+
}
395+
396+
// this enforces fetching credentials
397+
cs.credentialsExpiresAt = time.Unix(0, 0)
398+
return nil
399+
}
400+
401+
func (cs *awsSSOCredentialsService) credentials(context.Context) (aws.Credentials, error) {
402+
if cs.session == nil {
403+
if err := cs.init(); err != nil {
404+
return aws.Credentials{}, err
405+
}
406+
}
407+
408+
if cs.credentialsExpiresAt.Before(time.Now().Add(5 * time.Minute)) {
409+
// Check if the sso token we have is still valid,
410+
// if not, try to refresh it
411+
if cs.session.ExpiresAt.Before(time.Now()) {
412+
// we try and get a new token if we can
413+
if cs.session.RegistrationExpiresAt.Before(time.Now()) {
414+
return aws.Credentials{}, errors.New("cannot refresh token, registration expired")
415+
}
416+
417+
if err := cs.tryRefreshToken(); err != nil {
418+
return aws.Credentials{}, fmt.Errorf("failed to refresh token: %w", err)
419+
}
420+
}
421+
422+
if err := cs.refreshCredentials(); err != nil {
423+
return aws.Credentials{}, fmt.Errorf("failed to refresh credentials: %w", err)
424+
}
425+
}
426+
427+
return cs.creds, nil
428+
}
429+
98430
// awsProfileCredentialService represents a credential provider for AWS that extracts credentials from the AWS
99431
// credentials file
100432
type awsProfileCredentialService struct {

0 commit comments

Comments
 (0)