Skip to content

Commit 34b7340

Browse files
committed
Add connection option WithSkipTLSHostVerify for privatelink host
Signed-off-by: Jacky Hu <[email protected]>
1 parent 697ea4f commit 34b7340

File tree

3 files changed

+50
-2
lines changed

3 files changed

+50
-2
lines changed

connector.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package dbsql
22

33
import (
44
"context"
5+
"crypto/tls"
56
"database/sql/driver"
67
"fmt"
78
"net/http"
@@ -233,6 +234,20 @@ func WithSessionParams(params map[string]string) connOption {
233234
}
234235
}
235236

237+
// WithSkipTLSHostVerify disables the verification of the hostname in the TLS certificate.
238+
// WARNING:
239+
// When this option is used, TLS is susceptible to machine-in-the-middle attacks.
240+
// Please only use this option when the hostname is an internal private link hostname
241+
func WithSkipTLSHostVerify() connOption {
242+
return func(c *config.Config) {
243+
if c.TLSConfig == nil {
244+
c.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true} // #nosec G402
245+
} else {
246+
c.TLSConfig.InsecureSkipVerify = true // #nosec G402
247+
}
248+
}
249+
}
250+
236251
// WithAuthenticator sets up the Authentication. Mandatory if access token is not provided.
237252
func WithAuthenticator(authr auth.Authenticator) connOption {
238253
return func(c *config.Config) {

connector_test.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ import (
66
"time"
77

88
"github.com/databricks/databricks-sql-go/auth/pat"
9+
"github.com/databricks/databricks-sql-go/internal/client"
910
"github.com/databricks/databricks-sql-go/internal/config"
11+
"github.com/hashicorp/go-retryablehttp"
1012
"github.com/stretchr/testify/assert"
1113
"github.com/stretchr/testify/require"
1214
)
@@ -38,6 +40,7 @@ func TestNewConnector(t *testing.T) {
3840
WithTransport(roundTripper),
3941
WithCloudFetch(true),
4042
WithMaxDownloadThreads(15),
43+
WithSkipTLSHostVerify(),
4144
)
4245
expectedCloudFetchConfig := config.CloudFetchConfig{
4346
UseCloudFetch: true,
@@ -67,6 +70,7 @@ func TestNewConnector(t *testing.T) {
6770
expectedCfg := config.WithDefaults()
6871
expectedCfg.DriverVersion = DriverVersion
6972
expectedCfg.UserConfig = expectedUserConfig
73+
expectedCfg.TLSConfig.InsecureSkipVerify = true
7074
coni, ok := con.(*connector)
7175
require.True(t, ok)
7276
assert.Nil(t, err)
@@ -184,6 +188,28 @@ func TestNewConnector(t *testing.T) {
184188
}
185189

186190
})
191+
192+
t.Run("Connector test WithSkipTLSHostVerify with PoolClient", func(t *testing.T) {
193+
hostname := "databricks-host"
194+
con, err := NewConnector(
195+
WithServerHostname(hostname),
196+
WithSkipTLSHostVerify(),
197+
)
198+
assert.Nil(t, err)
199+
200+
coni, ok := con.(*connector)
201+
require.True(t, ok)
202+
userConfig := coni.cfg.UserConfig
203+
require.Equal(t, hostname, userConfig.Host)
204+
205+
httpClient, ok := coni.client.Transport.(*retryablehttp.RoundTripper)
206+
require.True(t, ok)
207+
poolClient, ok := httpClient.Client.HTTPClient.Transport.(*client.Transport)
208+
require.True(t, ok)
209+
internalClient, ok := poolClient.Base.(*http.Transport)
210+
require.True(t, ok)
211+
require.True(t, internalClient.TLSClientConfig.InsecureSkipVerify)
212+
})
187213
}
188214

189215
type mockRoundTripper struct{}

internal/client/client.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package client
22

33
import (
44
"context"
5+
"crypto/tls"
56
"crypto/x509"
67
"encoding/json"
78
"fmt"
@@ -545,14 +546,20 @@ func RetryableClient(cfg *config.Config) *http.Client {
545546
return retryableClient.StandardClient()
546547
}
547548

548-
func PooledTransport() *http.Transport {
549+
func PooledTransport(cfg *config.Config) *http.Transport {
550+
var tlsConfig *tls.Config
551+
if (cfg.TLSConfig != nil) && cfg.TLSConfig.InsecureSkipVerify {
552+
tlsConfig = cfg.TLSConfig
553+
}
554+
549555
transport := &http.Transport{
550556
Proxy: http.ProxyFromEnvironment,
551557
DialContext: (&net.Dialer{
552558
Timeout: 30 * time.Second,
553559
KeepAlive: 30 * time.Second,
554560
DualStack: true,
555561
}).DialContext,
562+
TLSClientConfig: tlsConfig,
556563
ForceAttemptHTTP2: true,
557564
MaxIdleConns: 100,
558565
IdleConnTimeout: 180 * time.Second,
@@ -577,7 +584,7 @@ func PooledClient(cfg *config.Config) *http.Client {
577584
}
578585
} else {
579586
tr = &Transport{
580-
Base: PooledTransport(),
587+
Base: PooledTransport(cfg),
581588
Authr: cfg.Authenticator,
582589
}
583590
}

0 commit comments

Comments
 (0)