Skip to content

Commit d70ab7c

Browse files
Added GCP cloud type for OAuth (#189)
Added the GCP cloud type (i.e. domain .gcp.databricks.com) to OAuth implementation. Signed-off-by: Raymond Cypher <[email protected]>
1 parent 5adddfc commit d70ab7c

File tree

3 files changed

+73
-0
lines changed

3 files changed

+73
-0
lines changed

auth/oauth/oauth.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,16 @@ var databricksAzureDomains []string = []string{
8686
".databricks.azure.us",
8787
}
8888

89+
var databricksGCPDomains []string = []string{
90+
".gcp.databricks.com",
91+
}
92+
8993
type CloudType int
9094

9195
const (
9296
AWS = iota
9397
Azure
98+
GCP
9499
Unknown
95100
)
96101

@@ -100,6 +105,8 @@ func (cl CloudType) String() string {
100105
return "AWS"
101106
case Azure:
102107
return "Azure"
108+
case GCP:
109+
return "GCP"
103110
}
104111

105112
return "Unknown"
@@ -119,5 +126,10 @@ func InferCloudFromHost(hostname string) CloudType {
119126
}
120127
}
121128

129+
for _, d := range databricksGCPDomains {
130+
if strings.Contains(hostname, d) {
131+
return GCP
132+
}
133+
}
122134
return Unknown
123135
}

auth/oauth/u2m/authenticator.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ const (
3030

3131
awsClientId = "databricks-sql-connector"
3232
awsRedirectURL = "localhost:8030"
33+
34+
gcpClientId = "databricks-sql-connector"
35+
gcpRedirectURL = "localhost:8030"
3336
)
3437

3538
func NewAuthenticator(hostName string, timeout time.Duration) (auth.Authenticator, error) {
@@ -43,6 +46,9 @@ func NewAuthenticator(hostName string, timeout time.Duration) (auth.Authenticato
4346
} else if cloud == oauth.Azure {
4447
clientID = azureClientId
4548
redirectURL = azureRedirectURL
49+
} else if cloud == oauth.GCP {
50+
clientID = gcpClientId
51+
redirectURL = gcpRedirectURL
4652
} else {
4753
return nil, errors.New("unhandled cloud type: " + cloud.String())
4854
}

examples/oauth/main.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,17 @@ import (
99
"time"
1010

1111
dbsql "github.com/databricks/databricks-sql-go"
12+
"github.com/databricks/databricks-sql-go/auth/oauth/m2m"
1213
"github.com/databricks/databricks-sql-go/auth/oauth/u2m"
1314
"github.com/joho/godotenv"
1415
)
1516

1617
func main() {
18+
testU2M()
19+
testM2M()
20+
}
21+
22+
func testU2M() {
1723
err := godotenv.Load()
1824

1925
if err != nil {
@@ -62,3 +68,52 @@ func main() {
6268
}
6369
fmt.Println(res)
6470
}
71+
72+
func testM2M() {
73+
err := godotenv.Load()
74+
75+
if err != nil {
76+
log.Fatal(err.Error())
77+
}
78+
79+
clientID := os.Getenv("DATABRICKS_CLIENT_ID")
80+
clientSecret := os.Getenv("DATABRICKS_CLIENT_SECRET")
81+
host := os.Getenv("DATABRICKS_HOST")
82+
authenticator := m2m.NewAuthenticator(clientID, clientSecret, host)
83+
84+
connector, err := dbsql.NewConnector(
85+
dbsql.WithServerHostname(os.Getenv("DATABRICKS_HOST")),
86+
dbsql.WithHTTPPath(os.Getenv("DATABRICKS_HTTPPATH")),
87+
dbsql.WithAuthenticator(authenticator),
88+
)
89+
if err != nil {
90+
log.Fatal(err)
91+
}
92+
93+
db := sql.OpenDB(connector)
94+
defer db.Close()
95+
96+
// Pinging should require logging in
97+
if err := db.Ping(); err != nil {
98+
fmt.Println(err)
99+
}
100+
101+
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
102+
defer cancel()
103+
104+
var res int
105+
106+
// Running query should not require logging in as we should have a token
107+
// from when ping was called.
108+
err1 := db.QueryRowContext(ctx, `select 1`).Scan(&res)
109+
110+
if err1 != nil {
111+
if err1 == sql.ErrNoRows {
112+
fmt.Println("not found")
113+
return
114+
} else {
115+
fmt.Printf("err: %v\n", err1)
116+
}
117+
}
118+
fmt.Println(res)
119+
}

0 commit comments

Comments
 (0)