Skip to content

Commit e81c58e

Browse files
committed
microsoft: support client_credentials flow using client assertions
Fixes golang#465 - Commonly referred to as Client Certificate authentication - Similar to JWT two-legged auth but incompatible
1 parent 08078c5 commit e81c58e

File tree

1 file changed

+171
-0
lines changed

1 file changed

+171
-0
lines changed

microsoft/microsoft.go

+171
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,22 @@
66
package microsoft // import "golang.org/x/oauth2/microsoft"
77

88
import (
9+
"context"
10+
"crypto/sha1"
11+
"crypto/x509"
12+
"encoding/base64"
13+
"encoding/json"
14+
"encoding/pem"
15+
"fmt"
916
"golang.org/x/oauth2"
17+
"golang.org/x/oauth2/internal"
18+
"golang.org/x/oauth2/jws"
19+
"io"
20+
"io/ioutil"
21+
"net/http"
22+
"net/url"
23+
"strings"
24+
"time"
1025
)
1126

1227
// LiveConnectEndpoint is Windows's Live ID OAuth 2.0 endpoint.
@@ -29,3 +44,159 @@ func AzureADEndpoint(tenant string) oauth2.Endpoint {
2944
TokenURL: "https://login.microsoftonline.com/" + tenant + "/oauth2/v2.0/token",
3045
}
3146
}
47+
48+
// Config is the configuration for using client credentials flow with a client assertion.
49+
//
50+
// For more information see:
51+
// https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-certificate-credentials
52+
type Config struct {
53+
// ClientID is the application's ID.
54+
ClientID string
55+
56+
// PrivateKey contains the contents of an RSA private key or the
57+
// contents of a PEM file that contains a private key. The provided
58+
// private key is used to sign JWT assertions.
59+
// PEM containers with a passphrase are not supported.
60+
// Use the following command to convert a PKCS 12 file into a PEM.
61+
//
62+
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
63+
//
64+
PrivateKey []byte
65+
66+
// Certificate contains the (optionally PEM encoded) X509 certificate registered
67+
// for the application with which you are authenticating.
68+
Certificate []byte
69+
70+
// Scopes optionally specifies a list of requested permission scopes.
71+
Scopes []string
72+
73+
// TokenURL is the token endpoint. Typically you can use the AzureADEndpoint
74+
// function to obtain this value, but it may change for non-public clouds.
75+
TokenURL string
76+
77+
// Expires optionally specifies how long the token is valid for.
78+
Expires time.Duration
79+
80+
// Audience optionally specifies the intended audience of the
81+
// request. If empty, the value of TokenURL is used as the
82+
// intended audience.
83+
Audience string
84+
}
85+
86+
// TokenSource returns a JWT TokenSource using the configuration
87+
// in c and the HTTP client from the provided context.
88+
func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource {
89+
return oauth2.ReuseTokenSource(nil, assertionSource{ctx, c})
90+
}
91+
92+
// Client returns an HTTP client wrapping the context's
93+
// HTTP transport and adding Authorization headers with tokens
94+
// obtained from c.
95+
//
96+
// The returned client and its Transport should not be modified.
97+
func (c *Config) Client(ctx context.Context) *http.Client {
98+
return oauth2.NewClient(ctx, c.TokenSource(ctx))
99+
}
100+
101+
// assertionSource is a source that always does a signed JWT request for a token.
102+
// It should typically be wrapped with a reuseTokenSource.
103+
type assertionSource struct {
104+
ctx context.Context
105+
conf *Config
106+
}
107+
108+
func (a assertionSource) Token() (*oauth2.Token, error) {
109+
crt := a.conf.Certificate
110+
if der, _ := pem.Decode(a.conf.Certificate); der != nil {
111+
crt = der.Bytes
112+
}
113+
cert, err := x509.ParseCertificate(crt)
114+
if err != nil {
115+
return nil, fmt.Errorf("oauth2: cannot parse certificate: %v", err)
116+
}
117+
s := sha1.Sum(cert.Raw)
118+
fp := base64.URLEncoding.EncodeToString(s[:])
119+
h := jws.Header{
120+
Algorithm: "RS256",
121+
Typ: "JWT",
122+
KeyID: fp,
123+
}
124+
125+
claimSet := &jws.ClaimSet{
126+
Iss: a.conf.ClientID,
127+
Sub: a.conf.ClientID,
128+
Aud: a.conf.TokenURL,
129+
}
130+
if t := a.conf.Expires; t > 0 {
131+
claimSet.Exp = time.Now().Add(t).Unix()
132+
}
133+
if aud := a.conf.Audience; aud != "" {
134+
claimSet.Aud = aud
135+
}
136+
137+
pk, err := internal.ParseKey(a.conf.PrivateKey)
138+
if err != nil {
139+
return nil, err
140+
}
141+
142+
payload, err := jws.Encode(&h, claimSet, pk)
143+
if err != nil {
144+
return nil, err
145+
}
146+
147+
hc := oauth2.NewClient(a.ctx, nil)
148+
v := url.Values{
149+
"client_assertion": {payload},
150+
"client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"},
151+
"client_id": {a.conf.ClientID},
152+
"grant_type": {"client_credentials"},
153+
"scope": {strings.Join(a.conf.Scopes, " ")},
154+
}
155+
resp, err := hc.PostForm(a.conf.TokenURL, v)
156+
if err != nil {
157+
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
158+
}
159+
160+
defer resp.Body.Close()
161+
body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
162+
if err != nil {
163+
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
164+
}
165+
166+
if c := resp.StatusCode; c < 200 || c > 299 {
167+
return nil, &oauth2.RetrieveError{
168+
Response: resp,
169+
Body: body,
170+
}
171+
}
172+
173+
var tokenRes struct {
174+
AccessToken string `json:"access_token"`
175+
TokenType string `json:"token_type"`
176+
IDToken string `json:"id_token"`
177+
Scope string `json:"scope"`
178+
ExpiresIn int64 `json:"expires_in"` // relative seconds from now
179+
ExpiresOn int64 `json:"expires_on"` // timestamp
180+
}
181+
if err := json.Unmarshal(body, &tokenRes); err != nil {
182+
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
183+
}
184+
185+
token := &oauth2.Token{
186+
AccessToken: tokenRes.AccessToken,
187+
TokenType: tokenRes.TokenType,
188+
}
189+
if secs := tokenRes.ExpiresIn; secs > 0 {
190+
token.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
191+
}
192+
if v := tokenRes.IDToken; v != "" {
193+
// decode returned id token to get expiry
194+
claimSet, err := jws.Decode(v)
195+
if err != nil {
196+
return nil, fmt.Errorf("oauth2: error decoding JWT token: %v", err)
197+
}
198+
token.Expiry = time.Unix(claimSet.Exp, 0)
199+
}
200+
201+
return token, nil
202+
}

0 commit comments

Comments
 (0)