Skip to content

Commit 952fd87

Browse files
Use new ctx when closing operation after cancel (#86)
If the context is canceled already, we can't use it to close the operation. Create new context with the context data. Signed-off-by: Andre Furlan <[email protected]>
1 parent d8bb0c9 commit 952fd87

File tree

2 files changed

+82
-5
lines changed

2 files changed

+82
-5
lines changed

connection.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ func (c *conn) IsValid() bool {
9191
// ExecContext honors the context timeout and return when it is canceled.
9292
// Statement ExecContext is the same as connection ExecContext
9393
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
94-
log := logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), "")
94+
corrId := driverctx.CorrelationIdFromContext(ctx)
95+
log := logger.WithContext(c.id, corrId, "")
9596
msg, start := logger.Track("ExecContext")
9697
defer log.Duration(msg, start)
9798

@@ -103,12 +104,13 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
103104

104105
if exStmtResp != nil && exStmtResp.OperationHandle != nil {
105106
// we have an operation id so update the logger
106-
log = logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID))
107+
log = logger.WithContext(c.id, corrId, client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID))
107108

108109
// since we have an operation handle we can close the operation if necessary
109110
alreadyClosed := exStmtResp.DirectResults != nil && exStmtResp.DirectResults.CloseOperation != nil
111+
newCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), c.id), corrId)
110112
if !alreadyClosed && (opStatusResp == nil || opStatusResp.GetOperationState() != cli_service.TOperationState_CLOSED_STATE) {
111-
_, err1 := c.client.CloseOperation(ctx, &cli_service.TCloseOperationReq{
113+
_, err1 := c.client.CloseOperation(newCtx, &cli_service.TCloseOperationReq{
112114
OperationHandle: exStmtResp.OperationHandle,
113115
})
114116
if err1 != nil {

connection_test.go

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ func TestConn_executeStatement(t *testing.T) {
9595
},
9696
OperationHandle: &cli_service.TOperationHandle{
9797
OperationId: &cli_service.THandleIdentifier{
98-
GUID: []byte{1, 2, 3, 4, 2, 23, 4, 2, 3, 1, 2, 3, 4, 4, 223, 34, 54},
98+
GUID: []byte{1, 2, 3, 4, 2, 23, 4, 2, 3, 1, 2, 3, 4, 223, 34, 54},
9999
Secret: []byte("b"),
100100
},
101101
},
@@ -344,7 +344,7 @@ func TestConn_pollOperation(t *testing.T) {
344344
}
345345
res, err := testConn.pollOperation(context.Background(), &cli_service.TOperationHandle{
346346
OperationId: &cli_service.THandleIdentifier{
347-
GUID: []byte{1, 2, 3, 4, 2, 23, 4, 2, 3, 2, 3, 4, 4, 223, 34, 54},
347+
GUID: []byte{1, 2, 3, 4, 2, 23, 4, 2, 3, 2, 4, 7, 8, 223, 34, 54},
348348
Secret: []byte("b"),
349349
},
350350
})
@@ -1059,6 +1059,11 @@ func TestConn_ExecContext(t *testing.T) {
10591059

10601060
testClient := &client.TestClient{
10611061
FnExecuteStatement: executeStatement,
1062+
FnCloseOperation: func(ctx context.Context, req *cli_service.TCloseOperationReq) (_r *cli_service.TCloseOperationResp, _err error) {
1063+
ctxErr := ctx.Err()
1064+
assert.NoError(t, ctxErr)
1065+
return &cli_service.TCloseOperationResp{}, nil
1066+
},
10621067
}
10631068
testConn := &conn{
10641069
session: getTestSession(),
@@ -1102,6 +1107,11 @@ func TestConn_ExecContext(t *testing.T) {
11021107
testClient := &client.TestClient{
11031108
FnExecuteStatement: executeStatement,
11041109
FnGetOperationStatus: getOperationStatus,
1110+
FnCloseOperation: func(ctx context.Context, req *cli_service.TCloseOperationReq) (_r *cli_service.TCloseOperationResp, _err error) {
1111+
ctxErr := ctx.Err()
1112+
assert.NoError(t, ctxErr)
1113+
return &cli_service.TCloseOperationResp{}, nil
1114+
},
11051115
}
11061116
testConn := &conn{
11071117
session: getTestSession(),
@@ -1116,6 +1126,71 @@ func TestConn_ExecContext(t *testing.T) {
11161126
assert.Equal(t, int64(10), rowsAffected)
11171127
assert.Equal(t, 1, executeStatementCount)
11181128
})
1129+
t.Run("ExecContext uses new context to close operation", func(t *testing.T) {
1130+
var executeStatementCount, getOperationStatusCount, closeOperationCount, cancelOperationCount int
1131+
var cancel context.CancelFunc
1132+
executeStatement := func(ctx context.Context, req *cli_service.TExecuteStatementReq) (r *cli_service.TExecuteStatementResp, err error) {
1133+
executeStatementCount++
1134+
executeStatementResp := &cli_service.TExecuteStatementResp{
1135+
Status: &cli_service.TStatus{
1136+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
1137+
},
1138+
OperationHandle: &cli_service.TOperationHandle{
1139+
OperationId: &cli_service.THandleIdentifier{
1140+
GUID: []byte{1, 2, 3, 4, 2, 23, 4, 2, 3, 2, 3, 4, 4, 223, 34, 54},
1141+
Secret: []byte("b"),
1142+
},
1143+
},
1144+
}
1145+
return executeStatementResp, nil
1146+
}
1147+
1148+
getOperationStatus := func(ctx context.Context, req *cli_service.TGetOperationStatusReq) (r *cli_service.TGetOperationStatusResp, err error) {
1149+
getOperationStatusCount++
1150+
cancel()
1151+
getOperationStatusResp := &cli_service.TGetOperationStatusResp{
1152+
OperationState: cli_service.TOperationStatePtr(cli_service.TOperationState_FINISHED_STATE),
1153+
NumModifiedRows: thrift.Int64Ptr(10),
1154+
}
1155+
return getOperationStatusResp, nil
1156+
}
1157+
1158+
testClient := &client.TestClient{
1159+
FnExecuteStatement: executeStatement,
1160+
FnGetOperationStatus: getOperationStatus,
1161+
FnCloseOperation: func(ctx context.Context, req *cli_service.TCloseOperationReq) (_r *cli_service.TCloseOperationResp, _err error) {
1162+
closeOperationCount++
1163+
ctxErr := ctx.Err()
1164+
assert.NoError(t, ctxErr)
1165+
return &cli_service.TCloseOperationResp{}, nil
1166+
},
1167+
FnCancelOperation: func(ctx context.Context, req *cli_service.TCancelOperationReq) (r *cli_service.TCancelOperationResp, err error) {
1168+
cancelOperationCount++
1169+
cancelOperationResp := &cli_service.TCancelOperationResp{
1170+
Status: &cli_service.TStatus{
1171+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
1172+
},
1173+
}
1174+
return cancelOperationResp, nil
1175+
},
1176+
}
1177+
testConn := &conn{
1178+
session: getTestSession(),
1179+
client: testClient,
1180+
cfg: config.WithDefaults(),
1181+
}
1182+
ctx := context.Background()
1183+
ctx, cancel = context.WithCancel(ctx)
1184+
defer cancel()
1185+
res, err := testConn.ExecContext(ctx, "insert 10", []driver.NamedValue{})
1186+
time.Sleep(10 * time.Millisecond)
1187+
assert.Error(t, err)
1188+
assert.Nil(t, res)
1189+
assert.Equal(t, 1, executeStatementCount)
1190+
assert.Equal(t, 1, cancelOperationCount)
1191+
assert.Equal(t, 1, getOperationStatusCount)
1192+
assert.Equal(t, 1, closeOperationCount)
1193+
})
11191194
}
11201195

11211196
func TestConn_QueryContext(t *testing.T) {

0 commit comments

Comments
 (0)