Skip to content

Commit 371ee25

Browse files
committed
Merge pull request #123 from tylertreat/fixes
Fix Ctrie snapshotting
2 parents 074c32b + 38e136e commit 371ee25

File tree

2 files changed

+61
-17
lines changed

2 files changed

+61
-17
lines changed

trie/ctrie/ctrie.go

+20-17
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,10 @@ type Ctrie struct {
6363
}
6464

6565
// generation demarcates Ctrie snapshots. We use a heap-allocated reference
66-
// instead of an integer to avoid integer overflows.
67-
type generation struct{}
66+
// instead of an integer to avoid integer overflows. Struct must have a field
67+
// on it since two distinct zero-size variables may have the same address in
68+
// memory.
69+
type generation struct{ _ int }
6870

6971
// iNode is an indirection node. I-nodes remain present in the Ctrie even as
7072
// nodes above and below change. Thread-safety is achieved in part by
@@ -320,7 +322,7 @@ func (c *Ctrie) Snapshot() *Ctrie {
320322
root := c.readRoot()
321323
main := gcasRead(root, c)
322324
if c.rdcssRoot(root, main, root.copyToGen(&generation{}, c)) {
323-
return newCtrie(root.copyToGen(&generation{}, c), c.hashFactory, c.readOnly)
325+
return newCtrie(c.readRoot().copyToGen(&generation{}, c), c.hashFactory, c.readOnly)
324326
}
325327
}
326328
}
@@ -335,7 +337,7 @@ func (c *Ctrie) ReadOnlySnapshot() *Ctrie {
335337
root := c.readRoot()
336338
main := gcasRead(root, c)
337339
if c.rdcssRoot(root, main, root.copyToGen(&generation{}, c)) {
338-
return newCtrie(root, c.hashFactory, true)
340+
return newCtrie(c.readRoot(), c.hashFactory, true)
339341
}
340342
}
341343
}
@@ -363,7 +365,7 @@ func (c *Ctrie) Iterator(cancel <-chan struct{}) <-chan *Entry {
363365
ch := make(chan *Entry)
364366
snapshot := c.ReadOnlySnapshot()
365367
go func() {
366-
traverse(snapshot.root, ch, cancel)
368+
snapshot.traverse(snapshot.readRoot(), ch, cancel)
367369
close(ch)
368370
}()
369371
return ch
@@ -385,13 +387,14 @@ func (c *Ctrie) Size() uint {
385387

386388
var errCanceled = errors.New("canceled")
387389

388-
func traverse(i *iNode, ch chan<- *Entry, cancel <-chan struct{}) error {
390+
func (c *Ctrie) traverse(i *iNode, ch chan<- *Entry, cancel <-chan struct{}) error {
391+
main := gcasRead(i, c)
389392
switch {
390-
case i.main.cNode != nil:
391-
for _, br := range i.main.cNode.array {
393+
case main.cNode != nil:
394+
for _, br := range main.cNode.array {
392395
switch b := br.(type) {
393396
case *iNode:
394-
if err := traverse(b, ch, cancel); err != nil {
397+
if err := c.traverse(b, ch, cancel); err != nil {
395398
return err
396399
}
397400
case *sNode:
@@ -402,8 +405,8 @@ func traverse(i *iNode, ch chan<- *Entry, cancel <-chan struct{}) error {
402405
}
403406
}
404407
}
405-
case i.main.lNode != nil:
406-
for _, e := range i.main.lNode.Map(func(sn interface{}) interface{} {
408+
case main.lNode != nil:
409+
for _, e := range main.lNode.Map(func(sn interface{}) interface{} {
407410
return sn.(*sNode).Entry
408411
}) {
409412
select {
@@ -485,7 +488,7 @@ func (c *Ctrie) iinsert(i *iNode, entry *Entry, lev uint, parent *iNode, startGe
485488
// If the branch is an I-node, then iinsert is called recursively.
486489
in := branch.(*iNode)
487490
if startGen == in.gen {
488-
return c.iinsert(in, entry, lev+w, i, i.gen)
491+
return c.iinsert(in, entry, lev+w, i, startGen)
489492
}
490493
if gcas(i, main, &mainNode{cNode: cn.renewed(startGen, c)}, c) {
491494
return c.iinsert(i, entry, lev, parent, startGen)
@@ -810,8 +813,8 @@ func gcasComplete(i *iNode, m *mainNode, ctrie *Ctrie) *mainNode {
810813
// Signals GCAS failure. Swap old value back into I-node.
811814
fn := prev.failed
812815
if atomic.CompareAndSwapPointer((*unsafe.Pointer)(unsafe.Pointer(&i.main)),
813-
unsafe.Pointer(m), unsafe.Pointer(fn.prev)) {
814-
return fn.prev
816+
unsafe.Pointer(m), unsafe.Pointer(fn)) {
817+
return fn
815818
}
816819
m = (*mainNode)(atomic.LoadPointer(
817820
(*unsafe.Pointer)(unsafe.Pointer(&i.main))))
@@ -845,7 +848,7 @@ type rdcssDescriptor struct {
845848
old *iNode
846849
expected *mainNode
847850
nv *iNode
848-
committed bool
851+
committed int32
849852
}
850853

851854
// readRoot performs a linearizable read of the Ctrie root. This operation is
@@ -878,7 +881,7 @@ func (c *Ctrie) rdcssRoot(old *iNode, expected *mainNode, nv *iNode) bool {
878881
}
879882
if c.casRoot(old, desc) {
880883
c.rdcssComplete(false)
881-
return desc.rdcss.committed
884+
return atomic.LoadInt32(&desc.rdcss.committed) == 1
882885
}
883886
return false
884887
}
@@ -909,7 +912,7 @@ func (c *Ctrie) rdcssComplete(abort bool) *iNode {
909912
if oldeMain == exp {
910913
// Commit the RDCSS.
911914
if c.casRoot(r, nv) {
912-
desc.committed = true
915+
atomic.StoreInt32(&desc.committed, 1)
913916
return nv
914917
}
915918
continue

trie/ctrie/ctrie_test.go

+41
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,47 @@ func TestConcurrency(t *testing.T) {
169169
wg.Wait()
170170
}
171171

172+
func TestConcurrency2(t *testing.T) {
173+
assert := assert.New(t)
174+
ctrie := New(nil)
175+
var wg sync.WaitGroup
176+
wg.Add(4)
177+
178+
go func() {
179+
for i := 0; i < 10000; i++ {
180+
ctrie.Insert([]byte(strconv.Itoa(i)), i)
181+
}
182+
wg.Done()
183+
}()
184+
185+
go func() {
186+
for i := 0; i < 10000; i++ {
187+
val, ok := ctrie.Lookup([]byte(strconv.Itoa(i)))
188+
if ok {
189+
assert.Equal(i, val)
190+
}
191+
}
192+
wg.Done()
193+
}()
194+
195+
go func() {
196+
for i := 0; i < 10000; i++ {
197+
ctrie.Snapshot()
198+
}
199+
wg.Done()
200+
}()
201+
202+
go func() {
203+
for i := 0; i < 10000; i++ {
204+
ctrie.ReadOnlySnapshot()
205+
}
206+
wg.Done()
207+
}()
208+
209+
wg.Wait()
210+
assert.Equal(uint(10000), ctrie.Size())
211+
}
212+
172213
func TestSnapshot(t *testing.T) {
173214
assert := assert.New(t)
174215
ctrie := New(nil)

0 commit comments

Comments
 (0)