Skip to content

Commit 4da02dc

Browse files
Close operation after executing statement (#65)
When executing a query there is a rows object returned to the client and the client is then responsible for closing the row set when finished with it. Closing the rows causes the driver to close the operation on the server. When executing a statement (ex. DROP TABLE ...) the client doesn't receive anything by which they can close the operation. This results in an operation that is completed but not closed. Eventually the open operation is closed due to inactivity but the driver should be closing the operation. - updated connection.ExecContext to close the operation if 1) we receive an operation handle back from the server and 2) the operations status is not already closed. - added unit test to check that we do/don't close the operation depending on the returned operation status. Signed-off-by: Raymond Cypher <[email protected]>
2 parents 4488f73 + fb0e622 commit 4da02dc

File tree

4 files changed

+213
-20
lines changed

4 files changed

+213
-20
lines changed

connection.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,21 +90,38 @@ func (c *conn) IsValid() bool {
9090
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
9191
log := logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), "")
9292
msg, start := logger.Track("ExecContext")
93+
defer log.Duration(msg, start)
94+
9395
ctx = driverctx.NewContextWithConnId(ctx, c.id)
9496
if len(args) > 0 {
9597
return nil, errors.New(ErrParametersNotSupported)
9698
}
9799
exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args)
98100

99101
if exStmtResp != nil && exStmtResp.OperationHandle != nil {
102+
// we have an operation id so update the logger
100103
log = logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID))
104+
105+
// since we have an operation handle we can close the operation if necessary
106+
alreadyClosed := exStmtResp.DirectResults != nil && exStmtResp.DirectResults.CloseOperation != nil
107+
if !alreadyClosed && (opStatusResp == nil || opStatusResp.GetOperationState() != cli_service.TOperationState_CLOSED_STATE) {
108+
_, err1 := c.client.CloseOperation(ctx, &cli_service.TCloseOperationReq{
109+
OperationHandle: exStmtResp.OperationHandle,
110+
})
111+
if err1 != nil {
112+
log.Err(err1).Msg("databricks: failed to close operation after executing statement")
113+
}
114+
}
101115
}
102-
defer log.Duration(msg, start)
103116

104117
if err != nil {
118+
// TODO: are there error situations in which the operation still needs to be closed?
119+
// Currently if there is an error we never get back a TExecuteStatementResponse so
120+
// can't try to close.
105121
log.Err(err).Msgf("databricks: failed to execute query: query %s", query)
106122
return nil, wrapErrf(err, "failed to execute query")
107123
}
124+
108125
res := result{AffectedRows: opStatusResp.GetNumModifiedRows()}
109126

110127
return &res, nil
@@ -261,10 +278,12 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
261278
if err != nil {
262279
return nil, err
263280
}
281+
264282
exStmtResp, ok := res.(*cli_service.TExecuteStatementResp)
265283
if !ok {
266284
return exStmtResp, errors.New("databricks: invalid execute statement response")
267285
}
286+
268287
return exStmtResp, err
269288
}
270289

