Skip to content

Commit c62f46b

Browse files
committed
WIP: Move Run loop body into own function to recover panic
1 parent 5a8f4f4 commit c62f46b

File tree

2 files changed

+158
-64
lines changed

2 files changed

+158
-64
lines changed

stm/locking_transaction.go

Lines changed: 82 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -239,95 +239,113 @@ func RunInTransaction(fn TxFn) (interface{}, error) {
239239
func (tx *Tx) Run(fn TxFn) (interface{}, error) {
240240

241241
done := false
242-
var ret interface{}
243242
locked := make([]*Ref, 0, 10)
244243
// notify := make([]*Notify)
245244

246245
defer func() {
247246
for k := len(locked) - 1; k >= 0; k-- {
248247
locked[k].exitWriteLock()
249248
}
249+
250+
locked = nil
251+
for r, _ := range tx.ensures {
252+
r.exitReadLock()
253+
}
254+
tx.ensures = nil
255+
if done {
256+
tx.Stop(txCommitted)
257+
} else {
258+
tx.Stop(txRetry)
259+
}
260+
if done {
261+
// do notifies and agent actions, if we every implement
262+
}
263+
250264
}()
251265

252-
locked = nil
253-
for r, _ := range tx.ensures {
254-
r.exitReadLock()
255-
}
256-
tx.ensures = nil
257-
if done {
258-
tx.Stop(txCommitted)
259-
} else {
260-
tx.Stop(txRetry)
261-
}
262-
if done {
263-
// do notifies and agent actions, if we every implement
266+
for i := 0; !done && i < retryLimit; i++ {
267+
ret, err := tx.tryRun(i, fn, &locked)
268+
if err == nil {
269+
done = true
270+
return ret, nil
271+
}
264272
}
265273

266-
for i := 0; !done && i < retryLimit; i++ {
274+
return nil, errors.New("Transaction failed after reaching retry limit")
275+
}
276+
277+
// One iteration of the Run loop
278+
// Split out so that we can catch a retry panic
279+
func (tx *Tx) tryRun(i int, fn TxFn, locked *[]*Ref) (ret interface{}, err error) {
267280

268-
tx.getReadPoint()
269-
if i == 0 {
270-
tx.startPoint = tx.readPoint
271-
tx.startTime = time.Now()
281+
ret, err = nil, nil
282+
283+
defer func() {
284+
r := recover()
285+
if r == retryError {
286+
ret, err = nil, retryError
287+
} else if r != nil {
288+
panic(r)
272289
}
290+
}()
273291

274-
tx.info = newTxInfo(txRunning, tx.startPoint)
275-
ret = fn(tx)
292+
tx.getReadPoint()
293+
if i == 0 {
294+
tx.startPoint = tx.readPoint
295+
tx.startTime = time.Now()
296+
}
276297

277-
// make sure no one has killed us before this point, and can't from now on
278-
if atomic.CompareAndSwapUint32(&tx.info.status, txRunning, txCommitting) {
279-
for r, calls := range tx.commutes {
280-
if _, ok := tx.sets[r]; ok {
281-
continue
282-
}
283-
_, wasEnsured := tx.ensures[r]
284-
tx.releaseIfEnsured(r)
285-
tryWriteLock(r)
286-
locked = append(locked, r)
287-
if wasEnsured && r.currValPoint() > tx.readPoint {
288-
panic(retryError)
289-
}
298+
tx.info = newTxInfo(txRunning, tx.startPoint)
299+
ret = fn(tx)
290300

291-
refInfo := r.tinfo
292-
if refInfo != nil && refInfo != tx.info && refInfo.isRunning() {
293-
if !tx.barge(refInfo) {
294-
panic(retryError)
295-
}
296-
}
297-
val := r.tryGetVal()
298-
tx.vals[r] = val
299-
for _, call := range calls {
300-
tx.vals[r] = call.fn(tx.vals[r], call.args...)
301-
}
301+
// make sure no one has killed us before this point, and can't from now on
302+
if atomic.CompareAndSwapUint32(&tx.info.status, txRunning, txCommitting) {
303+
for r, calls := range tx.commutes {
304+
if _, ok := tx.sets[r]; ok {
305+
continue
302306
}
303-
304-
for r, _ := range tx.sets {
305-
tryWriteLock(r)
306-
locked = append(locked, r)
307+
_, wasEnsured := tx.ensures[r]
308+
tx.releaseIfEnsured(r)
309+
tryWriteLock(r)
310+
*locked = append(*locked, r)
311+
if wasEnsured && r.currValPoint() > tx.readPoint {
312+
panic(retryError)
307313
}
308314

309-
// if we do validations for refs, it goes here
310-
311-
// at this point,
312-
// all values are calculated,
313-
// all refs to be written are locked
314-
// no more client code to be called
315-
commitPoint := getCommitPoint()
316-
for r, newV := range tx.vals {
317-
//oldV := r.tryGetVal()
318-
r.setValue(newV, commitPoint)
319-
// todo: call notifies
315+
refInfo := r.tinfo
316+
if refInfo != nil && refInfo != tx.info && refInfo.isRunning() {
317+
if !tx.barge(refInfo) {
318+
panic(retryError)
319+
}
320+
}
321+
val := r.tryGetVal()
322+
tx.vals[r] = val
323+
for _, call := range calls {
324+
tx.vals[r] = call.fn(tx.vals[r], call.args...)
320325
}
321-
done = true
322-
atomic.StoreUint32(&tx.info.status, txCommitted)
323326
}
324327

325-
if done {
326-
return ret, nil
328+
for r, _ := range tx.sets {
329+
tryWriteLock(r)
330+
*locked = append(*locked, r)
327331
}
328332

333+
// if we do validations for refs, it goes here
334+
335+
// at this point,
336+
// all values are calculated,
337+
// all refs to be written are locked
338+
// no more client code to be called
339+
commitPoint := getCommitPoint()
340+
for r, newV := range tx.vals {
341+
//oldV := r.tryGetVal()
342+
r.setValue(newV, commitPoint)
343+
// todo: call notifies
344+
}
345+
atomic.StoreUint32(&tx.info.status, txCommitted)
329346
}
330-
return nil, errors.New("Transaction failed after reaching retry limit")
347+
348+
return
331349
}
332350

333351
// Get the value of a Ref (most recently sent in this transaction or value prior to entering)

stm/tx_test.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
package stm
66

77
import (
8+
"runtime"
9+
"sync"
810
"testing"
11+
"time"
912
)
1013

1114
func TestSimpleTxCall(t *testing.T) {
@@ -225,5 +228,78 @@ func TestSimpleCommute(t *testing.T) {
225228
if save2 != init1+add1+add2 {
226229
t.Errorf("Expected save to have value %v, got %v", init1+add1+add2, save2)
227230
}
231+
}
232+
233+
func TestSimpleInterference(t *testing.T) {
234+
235+
runtime.GOMAXPROCS(4)
236+
237+
init1 := 12
238+
val1 := 100
239+
val2 := 200
240+
add1 := 1
241+
242+
r1 := NewRef(init1)
243+
ret := 300
244+
245+
fc := func(old interface{}, args ...interface{}) interface{} {
246+
return old.(int) + args[0].(int)
247+
}
248+
249+
var done sync.WaitGroup
250+
done.Add(2)
251+
252+
var once sync.Once
253+
254+
nfEnter, nfExit := 0, 0
255+
ngEnter, ngExit := 0, 0
256+
257+
ch := make(chan bool, 1)
258+
259+
signalG := func() {
260+
ch <- true
261+
}
262+
263+
f := func(tx *Tx) interface{} {
264+
t.Logf("F: Entering")
265+
nfEnter++
266+
r1.Set(tx, val1)
267+
t.Logf("F: between sets")
268+
once.Do(signalG)
269+
time.Sleep(20 * time.Nanosecond)
270+
r1.Set(tx, val2)
271+
t.Logf("F: after sets")
272+
nfExit++
273+
return ret
274+
}
275+
276+
go func() {
277+
defer done.Done()
278+
t.Logf("F: before TX")
279+
RunInTransaction(f)
280+
t.Logf("F: after TX")
281+
}()
282+
283+
g := func(tx *Tx) interface{} {
284+
t.Logf("G: entering")
285+
ngEnter++
286+
r1.Commute(tx, fc, add1)
287+
t.Logf("G: after commute")
288+
ngExit++
289+
return ret
290+
}
291+
292+
go func() {
293+
defer done.Done()
294+
<-ch
295+
t.Logf("G: before TX")
296+
RunInTransaction(g)
297+
t.Logf("G: after TX")
298+
}()
299+
300+
time.Sleep(time.Nanosecond)
301+
done.Wait()
302+
303+
t.Errorf("%v %v %v %v %v", ngEnter, ngExit, r1.Deref(nil), nfEnter, nfExit)
228304

229305
}

0 commit comments

Comments
 (0)