Skip to content

Commit 0e6b271

Browse files
Denys Smirnovdennwc
authored andcommitted
schema: support writing structs that contain loops; fix #731
1 parent eba6c31 commit 0e6b271

File tree

2 files changed

+130
-45
lines changed

2 files changed

+130
-45
lines changed

schema/schema.go

Lines changed: 84 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -866,51 +866,57 @@ func isZero(rv reflect.Value) bool {
866866
return rv.Interface() == reflect.Zero(rv.Type()).Interface() // TODO(dennwc): rewrite
867867
}
868868

869-
func (c *Config) writeOneValReflect(w quad.Writer, id quad.Value, pred quad.Value, rv reflect.Value, rev bool) error {
870-
if isZero(rv) {
871-
return nil
872-
}
873-
targ, ok := quad.AsValue(rv.Interface())
874-
if !ok {
875-
if rv.Kind() == reflect.Ptr {
876-
rv = rv.Elem()
877-
}
878-
targ, ok = quad.AsValue(rv.Interface())
879-
if !ok && rv.Kind() == reflect.Struct {
880-
sid, err := c.WriteAsQuads(w, rv.Interface())
881-
if err != nil {
882-
return err
883-
}
884-
targ, ok = sid, true
885-
}
886-
}
887-
if !ok {
888-
return fmt.Errorf("unsupported type: %T", rv.Interface())
889-
}
890-
s, o := id, targ
869+
func (c *Config) writeQuad(w quad.Writer, s, p, o quad.Value, rev bool) error {
891870
if rev {
892871
s, o = o, s
893872
}
894-
return w.WriteQuad(quad.Quad{Subject: s, Predicate: pred, Object: o, Label: c.Label})
873+
return w.WriteQuad(quad.Quad{Subject: s, Predicate: p, Object: o, Label: c.Label})
895874
}
896875

897-
func (c *Config) writeValueAs(w quad.Writer, id quad.Value, rv reflect.Value, pref string, rules fieldRules) error {
898-
if rv.Kind() == reflect.Ptr {
899-
rv = rv.Elem()
876+
// writeOneValReflect writes a set of quads corresponding to a value. It may omit writing quads if value is zero.
877+
func (c *Config) writeOneValReflect(w quad.Writer, id quad.Value, pred quad.Value, rv reflect.Value, rev bool, seen map[uintptr]quad.Value) error {
878+
if isZero(rv) {
879+
return nil
900880
}
901-
rt := rv.Type()
881+
// write field value and get an ID
882+
sid, err := c.writeAsQuads(w, rv, seen)
883+
if err != nil {
884+
return err
885+
}
886+
// write a quad pointing to this value
887+
return c.writeQuad(w, id, pred, sid, rev)
888+
}
889+
890+
func (c *Config) writeTypeInfo(w quad.Writer, id quad.Value, rt reflect.Type) error {
902891
typesMu.RLock()
903892
iri := typeToIRI[rt]
904893
typesMu.RUnlock()
905-
if iri != quad.IRI("") {
906-
if err := w.WriteQuad(quad.Quad{Subject: id, Predicate: c.iri(iriType), Object: c.iri(iri), Label: c.Label}); err != nil {
907-
return err
894+
if iri == quad.IRI("") {
895+
return nil
896+
}
897+
return c.writeQuad(w, id, c.iri(iriType), c.iri(iri), false)
898+
}
899+
900+
func (c *Config) writeValueAs(w quad.Writer, id quad.Value, rv reflect.Value, pref string, rules fieldRules, seen map[uintptr]quad.Value) error {
901+
switch kind := rv.Kind(); kind {
902+
case reflect.Ptr, reflect.Map:
903+
ptr := rv.Pointer()
904+
if _, ok := seen[ptr]; ok {
905+
return nil
908906
}
907+
seen[ptr] = id
908+
if kind == reflect.Ptr {
909+
rv = rv.Elem()
910+
}
911+
}
912+
rt := rv.Type()
913+
if err := c.writeTypeInfo(w, id, rt); err != nil {
914+
return err
909915
}
910916
for i := 0; i < rt.NumField(); i++ {
911917
f := rt.Field(i)
912918
if f.Anonymous {
913-
if err := c.writeValueAs(w, id, rv.Field(i), pref+f.Name+".", rules); err != nil {
919+
if err := c.writeValueAs(w, id, rv.Field(i), pref+f.Name+".", rules, seen); err != nil {
914920
return err
915921
}
916922
continue
@@ -928,7 +934,7 @@ func (c *Config) writeValueAs(w quad.Writer, id quad.Value, rv reflect.Value, pr
928934
if f.Type.Kind() == reflect.Slice {
929935
sl := rv.Field(i)
930936
for j := 0; j < sl.Len(); j++ {
931-
if err := c.writeOneValReflect(w, id, r.Pred, sl.Index(j), r.Rev); err != nil {
937+
if err := c.writeOneValReflect(w, id, r.Pred, sl.Index(j), r.Rev, seen); err != nil {
932938
return err
933939
}
934940
}
@@ -937,7 +943,7 @@ func (c *Config) writeValueAs(w quad.Writer, id quad.Value, rv reflect.Value, pr
937943
if !r.Opt && isZero(fv) {
938944
return ErrReqFieldNotSet{Field: f.Name}
939945
}
940-
if err := c.writeOneValReflect(w, id, r.Pred, fv, r.Rev); err != nil {
946+
if err := c.writeOneValReflect(w, id, r.Pred, fv, r.Rev, seen); err != nil {
941947
return err
942948
}
943949
}
@@ -991,29 +997,64 @@ func (c *Config) idFor(rules fieldRules, rt reflect.Type, rv reflect.Value, pref
991997
//
992998
// See LoadTo for a list of quads mapping rules.
993999
func (c *Config) WriteAsQuads(w quad.Writer, o interface{}) (quad.Value, error) {
994-
if v, ok := o.(quad.Value); ok {
995-
return v, nil
1000+
return c.writeAsQuads(w, reflect.ValueOf(o), make(map[uintptr]quad.Value))
1001+
}
1002+
1003+
var reflQuadValue = reflect.TypeOf((*quad.Value)(nil)).Elem()
1004+
1005+
func (c *Config) writeAsQuads(w quad.Writer, rv reflect.Value, seen map[uintptr]quad.Value) (quad.Value, error) {
1006+
rt := rv.Type()
1007+
// if node is a primitive - return directly
1008+
if rt.Implements(reflQuadValue) {
1009+
return rv.Interface().(quad.Value), nil
1010+
}
1011+
prv := rv
1012+
kind := rt.Kind()
1013+
// check if we've seen this node already
1014+
switch kind {
1015+
case reflect.Ptr, reflect.Map:
1016+
ptr := prv.Pointer()
1017+
if sid, ok := seen[ptr]; ok {
1018+
return sid, nil
1019+
}
1020+
if kind == reflect.Ptr {
1021+
rv = rv.Elem()
1022+
rt = rv.Type()
1023+
kind = rt.Kind()
1024+
}
9961025
}
997-
rv := reflect.ValueOf(o)
998-
if rv.Kind() == reflect.Ptr {
999-
rv = rv.Elem()
1026+
// check if it's a type that quads package supports
1027+
// note, that it may be a struct such as time.Time
1028+
if val, ok := quad.AsValue(rv.Interface()); ok {
1029+
return val, nil
10001030
}
1001-
rt := rv.Type()
1031+
// TODO(dennwc): support maps
1032+
if kind != reflect.Struct {
1033+
return nil, fmt.Errorf("unsupported type: %v", rt)
1034+
}
1035+
// get conversion rules for this struct type
10021036
rules, err := c.rulesFor(rt)
10031037
if err != nil {
10041038
return nil, fmt.Errorf("can't load rules: %v", err)
10051039
}
10061040
if len(rules) == 0 {
10071041
return nil, fmt.Errorf("no rules for struct: %v", rt)
10081042
}
1043+
// get an ID from the struct value
10091044
id, err := c.idFor(rules, rt, rv, "")
10101045
if err != nil {
10111046
return nil, err
10121047
}
10131048
if id == nil {
1014-
id = c.genID(o)
1049+
id = c.genID(prv.Interface())
1050+
}
1051+
// save a node ID to avoid loops
1052+
switch prv.Kind() {
1053+
case reflect.Ptr, reflect.Map:
1054+
ptr := prv.Pointer()
1055+
seen[ptr] = id
10151056
}
1016-
if err = c.writeValueAs(w, id, rv, "", rules); err != nil {
1057+
if err = c.writeValueAs(w, id, rv, "", rules, seen); err != nil {
10171058
return nil, err
10181059
}
10191060
return id, nil
@@ -1031,13 +1072,14 @@ func (c *Config) WriteNamespaces(w quad.Writer, n *voc.Namespaces) error {
10311072
if err != nil {
10321073
return fmt.Errorf("can't load rules: %v", err)
10331074
}
1075+
seen := make(map[uintptr]quad.Value)
10341076
for _, ns := range n.List() {
10351077
obj := namespace{
10361078
Full: quad.IRI(ns.Full),
10371079
Prefix: quad.IRI(ns.Prefix),
10381080
}
10391081
rv := reflect.ValueOf(obj)
1040-
if err = c.writeValueAs(w, obj.Full, rv, "", rules); err != nil {
1082+
if err = c.writeValueAs(w, obj.Full, rv, "", rules, seen); err != nil {
10411083
return err
10421084
}
10431085
}

schema/schema_test.go

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,12 @@ type Coords struct {
109109
Lng float64 `json:"ex:lng"`
110110
}
111111

112+
type NodeLoop struct {
113+
ID quad.IRI `quad:"@id"`
114+
Name string `quad:"name"`
115+
Next *NodeLoop `quad:"next"`
116+
}
117+
112118
func iri(s string) quad.IRI { return quad.IRI(s) }
113119

114120
const typeIRI = quad.IRI(rdf.Type)
@@ -291,6 +297,43 @@ var testWriteValueCases = []struct {
291297
},
292298
nil,
293299
},
300+
{
301+
"self loop",
302+
func() *NodeLoop {
303+
a := &NodeLoop{ID: iri("A"), Name: "Node A"}
304+
a.Next = a
305+
return a
306+
}(),
307+
iri("A"),
308+
[]quad.Quad{
309+
{iri("A"), iri("name"), quad.String("Node A"), nil},
310+
{iri("A"), iri("next"), iri("A"), nil},
311+
},
312+
nil,
313+
},
314+
{
315+
"pointer chain",
316+
func() *NodeLoop {
317+
a := &NodeLoop{ID: iri("A"), Name: "Node A"}
318+
b := &NodeLoop{ID: iri("B"), Name: "Node B"}
319+
c := &NodeLoop{ID: iri("C"), Name: "Node C"}
320+
321+
a.Next = b
322+
b.Next = c
323+
c.Next = a
324+
return a
325+
}(),
326+
iri("A"),
327+
[]quad.Quad{
328+
{iri("A"), iri("name"), quad.String("Node A"), nil},
329+
{iri("B"), iri("name"), quad.String("Node B"), nil},
330+
{iri("C"), iri("name"), quad.String("Node C"), nil},
331+
{iri("C"), iri("next"), iri("A"), nil},
332+
{iri("B"), iri("next"), iri("C"), nil},
333+
{iri("A"), iri("next"), iri("B"), nil},
334+
},
335+
nil,
336+
},
294337
}
295338

296339
type quadSlice []quad.Quad
@@ -302,12 +345,12 @@ func (s *quadSlice) WriteQuad(q quad.Quad) error {
302345

303346
func TestWriteAsQuads(t *testing.T) {
304347
sch := schema.NewConfig()
305-
for i, c := range testWriteValueCases {
348+
for _, c := range testWriteValueCases {
306349
t.Run(c.name, func(t *testing.T) {
307350
var out quadSlice
308351
id, err := sch.WriteAsQuads(&out, c.obj)
309352
if err != c.err {
310-
t.Errorf("case %d failed: %v != %v", i, err, c.err)
353+
t.Errorf("unexpected error: %v (expected: %v)", err, c.err)
311354
} else if c.err != nil {
312355
return // case with expected error; omit other checks
313356
}
@@ -661,4 +704,4 @@ func TestSaveNamespaces(t *testing.T) {
661704
if !reflect.DeepEqual(expect, q) {
662705
t.Fatalf("wrong quads returned: got: %v, expect: %v", q, expect)
663706
}
664-
}
707+
}

0 commit comments

Comments
 (0)