connection_test.go

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,101 @@ func TestConn_executeStatement(t *testing.T) {
8686
assert.NoError(t, err)
8787
assert.Equal(t, 1, executeStatementCount)
8888
})
89+
90+
t.Run("ExecStatement should close operation on success", func(t *testing.T) {
91+
var executeStatementCount, closeOperationCount int
92+
executeStatementResp := &cli_service.TExecuteStatementResp{
93+
Status: &cli_service.TStatus{
94+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
95+
},
96+
OperationHandle: &cli_service.TOperationHandle{
97+
OperationId: &cli_service.THandleIdentifier{
98+
GUID: []byte{1, 2, 3, 4, 2, 23, 4, 2, 3, 1, 2, 3, 4, 4, 223, 34, 54},
99+
Secret: []byte("b"),
100+
},
101+
},
102+
DirectResults: &cli_service.TSparkDirectResults{
103+
OperationStatus: &cli_service.TGetOperationStatusResp{
104+
Status: &cli_service.TStatus{
105+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
106+
},
107+
OperationState: cli_service.TOperationStatePtr(cli_service.TOperationState_ERROR_STATE),
108+
ErrorMessage: strPtr("error message"),
109+
DisplayMessage: strPtr("display message"),
110+
},
111+
ResultSetMetadata: &cli_service.TGetResultSetMetadataResp{
112+
Status: &cli_service.TStatus{
113+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
114+
},
115+
},
116+
ResultSet: &cli_service.TFetchResultsResp{
117+
Status: &cli_service.TStatus{
118+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
119+
},
120+
},
121+
},
122+
}
123+
124+
testClient := &client.TestClient{
125+
FnExecuteStatement: func(ctx context.Context, req *cli_service.TExecuteStatementReq) (r *cli_service.TExecuteStatementResp, err error) {
126+
executeStatementCount++
127+
return executeStatementResp, nil
128+
},
129+
FnCloseOperation: func(ctx context.Context, req *cli_service.TCloseOperationReq) (_r *cli_service.TCloseOperationResp, _err error) {
130+
closeOperationCount++
131+
return &cli_service.TCloseOperationResp{}, nil
132+
},
133+
}
134+
135+
testConn := &conn{
136+
session: getTestSession(),
137+
client: testClient,
138+
cfg: config.WithDefaults(),
139+
}
140+
141+
type opStateTest struct {
142+
state cli_service.TOperationState
143+
err string
144+
closeOperationCount int
145+
}
146+
147+
// test behaviour with all terminal operation states
148+
operationStateTests := []opStateTest{
149+
{state: cli_service.TOperationState_ERROR_STATE, err: "error state", closeOperationCount: 1},
150+
{state: cli_service.TOperationState_FINISHED_STATE, err: "", closeOperationCount: 1},
151+
{state: cli_service.TOperationState_CANCELED_STATE, err: "cancelled state", closeOperationCount: 1},
152+
{state: cli_service.TOperationState_CLOSED_STATE, err: "closed state", closeOperationCount: 0},
153+
{state: cli_service.TOperationState_TIMEDOUT_STATE, err: "timeout state", closeOperationCount: 1},
154+
}
155+
156+
for _, opTest := range operationStateTests {
157+
closeOperationCount = 0
158+
executeStatementCount = 0
159+
executeStatementResp.DirectResults.OperationStatus.OperationState = &opTest.state
160+
executeStatementResp.DirectResults.OperationStatus.DisplayMessage = &opTest.err
161+
_, err := testConn.ExecContext(context.Background(), "select 1", []driver.NamedValue{})
162+
if opTest.err == "" {
163+
assert.NoError(t, err)
164+
} else {
165+
assert.EqualError(t, err, opTest.err)
166+
}
167+
assert.Equal(t, 1, executeStatementCount)
168+
assert.Equal(t, opTest.closeOperationCount, closeOperationCount)
169+
}
170+
171+
// if the execute statement response contains direct results with a non-nil CloseOperation member
172+
// we shouldn't call close
173+
closeOperationCount = 0
174+
executeStatementCount = 0
175+
executeStatementResp.DirectResults.CloseOperation = &cli_service.TCloseOperationResp{}
176+
finished := cli_service.TOperationState_FINISHED_STATE
177+
executeStatementResp.DirectResults.OperationStatus.OperationState = &finished
178+
_, err := testConn.ExecContext(context.Background(), "select 1", []driver.NamedValue{})
179+
assert.NoError(t, err)
180+
assert.Equal(t, 1, executeStatementCount)
181+
assert.Equal(t, 0, closeOperationCount)
182+
})
183+
89184
}
90185

