diff --git a/p2p/simulations/http_test.go b/p2p/simulations/http_test.go index cd03e600f35c..35536f69b2de 100644 --- a/p2p/simulations/http_test.go +++ b/p2p/simulations/http_test.go @@ -234,10 +234,11 @@ func (t *testService) Snapshot() ([]byte, error) { // * get and increment a counter // * subscribe to counter increment events type TestAPI struct { - state *atomic.Value - peerCount *int64 - counter int64 - feed event.Feed + state *atomic.Value + peerCount *int64 + counter int64 + activeSubscriptions int64 + feed event.Feed } func (t *TestAPI) PeerCount() int64 { @@ -273,14 +274,17 @@ func (t *TestAPI) Events(ctx context.Context) (*rpc.Subscription, error) { events := make(chan int64) sub := t.feed.Subscribe(events) defer sub.Unsubscribe() + atomic.AddInt64(&t.activeSubscriptions, 1) for { select { case event := <-events: notifier.Notify(rpcSub.ID, event) case <-sub.Err(): + atomic.AddInt64(&t.activeSubscriptions, -1) return case <-rpcSub.Err(): + atomic.AddInt64(&t.activeSubscriptions, -1) return } } @@ -289,6 +293,10 @@ func (t *TestAPI) Events(ctx context.Context) (*rpc.Subscription, error) { return rpcSub, nil } +func (t *TestAPI) GetNumActiveSubscriptions() int64 { + return atomic.LoadInt64(&t.activeSubscriptions) +} + var testServices = adapters.LifecycleConstructors{ "test": newTestService, } @@ -557,6 +565,14 @@ func TestHTTPNodeRPC(t *testing.T) { t.Fatalf("error getting node RPC client: %s", err) } + // get the number of subscriptions before subscribing to know what number to + // expect once it becomes active + var expectedActiveSubscriptions int64 + if err := rpcClient1.CallContext(ctx, &expectedActiveSubscriptions, "test_getNumActiveSubscriptions"); err != nil { + t.Fatalf("error calling RPC method: %s", err) + } + expectedActiveSubscriptions += 1 + // subscribe to events using client 1 events := make(chan int64, 1) sub, err := rpcClient1.Subscribe(ctx, "test", events, "events") @@ -565,6 +581,22 @@ func TestHTTPNodeRPC(t *testing.T) { } defer sub.Unsubscribe() + // make sure the subscription becomes active + var numActiveSubscriptions int64 + for i := 0; i < 3; i++ { + err := rpcClient1.CallContext(ctx, &numActiveSubscriptions, "test_getNumActiveSubscriptions") + if err != nil { + t.Fatalf("error calling RPC method: %s", err) + } + if numActiveSubscriptions > 0 { + break + } + time.Sleep(100 * time.Millisecond) + } + if numActiveSubscriptions != expectedActiveSubscriptions { + t.Fatalf("subscription never became active") + } + // call some RPC methods using client 2 if err := rpcClient2.CallContext(ctx, nil, "test_add", 10); err != nil { t.Fatalf("error calling RPC method: %s", err)