Skip to content

Commit a32b90b

Browse files
author
Heqing Ya
authored
Escape Parameters (#9)
* escape parameters * fix test * rename const
1 parent e07c3f8 commit a32b90b

File tree

5 files changed

+89
-10
lines changed

5 files changed

+89
-10
lines changed

connection.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,10 @@ func (c *Conn) QueryContext(ctx context.Context, q string, args []driver.NamedVa
6565
}
6666

6767
tmpl := template(q)
68-
stmt := statement(tmpl, args)
68+
stmt, err := statement(tmpl, args)
69+
if err != nil {
70+
return nil, err
71+
}
6972
return query(ctx, session, stmt)
7073
}
7174

@@ -77,7 +80,10 @@ func (c *Conn) ExecContext(ctx context.Context, q string, args []driver.NamedVal
7780
}
7881

7982
tmpl := template(q)
80-
stmt := statement(tmpl, args)
83+
stmt, err := statement(tmpl, args)
84+
if err != nil {
85+
return nil, err
86+
}
8187
return exec(ctx, session, stmt)
8288
}
8389

escaper.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package dbsql
2+
3+
import (
4+
"database/sql/driver"
5+
"fmt"
6+
"strings"
7+
"time"
8+
)
9+
10+
const (
11+
TimeFmt = "2006-01-02T15:04:05.999-07:00"
12+
)
13+
14+
func EscapeArg(arg driver.NamedValue) (string, error) {
15+
switch v := arg.Value.(type) {
16+
case int64:
17+
return fmt.Sprintf("%v", v), nil
18+
case float64:
19+
return fmt.Sprintf("%v", v), nil
20+
case bool:
21+
return fmt.Sprintf("%v", v), nil
22+
case string:
23+
return fmt.Sprintf("'%v'", strings.ReplaceAll(v, "'", "''")), nil
24+
case time.Time:
25+
return "'" + v.Format(TimeFmt) + "'", nil
26+
default:
27+
return "", fmt.Errorf("unsupported parameter type %T for value %v", v, v)
28+
}
29+
}

escaper_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package dbsql
2+
3+
import (
4+
"database/sql/driver"
5+
"testing"
6+
"time"
7+
)
8+
9+
func TestEscaper(t *testing.T) {
10+
testcases := []struct {
11+
Value driver.Value
12+
ExpectedOutput string
13+
}{
14+
{Value: "a'b'c", ExpectedOutput: `'a''b''c'`},
15+
{Value: int64(1024), ExpectedOutput: `1024`},
16+
{Value: float64(1024.5), ExpectedOutput: `1024.5`},
17+
{Value: true, ExpectedOutput: "true"},
18+
{Value: time.Date(2020, time.April, 11, 21, 34, 01, 0, time.UTC), ExpectedOutput: "'2020-04-11T21:34:01+00:00'"},
19+
}
20+
21+
for _, test := range testcases {
22+
actual, err := EscapeArg(driver.NamedValue{Value: test.Value})
23+
if err != nil {
24+
t.Error(err)
25+
}
26+
if actual != test.ExpectedOutput {
27+
t.Errorf("expecting %v, actual value: %v", test.ExpectedOutput, actual)
28+
}
29+
30+
}
31+
}

statement.go

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv
6464
if err != nil {
6565
return nil, err
6666
}
67-
stmt := statement(s.stmt, args)
67+
stmt, err := statement(s.stmt, args)
68+
if err != nil {
69+
return nil, err
70+
}
6871
return query(ctx, session, stmt)
6972
}
7073

@@ -74,7 +77,10 @@ func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
7477
if err != nil {
7578
return nil, err
7679
}
77-
stmt := statement(s.stmt, args)
80+
stmt, err := statement(s.stmt, args)
81+
if err != nil {
82+
return nil, err
83+
}
7884
return exec(ctx, session, stmt)
7985
}
8086

@@ -92,7 +98,7 @@ func template(query string) string {
9298
return query
9399
}
94100

95-
func statement(tmpl string, args []driver.NamedValue) string {
101+
func statement(tmpl string, args []driver.NamedValue) (string, error) {
96102
stmt := tmpl
97103
for _, arg := range args {
98104
var re *regexp.Regexp
@@ -101,10 +107,14 @@ func statement(tmpl string, args []driver.NamedValue) string {
101107
} else {
102108
re = regexp.MustCompile(fmt.Sprintf("@p%d%s", arg.Ordinal, `\b`))
103109
}
104-
val := fmt.Sprintf("%v", arg.Value)
110+
escaped, err := EscapeArg(arg)
111+
if err != nil {
112+
return "", err
113+
}
114+
val := fmt.Sprintf("%v", escaped)
105115
stmt = re.ReplaceAllString(stmt, val)
106116
}
107-
return stmt
117+
return stmt, nil
108118
}
109119

110120
func query(ctx context.Context, session *hive.Session, stmt string) (driver.Rows, error) {

statement_test.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ func TestStatement(t *testing.T) {
1616
args: []driver.NamedValue{
1717
{Ordinal: 1, Value: "val_1"},
1818
},
19-
target: "val_1 p1",
19+
target: "'val_1' p1",
2020
},
2121
{
2222
stmt: "@p1 @p10 @p11 @named @named1 @p1",
@@ -25,12 +25,15 @@ func TestStatement(t *testing.T) {
2525
{Ordinal: 10, Name: "named", Value: "val_named"},
2626
{Ordinal: 11, Value: "val_11"},
2727
},
28-
target: "val_1 @p10 val_11 val_named @named1 val_1",
28+
target: "'val_1' @p10 'val_11' 'val_named' @named1 'val_1'",
2929
},
3030
}
3131

3232
for _, tt := range tests {
33-
result := statement(tt.stmt, tt.args)
33+
result, err := statement(tt.stmt, tt.args)
34+
if err != nil {
35+
t.Error(err)
36+
}
3437

3538
if result != tt.target {
3639
t.Fatalf("mismatch for statement: %q\n\ttarget: %q\n\tresult: %q", tt.stmt, tt.target, result)

0 commit comments

Comments
 (0)