91186
func TestConn_pollOperation(t *testing.T) {

examples/createdrop/main.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"log"
7+
"os"
8+
"strconv"
9+
"time"
10+
11+
dbsql "github.com/databricks/databricks-sql-go"
12+
dbsqlctx "github.com/databricks/databricks-sql-go/driverctx"
13+
dbsqllog "github.com/databricks/databricks-sql-go/logger"
14+
"github.com/joho/godotenv"
15+
)
16+
17+
func main() {
18+
// use this package to set up logging. By default logging level is `warn`. If you want to disable logging, use `disabled`
19+
if err := dbsqllog.SetLogLevel("debug"); err != nil {
20+
log.Fatal(err)
21+
}
22+
// sets the logging output. By default it will use os.Stderr. If running in terminal, it will use ConsoleWriter to make it pretty
23+
// dbsqllog.SetLogOutput(os.Stdout)
24+
25+
// this is just to make it easy to load all variables
26+
if err := godotenv.Load(); err != nil {
27+
log.Fatal(err)
28+
}
29+
port, err := strconv.Atoi(os.Getenv("DATABRICKS_PORT"))
30+
if err != nil {
31+
log.Fatal(err)
32+
}
33+
34+
// programmatically initializes the connector
35+
// another way is to use a DNS. In this case the equivalent DNS would be:
36+
// "token:<my_token>@hostname:port/http_path?catalog=hive_metastore&schema=default&timeout=60&maxRows=10&&timezone=America/Sao_Paulo&ANSI_MODE=true"
37+
connector, err := dbsql.NewConnector(
38+
// minimum configuration
39+
dbsql.WithServerHostname(os.Getenv("DATABRICKS_HOST")),
40+
dbsql.WithPort(port),
41+
dbsql.WithHTTPPath(os.Getenv("DATABRICKS_HTTPPATH")),
42+
dbsql.WithAccessToken(os.Getenv("DATABRICKS_ACCESSTOKEN")),
43+
//optional configuration
44+
dbsql.WithSessionParams(map[string]string{"timezone": "America/Sao_Paulo", "ansi_mode": "true"}),
45+
dbsql.WithUserAgentEntry("workflow-example"),
46+
dbsql.WithInitialNamespace("hive_metastore", "default"),
47+
dbsql.WithTimeout(time.Minute), // defaults to no timeout. Global timeout. Any query will be canceled if taking more than this time.
48+
dbsql.WithMaxRows(10), // defaults to 10000
49+
)
50+
if err != nil {
51+
// This will not be a connection error, but a DSN parse error or
52+
// another initialization error.
53+
log.Fatal(err)
54+
55+
}
56+
// Opening a driver typically will not attempt to connect to the database.
57+
db := sql.OpenDB(connector)
58+
// make sure to close it later
59+
defer db.Close()
60+
61+
ogCtx := dbsqlctx.NewContextWithCorrelationId(context.Background(), "createdrop-example")
62+
63+
// sets the timeout to 30 seconds. More than that we ping will fail. The default is 15 seconds
64+
ctx1, cancel := context.WithTimeout(ogCtx, 30*time.Second)
65+
defer cancel()
66+
if err := db.PingContext(ctx1); err != nil {
67+
log.Fatal(err)
68+
}
69+
70+
// create a table with some data. This has no context timeout, it will follow the timeout of one minute set for the connection.
71+
if _, err := db.ExecContext(ogCtx, `CREATE TABLE IF NOT EXISTS diamonds USING CSV LOCATION '/databricks-datasets/Rdatasets/data-001/csv/ggplot2/diamonds.csv' options (header = true, inferSchema = true)`); err != nil {
72+
log.Fatal(err)
73+
}
74+
75+
if _, err := db.ExecContext(ogCtx, `DROP TABLE diamonds `); err != nil {
76+
log.Fatal(err)
77+
}
78+
}

