Skip to content

Commit c94d00e

Browse files
committed
fixed tests; fully working with generics
1 parent e255e1c commit c94d00e

File tree

2 files changed

+50
-19
lines changed

2 files changed

+50
-19
lines changed

tree/tree.go

+17-12
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ func (t *Tree[T]) Root() Node[T] {
7272
// Do not set a primaryID to zero, as this value should be reserved for the
7373
// case where a node has no parent.
7474
func (t *Tree[T]) Add(nodeID uint, parentID uint, data T) (added bool, exists bool) {
75+
7576
child := &node[T]{primary: nodeID, parentID: parentID, data: data}
7677

7778
// Return false if this element has already been added
@@ -190,6 +191,13 @@ func (t *Tree[T]) FindParents(id uint) (parents []Node[T], ok bool) {
190191
return parents, true
191192
}
192193

194+
type serialNode[T any] struct {
195+
// translates the important fields of a node for serialization
196+
Primary uint
197+
ParentID uint
198+
Data T
199+
}
200+
193201
// Serialize encodes the tree as a byte stream.
194202
//
195203
// The argument TraversalType will determine the traversal order in which
@@ -214,7 +222,11 @@ func (t *Tree[T]) Serialize(trvsl TraversalType) (io.ReadCloser, <-chan error) {
214222
go func() {
215223
encoder := json.NewEncoder(writer)
216224
for n := range t.Traverse(trvsl) {
217-
err := encoder.Encode(n)
225+
err := encoder.Encode(serialNode[T]{
226+
Primary: n.GetID(),
227+
ParentID: n.GetParentID(),
228+
Data: n.GetData(),
229+
})
218230
if err != nil {
219231
errchan <- err
220232
writer.Close()
@@ -230,12 +242,6 @@ func (t *Tree[T]) Serialize(trvsl TraversalType) (io.ReadCloser, <-chan error) {
230242
return reader, errchan
231243
}
232244

233-
type serialNode struct {
234-
Primary uint
235-
ParentID uint
236-
Data json.RawMessage
237-
}
238-
239245
// Deserialize decodes a data stream into a tree.
240246
//
241247
// Decode is validated for data streams encoded via the [`Serialize`]
@@ -247,23 +253,22 @@ type serialNode struct {
247253
// error.
248254
func Deserialize[T any](stream io.ReadCloser) (*Tree[T], error) {
249255
decoder := json.NewDecoder(stream)
250-
var n node[T]
251256
t := Empty[T]()
257+
252258
for {
253259

260+
var n serialNode[T]
261+
254262
err := decoder.Decode(&n)
255263
if err == io.EOF {
256-
//log.Printf("deserialize - end of file")
257264
return t, nil
258265
}
259266

260267
if err != nil {
261-
//log.Printf("deserialize - error: %s", err)
262268
return nil, fmt.Errorf("error deserializing: %w", err)
263269
}
264270

265-
//log.Printf("deserialize - adding %d %d %+v", n.GetID(), n.GetParentID(), n.GetData())
266-
t.Add(n.GetID(), n.GetParentID(), n.GetData())
271+
t.Add(n.Primary, n.ParentID, n.Data)
267272

268273
}
269274

tree/tree_test.go

+33-7
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,36 @@ func TestDeserializeMap(t *testing.T) {
571571
expBFC: []uint{},
572572
expDFC: []uint{},
573573
},
574+
"serialize breadth first": {
575+
prep: func() *Tree[elem] {
576+
t := Empty[elem]()
577+
t.Add(1, 0, elem{"one": {1, 2}, "two": {2}})
578+
t.Add(2, 1, elem{"two": {2}})
579+
t.Add(3, 2, elem{"three": {3}})
580+
t.Add(4, 1, elem{"four": {4}})
581+
t.Add(5, 4, elem{"five": {5}})
582+
return t
583+
},
584+
traversal: TraverseBreadthFirst,
585+
expErr: nil,
586+
expBFC: []uint{1, 2, 4, 3, 5},
587+
expDFC: []uint{1, 2, 3, 4, 5},
588+
dataAssert: func(t *testing.T, gotTree *Tree[elem]) {
589+
iter := gotTree.Traverse(TraverseBreadthFirst)
590+
expData := []elem{
591+
{"one": {1, 2}, "two": {2}},
592+
{"two": {2}},
593+
{"four": {4}},
594+
{"three": {3}},
595+
{"five": {5}},
596+
}
597+
gotData := []elem{}
598+
for e := range iter {
599+
gotData = append(gotData, e.GetData())
600+
}
601+
assert.Equal(t, expData, gotData)
602+
},
603+
},
574604
}
575605

576606
for name, tt := range tests {
@@ -655,18 +685,14 @@ func TestDeserializeStruct(t *testing.T) {
655685
for name, tt := range tests {
656686
t.Run(name, func(t *testing.T) {
657687

658-
// this test assumes that Serialize will throw no errors
659-
rdr, _ := tt.prep().Serialize(TraverseBreadthFirst)
688+
prepTree := tt.prep()
660689

661-
//t.Logf("Started to serialize")
690+
// this test assumes that Serialize will throw no errors
691+
rdr, _ := prepTree.Serialize(TraverseBreadthFirst)
662692

663693
gotTree, gotErr := Deserialize[embeddedSerializable](rdr)
664694

665-
//t.Logf("Finished deserializing")
666-
667695
assert.Equal(t, tt.expErr, gotErr)
668-
//t.Logf("Arguments: %+v\n", tt)
669-
t.Logf("Results: {tree: %+v, error %+v}\n", gotTree, gotErr)
670696

671697
// only check the tree value if both expected and got errors are nil
672698
if gotErr == nil && tt.expErr == nil {

0 commit comments

Comments
 (0)