Skip to content

Commit db03838

Browse files
authored
[PECO-2050] Add custom auth headers into cloud fetch request (#249)
When file encryption is enabled with customer provided keys (SSE-CPK), we must pass the keys in HTTP headers in the fetch request. These headers are provided in the property `httpHeaders` in the `TSparkArrowResultLink`
2 parents 1e9d6ac + 977c5c1 commit db03838

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

internal/rows/arrowbased/batchloader.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,12 @@ func fetchBatchBytes(
277277
return nil, err
278278
}
279279

280+
if link.HttpHeaders != nil {
281+
for key, value := range link.HttpHeaders {
282+
req.Header.Set(key, value)
283+
}
284+
}
285+
280286
client := http.DefaultClient
281287
res, err := client.Do(req)
282288
if err != nil {

internal/rows/arrowbased/batchloader_test.go

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@ import (
44
"bytes"
55
"context"
66
"fmt"
7-
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
8-
"github.com/databricks/databricks-sql-go/internal/cli_service"
9-
"github.com/databricks/databricks-sql-go/internal/config"
107
"net/http"
118
"net/http/httptest"
129
"testing"
1310
"time"
1411

12+
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
13+
"github.com/databricks/databricks-sql-go/internal/cli_service"
14+
"github.com/databricks/databricks-sql-go/internal/config"
15+
"github.com/pkg/errors"
16+
1517
"github.com/apache/arrow/go/v12/arrow"
1618
"github.com/apache/arrow/go/v12/arrow/array"
1719
"github.com/apache/arrow/go/v12/arrow/ipc"
@@ -28,8 +30,19 @@ func TestCloudFetchIterator(t *testing.T) {
2830
defer server.Close()
2931

3032
t.Run("should fetch all the links", func(t *testing.T) {
33+
cloudFetchHeaders := map[string]string{
34+
"foo": "bar",
35+
}
36+
3137
handler = func(w http.ResponseWriter, r *http.Request) {
3238
w.WriteHeader(http.StatusOK)
39+
for name, value := range cloudFetchHeaders {
40+
if values, ok := r.Header[name]; ok {
41+
if values[0] != value {
42+
panic(errors.New("Missing auth headers"))
43+
}
44+
}
45+
}
3346
_, err := w.Write(generateMockArrowBytes(generateArrowRecord()))
3447
if err != nil {
3548
panic(err)
@@ -44,12 +57,14 @@ func TestCloudFetchIterator(t *testing.T) {
4457
ExpiryTime: time.Now().Add(10 * time.Minute).Unix(),
4558
StartRowOffset: startRowOffset,
4659
RowCount: 1,
60+
HttpHeaders: cloudFetchHeaders,
4761
},
4862
{
4963
FileLink: server.URL,
5064
ExpiryTime: time.Now().Add(10 * time.Minute).Unix(),
5165
StartRowOffset: startRowOffset + 1,
5266
RowCount: 1,
67+
HttpHeaders: cloudFetchHeaders,
5368
},
5469
}
5570

0 commit comments

Comments
 (0)