Skip to content

Commit dba7b4c

Browse files
bozarodomodwyer
authored andcommitted
Fix GetBSON() method usage (#40)
* Fix GetBSON() method usage Original issue --- You can't use type with custom GetBSON() method mixed with structure field type and structure field reference type. For example, you can't create custom GetBSON() for Bar type: ``` struct Foo { a Bar b *Bar } ``` Type implementation (`func (t Bar) GetBSON()` ) would crash on `Foo.b = nil` value encoding. Reference implementation (`func (t *Bar) GetBSON()` ) would not call on `Foo.a` value encoding. After this change --- For type implementation `func (t Bar) GetBSON()` would not call on `Foo.b = nil` value encoding. In this case `nil` value would be seariazied as `nil` BSON value. For reference implementation `func (t *Bar) GetBSON()` would call even on `Foo.a` value encoding. * Minor refactoring
1 parent fd79249 commit dba7b4c

File tree

3 files changed

+141
-12
lines changed

3 files changed

+141
-12
lines changed

bson/bson_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import (
3636
"reflect"
3737
"testing"
3838
"time"
39+
"strings"
3940

4041
"github.com/globalsign/mgo/bson"
4142
. "gopkg.in/check.v1"
@@ -381,8 +382,54 @@ func (s *S) Test64bitInt(c *C) {
381382
// --------------------------------------------------------------------------
382383
// Generic two-way struct marshaling tests.
383384

385+
type prefixPtr string
386+
type prefixVal string
387+
388+
func (t *prefixPtr) GetBSON() (interface{}, error) {
389+
if t == nil {
390+
return nil, nil
391+
}
392+
return "foo-" + string(*t), nil
393+
}
394+
395+
func (t *prefixPtr) SetBSON(raw bson.Raw) error {
396+
var s string
397+
if raw.Kind == 0x0A {
398+
return bson.ErrSetZero
399+
}
400+
if err := raw.Unmarshal(&s); err != nil {
401+
return err
402+
}
403+
if !strings.HasPrefix(s, "foo-") {
404+
return errors.New("Prefix not found: " + s)
405+
}
406+
*t = prefixPtr(s[4:])
407+
return nil
408+
}
409+
410+
func (t prefixVal) GetBSON() (interface{}, error) {
411+
return "foo-" + string(t), nil
412+
}
413+
414+
func (t *prefixVal) SetBSON(raw bson.Raw) error {
415+
var s string
416+
if raw.Kind == 0x0A {
417+
return bson.ErrSetZero
418+
}
419+
if err := raw.Unmarshal(&s); err != nil {
420+
return err
421+
}
422+
if !strings.HasPrefix(s, "foo-") {
423+
return errors.New("Prefix not found: " + s)
424+
}
425+
*t = prefixVal(s[4:])
426+
return nil
427+
}
428+
384429
var bytevar = byte(8)
385430
var byteptr = &bytevar
431+
var prefixptr = prefixPtr("bar")
432+
var prefixval = prefixVal("bar")
386433

387434
var structItems = []testItemType{
388435
{&struct{ Ptr *byte }{nil},
@@ -419,6 +466,24 @@ var structItems = []testItemType{
419466
// Byte arrays.
420467
{&struct{ V [2]byte }{[2]byte{'y', 'o'}},
421468
"\x05v\x00\x02\x00\x00\x00\x00yo"},
469+
470+
{&struct{ V prefixPtr }{prefixPtr("buzz")},
471+
"\x02v\x00\x09\x00\x00\x00foo-buzz\x00"},
472+
473+
{&struct{ V *prefixPtr }{&prefixptr},
474+
"\x02v\x00\x08\x00\x00\x00foo-bar\x00"},
475+
476+
{&struct{ V *prefixPtr }{nil},
477+
"\x0Av\x00"},
478+
479+
{&struct{ V prefixVal }{prefixVal("buzz")},
480+
"\x02v\x00\x09\x00\x00\x00foo-buzz\x00"},
481+
482+
{&struct{ V *prefixVal }{&prefixval},
483+
"\x02v\x00\x08\x00\x00\x00foo-bar\x00"},
484+
485+
{&struct{ V *prefixVal }{nil},
486+
"\x0Av\x00"},
422487
}
423488

424489
func (s *S) TestMarshalStructItems(c *C) {

bson/decode.go

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,18 +87,20 @@ func setterStyle(outt reflect.Type) int {
8787
setterMutex.RLock()
8888
style := setterStyles[outt]
8989
setterMutex.RUnlock()
90-
if style == setterUnknown {
91-
setterMutex.Lock()
92-
defer setterMutex.Unlock()
93-
if outt.Implements(setterIface) {
94-
setterStyles[outt] = setterType
95-
} else if reflect.PtrTo(outt).Implements(setterIface) {
96-
setterStyles[outt] = setterAddr
97-
} else {
98-
setterStyles[outt] = setterNone
99-
}
100-
style = setterStyles[outt]
90+
if style != setterUnknown {
91+
return style
92+
}
93+
94+
setterMutex.Lock()
95+
defer setterMutex.Unlock()
96+
if outt.Implements(setterIface) {
97+
style = setterType
98+
} else if reflect.PtrTo(outt).Implements(setterIface) {
99+
style = setterAddr
100+
} else {
101+
style = setterNone
101102
}
103+
setterStyles[outt] = style
102104
return style
103105
}
104106

bson/encode.go

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import (
3535
"reflect"
3636
"sort"
3737
"strconv"
38+
"sync"
3839
"time"
3940
)
4041

@@ -60,13 +61,28 @@ var (
6061

6162
const itoaCacheSize = 32
6263

64+
const (
65+
getterUnknown = iota
66+
getterNone
67+
getterTypeVal
68+
getterTypePtr
69+
getterAddr
70+
)
71+
6372
var itoaCache []string
6473

74+
var getterStyles map[reflect.Type]int
75+
var getterIface reflect.Type
76+
var getterMutex sync.RWMutex
77+
6578
func init() {
6679
itoaCache = make([]string, itoaCacheSize)
6780
for i := 0; i != itoaCacheSize; i++ {
6881
itoaCache[i] = strconv.Itoa(i)
6982
}
83+
var iface Getter
84+
getterIface = reflect.TypeOf(&iface).Elem()
85+
getterStyles = make(map[reflect.Type]int)
7086
}
7187

7288
func itoa(i int) string {
@@ -76,6 +92,52 @@ func itoa(i int) string {
7692
return strconv.Itoa(i)
7793
}
7894

95+
func getterStyle(outt reflect.Type) int {
96+
getterMutex.RLock()
97+
style := getterStyles[outt]
98+
getterMutex.RUnlock()
99+
if style != getterUnknown {
100+
return style
101+
}
102+
103+
getterMutex.Lock()
104+
defer getterMutex.Unlock()
105+
if outt.Implements(getterIface) {
106+
vt := outt
107+
for vt.Kind() == reflect.Ptr {
108+
vt = vt.Elem()
109+
}
110+
if vt.Implements(getterIface) {
111+
style = getterTypeVal
112+
} else {
113+
style = getterTypePtr
114+
}
115+
} else if reflect.PtrTo(outt).Implements(getterIface) {
116+
style = getterAddr
117+
} else {
118+
style = getterNone
119+
}
120+
getterStyles[outt] = style
121+
return style
122+
}
123+
124+
func getGetter(outt reflect.Type, out reflect.Value) Getter {
125+
style := getterStyle(outt)
126+
if style == getterNone {
127+
return nil
128+
}
129+
if style == getterAddr {
130+
if !out.CanAddr() {
131+
return nil
132+
}
133+
return out.Addr().Interface().(Getter)
134+
}
135+
if style == getterTypeVal && out.Kind() == reflect.Ptr && out.IsNil() {
136+
return nil
137+
}
138+
return out.Interface().(Getter)
139+
}
140+
79141
// --------------------------------------------------------------------------
80142
// Marshaling of the document value itself.
81143

@@ -253,7 +315,7 @@ func (e *encoder) addElem(name string, v reflect.Value, minSize bool) {
253315
return
254316
}
255317

256-
if getter, ok := v.Interface().(Getter); ok {
318+
if getter := getGetter(v.Type(), v); getter != nil {
257319
getv, err := getter.GetBSON()
258320
if err != nil {
259321
panic(err)

0 commit comments

Comments
 (0)