examples/workflow/main.go

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"database/sql"
66
"fmt"
7+
"log"
78
"os"
89
"strconv"
910
"time"
@@ -17,18 +18,18 @@ import (
1718
func main() {
1819
// use this package to set up logging. By default logging level is `warn`. If you want to disable logging, use `disabled`
1920
if err := dbsqllog.SetLogLevel("debug"); err != nil {
20-
panic(err)
21+
log.Fatal(err)
2122
}
2223
// sets the logging output. By default it will use os.Stderr. If running in terminal, it will use ConsoleWriter to make it pretty
2324
// dbsqllog.SetLogOutput(os.Stdout)
2425

2526
// this is just to make it easy to load all variables
2627
if err := godotenv.Load(); err != nil {
27-
panic(err)
28+
log.Fatal(err)
2829
}
2930
port, err := strconv.Atoi(os.Getenv("DATABRICKS_PORT"))
3031
if err != nil {
31-
panic(err)
32+
log.Fatal(err)
3233
}
3334

3435
// programmatically initializes the connector
@@ -50,7 +51,7 @@ func main() {
5051
if err != nil {
5152
// This will not be a connection error, but a DSN parse error or
5253
// another initialization error.
53-
panic(err)
54+
log.Fatal(err)
5455

5556
}
5657
// Opening a driver typically will not attempt to connect to the database.
@@ -88,18 +89,18 @@ func main() {
8889
ctx1, cancel := context.WithTimeout(ogCtx, 30*time.Second)
8990
defer cancel()
9091
if err := db.PingContext(ctx1); err != nil {
91-
panic(err)
92+
log.Fatal(err)
9293
}
9394

9495
// create a table with some data. This has no context timeout, it will follow the timeout of one minute set for the connection.
9596
if _, err := db.ExecContext(ogCtx, `CREATE TABLE IF NOT EXISTS diamonds USING CSV LOCATION '/databricks-datasets/Rdatasets/data-001/csv/ggplot2/diamonds.csv' options (header = true, inferSchema = true)`); err != nil {
96-
panic(err)
97+
log.Fatal(err)
9798
}
9899

99100
// QueryRowContext is a shortcut function to get a single value
100101
var max float64
101102
if err := db.QueryRowContext(ogCtx, `select max(carat) from diamonds`).Scan(&max); err != nil {
102-
panic(err)
103+
log.Fatal(err)
103104
} else {
104105
fmt.Printf("max carat in dataset is: %f\n", max)
105106
}
@@ -109,7 +110,7 @@ func main() {
109110
defer cancel()
110111

111112
if rows, err := db.QueryContext(ctx2, "select * from diamonds limit 19"); err != nil {
112-
panic(err)
113+
log.Fatal(err)
113114
} else {
114115
type row struct {
115116
_c0 int
@@ -127,11 +128,11 @@ func main() {
127128

128129
cols, err := rows.Columns()
129130
if err != nil {
130-
panic(err)
131+
log.Fatal(err)
131132
}
132133
types, err := rows.ColumnTypes()
133134
if err != nil {
134-
panic(err)
135+
log.Fatal(err)
135136
}
136137
for i, c := range cols {
137138
fmt.Printf("column %d is %s and has type %v\n", i, c, types[i].DatabaseTypeName())
@@ -141,7 +142,7 @@ func main() {
141142
// After row 10 this will cause one fetch call, as 10 rows (maxRows config) will come from the first execute statement call.
142143
r := row{}
143144
if err := rows.Scan(&r._c0, &r.carat, &r.cut, &r.color, &r.clarity, &r.depth, &r.table, &r.price, &r.x, &r.y, &r.z); err != nil {
144-
panic(err)
145+
log.Fatal(err)
145146
}
146147
res = append(res, r)
147148
}
@@ -156,7 +157,7 @@ func main() {
156157
var curTimezone string
157158

158159
if err := db.QueryRowContext(ogCtx, `select current_date(), current_timestamp(), current_timezone()`).Scan(&curDate, &curTimestamp, &curTimezone); err != nil {
159-
panic(err)
160+
log.Fatal(err)
160161
} else {
161162
// this will print now at timezone America/Sao_Paulo is: 2022-11-16 20:25:15.282 -0300 -03
162163
fmt.Printf("current timestamp at timezone %s is: %s\n", curTimezone, curTimestamp)
@@ -170,11 +171,11 @@ func main() {
170171
array_col array < int >,
171172
map_col map < string, int >,
172173
struct_col struct < string_field string, array_field array < int > >)`); err != nil {
173-
panic(err)
174+
log.Fatal(err)
174175
}
175176
var numRows int
176177
if err := db.QueryRowContext(ogCtx, `select count(*) from array_map_struct`).Scan(&numRows); err != nil {
177-
panic(err)
178+
log.Fatal(err)
178179
} else {
179180
fmt.Printf("table has %d rows\n", numRows)
180181
}
@@ -186,7 +187,7 @@ func main() {
186187
array(1, 2, 3),
187188
map('key1', 1),
188189
struct('string_val', array(4, 5, 6)))`); err != nil {
189-
panic(err)
190+
log.Fatal(err)
190191
} else {
191192
i, err1 := res.RowsAffected()
192193
if err1 != nil {
@@ -197,7 +198,7 @@ func main() {
197198
}
198199

199200
if rows, err := db.QueryContext(ogCtx, "select * from array_map_struct"); err != nil {
200-
panic(err)
201+
log.Fatal(err)
201202
} else {
202203
// complex data types are returned as string
203204
type row struct {
@@ -208,11 +209,11 @@ func main() {
208209
res := []row{}
209210
cols, err := rows.Columns()
210211
if err != nil {
211-
panic(err)
212+
log.Fatal(err)
212213
}
213214
types, err := rows.ColumnTypes()
214215
if err != nil {
215-
panic(err)
216+
log.Fatal(err)
216217
}
217218
for i, c := range cols {
218219
fmt.Printf("column %d is %s and has type %v\n", i, c, types[i].DatabaseTypeName())
@@ -221,7 +222,7 @@ func main() {
221222
for rows.Next() {
222223
r := row{}
223224
if err := rows.Scan(&r.arrayVal, &r.mapVal, &r.structVal); err != nil {
224-
panic(err)
225+
log.Fatal(err)
225226
}
226227
res = append(res, r)
227228
}

0 commit comments

Comments
 (0)