Skip to content

Commit 8a82e93

Browse files
authored
Merge pull request globalsign#56 from joomcode/feature/marshal
INFRA-4180: Add support for `bson.Marshal`/`bson.Unmarshal`
2 parents 53dbeb5 + 1a9e7e9 commit 8a82e93

File tree

2 files changed

+125
-59
lines changed

2 files changed

+125
-59
lines changed

bson/decode.go

Lines changed: 65 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
package bson
2929

3030
import (
31+
"errors"
3132
"fmt"
3233
"io"
3334
"math"
@@ -69,20 +70,25 @@ func corrupted() {
6970
const (
7071
setterUnknown = iota
7172
setterNone
72-
setterType
73-
setterAddr
73+
setterTypeValue
74+
setterAddrValue
75+
setterTypeDoc
76+
setterAddrDoc
7477
)
7578

7679
var setterStyles map[reflect.Type]int
7780
var setterIfaceLegacy reflect.Type
78-
var setterIfaceHipster reflect.Type
81+
var setterIfaceValue reflect.Type
82+
var setterIfaceDoc reflect.Type
7983
var setterMutex sync.RWMutex
8084

8185
func init() {
8286
var ifaceLegacy Setter
83-
var ifaceHipster bson.ValueUnmarshaler
87+
var ifaceValue bson.ValueUnmarshaler
88+
var ifaceDoc bson.Unmarshaler
8489
setterIfaceLegacy = reflect.TypeOf(&ifaceLegacy).Elem()
85-
setterIfaceHipster = reflect.TypeOf(&ifaceHipster).Elem()
90+
setterIfaceValue = reflect.TypeOf(&ifaceValue).Elem()
91+
setterIfaceDoc = reflect.TypeOf(&ifaceDoc).Elem()
8692
setterStyles = make(map[reflect.Type]int)
8793
}
8894

@@ -96,35 +102,61 @@ func setterStyle(outt reflect.Type) int {
96102

97103
setterMutex.Lock()
98104
defer setterMutex.Unlock()
99-
if outt.Implements(setterIfaceHipster) {
100-
style = setterType
101-
} else if reflect.PtrTo(outt).Implements(setterIfaceHipster) {
102-
style = setterAddr
105+
if outt.Implements(setterIfaceValue) {
106+
style = setterTypeValue
107+
} else if reflect.PtrTo(outt).Implements(setterIfaceValue) {
108+
style = setterAddrValue
109+
} else if outt.Implements(setterIfaceDoc) {
110+
style = setterTypeDoc
111+
} else if reflect.PtrTo(outt).Implements(setterIfaceDoc) {
112+
style = setterAddrDoc
103113
} else if outt.Implements(setterIfaceLegacy) {
104-
panic(fmt.Errorf("Type %q implements %q but not implements %q", outt.String(), setterIfaceLegacy.String(), setterIfaceHipster.String()))
114+
panic(fmt.Errorf("Type %q implements %q but not implements %q or %q", outt.String(), setterIfaceLegacy.String(), setterIfaceValue.String(), setterIfaceDoc.String()))
105115
} else if reflect.PtrTo(outt).Implements(setterIfaceLegacy) {
106-
panic(fmt.Errorf("Type %q implements %q but not implements %q", outt.String(), setterIfaceLegacy.String(), setterIfaceHipster.String()))
116+
panic(fmt.Errorf("Type %q implements %q but not implements %q or %q", outt.String(), setterIfaceLegacy.String(), setterIfaceValue.String(), setterIfaceDoc.String()))
107117
} else {
108118
style = setterNone
109119
}
110120
setterStyles[outt] = style
111121
return style
112122
}
113123

114-
func getSetter(outt reflect.Type, out reflect.Value) bson.ValueUnmarshaler {
124+
func setValue(outt reflect.Type, out reflect.Value, cb func() (bsontype.Type, []byte)) (bool, error) {
115125
style := setterStyle(outt)
116-
if style == setterNone {
117-
return nil
118-
}
119-
if style == setterAddr {
126+
switch style {
127+
case setterAddrValue:
120128
if !out.CanAddr() {
121-
return nil
129+
return false, nil
130+
}
131+
btype, data := cb()
132+
return true, out.Addr().Interface().(bson.ValueUnmarshaler).UnmarshalBSONValue(btype, data)
133+
case setterTypeValue:
134+
if outt.Kind() == reflect.Ptr && out.IsNil() {
135+
out.Set(reflect.New(outt.Elem()))
136+
}
137+
btype, data := cb()
138+
return true, out.Interface().(bson.ValueUnmarshaler).UnmarshalBSONValue(btype, data)
139+
case setterAddrDoc:
140+
if !out.CanAddr() {
141+
return false, nil
142+
}
143+
btype, data := cb()
144+
if btype != bsontype.EmbeddedDocument {
145+
return true, fmt.Errorf("unexpected bson type: %v", btype)
146+
}
147+
return true, out.Addr().Interface().(bson.Unmarshaler).UnmarshalBSON(data)
148+
case setterTypeDoc:
149+
if outt.Kind() == reflect.Ptr && out.IsNil() {
150+
out.Set(reflect.New(outt.Elem()))
151+
}
152+
btype, data := cb()
153+
if btype != bsontype.EmbeddedDocument {
154+
return true, fmt.Errorf("unexpected bson type: %v", btype)
122155
}
123-
out = out.Addr()
124-
} else if outt.Kind() == reflect.Ptr && out.IsNil() {
125-
out.Set(reflect.New(outt.Elem()))
156+
return true, out.Interface().(bson.Unmarshaler).UnmarshalBSON(data)
157+
default:
158+
return false, nil
126159
}
127-
return out.Interface().(bson.ValueUnmarshaler)
128160
}
129161

130162
func clearMap(m reflect.Value) {
@@ -143,13 +175,16 @@ func (d *decoder) readDocTo(out reflect.Value) {
143175
if outk == reflect.Ptr && out.IsNil() {
144176
out.Set(reflect.New(outt.Elem()))
145177
}
146-
if setter := getSetter(outt, out); setter != nil {
178+
if ok, err := setValue(outt, out, func() (bsontype.Type, []byte) {
147179
raw := d.readRaw(ElementDocument)
148-
err := setter.UnmarshalBSONValue(raw.Type, raw.Value)
149-
if _, ok := err.(*TypeError); err != nil && !ok {
180+
return raw.Type, raw.Value
181+
}); err != nil {
182+
if _, ok := err.(*TypeError); !ok {
150183
panic(err)
151184
}
152185
return
186+
} else if ok {
187+
return
153188
}
154189
if outk == reflect.Ptr {
155190
out = out.Elem()
@@ -407,7 +442,7 @@ func (d *decoder) readSliceDoc(t reflect.Type) interface{} {
407442
func BSONElementSize(kind bsontype.Type, offset int, buffer []byte) (int, error) {
408443
value, _, ok := bsoncore.ReadValue(buffer[offset:], kind)
409444
if !ok {
410-
return 0, bsoncore.ErrCorruptedDocument
445+
return 0, errors.New("Document is corrupted")
411446
}
412447
return len(value.Data), nil
413448
}
@@ -555,20 +590,20 @@ func (d *decoder) readElemTo(out reflect.Value, kind bsontype.Type) (good bool)
555590
return true
556591
}
557592

558-
if setter := getSetter(outt, out); setter != nil {
593+
if ok, err := setValue(outt, out, func() (bsontype.Type, []byte) {
559594
raw := d.readRaw(kind)
560-
err := setter.UnmarshalBSONValue(raw.Type, raw.Value)
595+
return raw.Type, raw.Value
596+
}); err != nil {
561597
if err == ErrSetZero {
562598
out.Set(reflect.Zero(outt))
563599
return true
564600
}
565-
if err == nil {
566-
return true
567-
}
568601
if _, ok := err.(*TypeError); !ok {
569602
panic(err)
570603
}
571604
return false
605+
} else if ok {
606+
return true
572607
}
573608

574609
var in interface{}

bson/encode.go

Lines changed: 60 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -81,16 +81,20 @@ const itoaCacheSize = 32
8181
const (
8282
getterUnknown = iota
8383
getterNone
84-
getterTypeVal
85-
getterTypePtr
86-
getterAddr
84+
getterTypeValValue
85+
getterTypePtrValue
86+
getterAddrValue
87+
getterTypeValDoc
88+
getterTypePtrDoc
89+
getterAddrDoc
8790
)
8891

8992
var itoaCache []string
9093

9194
var getterStyles map[reflect.Type]int
9295
var getterIfaceLegacy reflect.Type
93-
var getterIfaceHipster reflect.Type
96+
var getterIfaceValue reflect.Type
97+
var getterIfaceDoc reflect.Type
9498
var getterMutex sync.RWMutex
9599

96100
func init() {
@@ -99,9 +103,11 @@ func init() {
99103
itoaCache[i] = strconv.Itoa(i)
100104
}
101105
var ifaceLegacy Getter
102-
var ifaceHipster bson.ValueMarshaler
106+
var ifaceValue bson.ValueMarshaler
107+
var ifaceDoc bson.ValueMarshaler
103108
getterIfaceLegacy = reflect.TypeOf(&ifaceLegacy).Elem()
104-
getterIfaceHipster = reflect.TypeOf(&ifaceHipster).Elem()
109+
getterIfaceValue = reflect.TypeOf(&ifaceValue).Elem()
110+
getterIfaceDoc = reflect.TypeOf(&ifaceDoc).Elem()
105111
getterStyles = make(map[reflect.Type]int)
106112
}
107113

@@ -122,44 +128,71 @@ func getterStyle(outt reflect.Type) int {
122128

123129
getterMutex.Lock()
124130
defer getterMutex.Unlock()
125-
if outt.Implements(getterIfaceHipster) {
131+
if outt.Implements(getterIfaceValue) {
126132
vt := outt
127133
for vt.Kind() == reflect.Ptr {
128134
vt = vt.Elem()
129135
}
130-
if vt.Implements(getterIfaceHipster) {
131-
style = getterTypeVal
136+
if vt.Implements(getterIfaceValue) {
137+
style = getterTypeValValue
132138
} else {
133-
style = getterTypePtr
139+
style = getterTypePtrValue
134140
}
135-
} else if reflect.PtrTo(outt).Implements(getterIfaceHipster) {
136-
style = getterAddr
141+
} else if reflect.PtrTo(outt).Implements(getterIfaceValue) {
142+
style = getterAddrValue
143+
} else if outt.Implements(getterIfaceDoc) {
144+
vt := outt
145+
for vt.Kind() == reflect.Ptr {
146+
vt = vt.Elem()
147+
}
148+
if vt.Implements(getterIfaceDoc) {
149+
style = getterTypeValDoc
150+
} else {
151+
style = getterTypePtrDoc
152+
}
153+
} else if reflect.PtrTo(outt).Implements(getterIfaceDoc) {
154+
style = getterAddrDoc
137155
} else if outt.Implements(getterIfaceLegacy) {
138-
panic(fmt.Errorf("Type %q implements %q but not implements %q", outt.String(), getterIfaceLegacy.String(), getterIfaceHipster.String()))
156+
panic(fmt.Errorf("Type %q implements %q but not implements %q or %q", outt.String(), getterIfaceLegacy.String(), getterIfaceValue.String(), getterIfaceDoc.String()))
139157
} else if reflect.PtrTo(outt).Implements(getterIfaceLegacy) {
140-
panic(fmt.Errorf("Type %q implements %q but not implements %q", outt.String(), getterIfaceLegacy.String(), getterIfaceHipster.String()))
158+
panic(fmt.Errorf("Type %q implements %q but not implements %q or %q", outt.String(), getterIfaceLegacy.String(), getterIfaceValue.String(), getterIfaceDoc.String()))
141159
} else {
142160
style = getterNone
143161
}
144162
getterStyles[outt] = style
145163
return style
146164
}
147165

148-
func getGetter(outt reflect.Type, out reflect.Value) bson.ValueMarshaler {
166+
func getValue(outt reflect.Type, out reflect.Value) (bool, bsontype.Type, []byte, error) {
149167
style := getterStyle(outt)
150-
if style == getterNone {
151-
return nil
152-
}
153-
if style == getterAddr {
168+
switch style {
169+
case getterAddrValue:
154170
if !out.CanAddr() {
155-
return nil
171+
return false, 0, nil, nil
156172
}
157-
return out.Addr().Interface().(bson.ValueMarshaler)
158-
}
159-
if style == getterTypeVal && out.Kind() == reflect.Ptr && out.IsNil() {
160-
return nil
173+
btype, data, err := out.Addr().Interface().(bson.ValueMarshaler).MarshalBSONValue()
174+
return true, btype, data, err
175+
case getterTypeValValue, getterTypePtrValue:
176+
if style == getterTypeValValue && out.Kind() == reflect.Ptr && out.IsNil() {
177+
return false, 0, nil, nil
178+
}
179+
btype, data, err := out.Interface().(bson.ValueMarshaler).MarshalBSONValue()
180+
return true, btype, data, err
181+
case getterAddrDoc:
182+
if !out.CanAddr() {
183+
return false, 0, nil, nil
184+
}
185+
data, err := out.Addr().Interface().(bson.Marshaler).MarshalBSON()
186+
return true, bsontype.EmbeddedDocument, data, err
187+
case getterTypeValDoc, getterTypePtrDoc:
188+
if style == getterTypeValDoc && out.Kind() == reflect.Ptr && out.IsNil() {
189+
return false, 0, nil, nil
190+
}
191+
data, err := out.Interface().(bson.Marshaler).MarshalBSON()
192+
return true, bsontype.EmbeddedDocument, data, err
193+
default:
194+
return false, 0, nil, nil
161195
}
162-
return out.Interface().(bson.ValueMarshaler)
163196
}
164197

165198
// --------------------------------------------------------------------------
@@ -171,8 +204,7 @@ type encoder struct {
171204

172205
func (e *encoder) addDoc(v reflect.Value) {
173206
for v.IsValid() {
174-
if vi := getGetter(v.Type(), v); vi != nil {
175-
btype, data, err := vi.MarshalBSONValue()
207+
if ok, btype, data, err := getValue(v.Type(), v); ok {
176208
if err != nil {
177209
panic(err)
178210
}
@@ -359,8 +391,7 @@ loop:
359391
return
360392
}
361393

362-
if getter := getGetter(v.Type(), v); getter != nil {
363-
btype, data, err := getter.MarshalBSONValue()
394+
if ok, btype, data, err := getValue(v.Type(), v); ok {
364395
if err != nil {
365396
panic(err)
366397
}

0 commit comments

Comments
 (0)