Skip to content

Commit 90a2d09

Browse files
Backport of CSI: improve controller RPC reliability into release/1.6.x (#18015)
This pull request was automerged via backport-assistant
1 parent 3024020 commit 90a2d09

File tree

5 files changed

+178
-23
lines changed

5 files changed

+178
-23
lines changed

.changelog/17996.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
```release-note:bug
2+
csi: Fixed a bug in sending concurrent requests to CSI controller plugins by serializing them per plugin
3+
```
4+
5+
```release-note:bug
6+
csi: Fixed a bug where CSI controller requests could be sent to unhealthy plugins
7+
```
8+
9+
```release-note:bug
10+
csi: Fixed a bug where CSI controller requests could not be sent to controllers on nodes ineligible for scheduling
11+
```

nomad/client_csi_endpoint.go

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
package nomad
55

66
import (
7+
"errors"
78
"fmt"
8-
"math/rand"
9+
"sort"
910
"strings"
1011
"time"
1112

@@ -262,9 +263,9 @@ func (a *ClientCSI) clientIDsForController(pluginID string) ([]string, error) {
262263

263264
ws := memdb.NewWatchSet()
264265

265-
// note: plugin IDs are not scoped to region/DC but volumes are.
266-
// so any node we get for a controller is already in the same
267-
// region/DC for the volume.
266+
// note: plugin IDs are not scoped to region but volumes are. so any Nomad
267+
// client we get for a controller is already in the same region for the
268+
// volume.
268269
plugin, err := snap.CSIPluginByID(ws, pluginID)
269270
if err != nil {
270271
return nil, fmt.Errorf("error getting plugin: %s, %v", pluginID, err)
@@ -273,31 +274,55 @@ func (a *ClientCSI) clientIDsForController(pluginID string) ([]string, error) {
273274
return nil, fmt.Errorf("plugin missing: %s", pluginID)
274275
}
275276

276-
// iterating maps is "random" but unspecified and isn't particularly
277-
// random with small maps, so not well-suited for load balancing.
278-
// so we shuffle the keys and iterate over them.
279277
clientIDs := []string{}
280278

279+
if len(plugin.Controllers) == 0 {
280+
return nil, fmt.Errorf("failed to find instances of controller plugin %q", pluginID)
281+
}
282+
283+
var merr error
281284
for clientID, controller := range plugin.Controllers {
282285
if !controller.IsController() {
283-
// we don't have separate types for CSIInfo depending on
284-
// whether it's a controller or node. this error shouldn't
285-
// make it to production but is to aid developers during
286-
// development
286+
// we don't have separate types for CSIInfo depending on whether
287+
// it's a controller or node. this error should never make it to
288+
// production
289+
merr = errors.Join(merr, fmt.Errorf(
290+
"plugin instance %q is not a controller but was registered as one - this is always a bug", controller.AllocID))
291+
continue
292+
}
293+
294+
if !controller.Healthy {
295+
merr = errors.Join(merr, fmt.Errorf(
296+
"plugin instance %q is not healthy", controller.AllocID))
287297
continue
288298
}
299+
289300
node, err := getNodeForRpc(snap, clientID)
290-
if err == nil && node != nil && node.Ready() {
291-
clientIDs = append(clientIDs, clientID)
301+
if err != nil || node == nil {
302+
merr = errors.Join(merr, fmt.Errorf(
303+
"cannot find node %q for plugin instance %q", clientID, controller.AllocID))
304+
continue
305+
}
306+
307+
if node.Status != structs.NodeStatusReady {
308+
merr = errors.Join(merr, fmt.Errorf(
309+
"node %q for plugin instance %q is not ready", clientID, controller.AllocID))
310+
continue
292311
}
312+
313+
clientIDs = append(clientIDs, clientID)
293314
}
315+
294316
if len(clientIDs) == 0 {
295-
return nil, fmt.Errorf("failed to find clients running controller plugin %q", pluginID)
317+
return nil, fmt.Errorf("failed to find clients running controller plugin %q: %v",
318+
pluginID, merr)
296319
}
297320

298-
rand.Shuffle(len(clientIDs), func(i, j int) {
299-
clientIDs[i], clientIDs[j] = clientIDs[j], clientIDs[i]
300-
})
321+
// Many plugins don't handle concurrent requests as described in the spec,
322+
// and have undocumented expectations of using k8s-specific sidecars to
323+
// leader elect. Sort the client IDs so that we prefer sending requests to
324+
// the same controller to hack around this.
325+
clientIDs = sort.StringSlice(clientIDs)
301326

302327
return clientIDs, nil
303328
}

nomad/csi_endpoint.go

Lines changed: 69 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package nomad
55

66
import (
7+
"context"
78
"fmt"
89
"net/http"
910
"strings"
@@ -549,7 +550,9 @@ func (v *CSIVolume) controllerPublishVolume(req *structs.CSIVolumeClaimRequest,
549550
cReq.PluginID = plug.ID
550551
cResp := &cstructs.ClientCSIControllerAttachVolumeResponse{}
551552

552-
err = v.srv.RPC(method, cReq, cResp)
553+
err = v.serializedControllerRPC(plug.ID, func() error {
554+
return v.srv.RPC(method, cReq, cResp)
555+
})
553556
if err != nil {
554557
if strings.Contains(err.Error(), "FailedPrecondition") {
555558
return fmt.Errorf("%v: %v", structs.ErrCSIClientRPCRetryable, err)
@@ -586,6 +589,57 @@ func (v *CSIVolume) volAndPluginLookup(namespace, volID string) (*structs.CSIPlu
586589
return plug, vol, nil
587590
}
588591

592+
// serializedControllerRPC ensures we're only sending a single controller RPC to
593+
// a given plugin if the RPC can cause conflicting state changes.
594+
//
595+
// The CSI specification says that we SHOULD send no more than one in-flight
596+
// request per *volume* at a time, with an allowance for losing state
597+
// (ex. leadership transitions) which the plugins SHOULD handle gracefully.
598+
//
599+
// In practice many CSI plugins rely on k8s-specific sidecars for serializing
600+
// storage provider API calls globally (ex. concurrently attaching EBS volumes
601+
// to an EC2 instance results in a race for device names). So we have to be much
602+
// more conservative about concurrency in Nomad than the spec allows.
603+
func (v *CSIVolume) serializedControllerRPC(pluginID string, fn func() error) error {
604+
605+
for {
606+
v.srv.volumeControllerLock.Lock()
607+
future := v.srv.volumeControllerFutures[pluginID]
608+
if future == nil {
609+
future, futureDone := context.WithCancel(v.srv.shutdownCtx)
610+
v.srv.volumeControllerFutures[pluginID] = future
611+
v.srv.volumeControllerLock.Unlock()
612+
613+
err := fn()
614+
615+
// close the future while holding the lock and not in a defer so
616+
// that we can ensure we've cleared it from the map before allowing
617+
// anyone else to take the lock and write a new one
618+
v.srv.volumeControllerLock.Lock()
619+
futureDone()
620+
delete(v.srv.volumeControllerFutures, pluginID)
621+
v.srv.volumeControllerLock.Unlock()
622+
623+
return err
624+
} else {
625+
v.srv.volumeControllerLock.Unlock()
626+
627+
select {
628+
case <-future.Done():
629+
continue
630+
case <-v.srv.shutdownCh:
631+
// The csi_hook publish workflow on the client will retry if it
632+
// gets this error. On unpublish, we don't want to block client
633+
// shutdown so we give up on error. The new leader's
634+
// volumewatcher will iterate all the claims at startup to
635+
// detect this and mop up any claims in the NodeDetached state
636+
// (volume GC will run periodically as well)
637+
return structs.ErrNoLeader
638+
}
639+
}
640+
}
641+
}
642+
589643
// allowCSIMount is called on Job register to check mount permission
590644
func allowCSIMount(aclObj *acl.ACL, namespace string) bool {
591645
return aclObj.AllowPluginRead() &&
@@ -863,8 +917,11 @@ func (v *CSIVolume) controllerUnpublishVolume(vol *structs.CSIVolume, claim *str
863917
Secrets: vol.Secrets,
864918
}
865919
req.PluginID = vol.PluginID
866-
err = v.srv.RPC("ClientCSI.ControllerDetachVolume", req,
867-
&cstructs.ClientCSIControllerDetachVolumeResponse{})
920+
921+
err = v.serializedControllerRPC(vol.PluginID, func() error {
922+
return v.srv.RPC("ClientCSI.ControllerDetachVolume", req,
923+
&cstructs.ClientCSIControllerDetachVolumeResponse{})
924+
})
868925
if err != nil {
869926
return fmt.Errorf("could not detach from controller: %v", err)
870927
}
@@ -1139,7 +1196,9 @@ func (v *CSIVolume) deleteVolume(vol *structs.CSIVolume, plugin *structs.CSIPlug
11391196
cReq.PluginID = plugin.ID
11401197
cResp := &cstructs.ClientCSIControllerDeleteVolumeResponse{}
11411198

1142-
return v.srv.RPC(method, cReq, cResp)
1199+
return v.serializedControllerRPC(plugin.ID, func() error {
1200+
return v.srv.RPC(method, cReq, cResp)
1201+
})
11431202
}
11441203

11451204
func (v *CSIVolume) ListExternal(args *structs.CSIVolumeExternalListRequest, reply *structs.CSIVolumeExternalListResponse) error {
@@ -1286,7 +1345,9 @@ func (v *CSIVolume) CreateSnapshot(args *structs.CSISnapshotCreateRequest, reply
12861345
}
12871346
cReq.PluginID = pluginID
12881347
cResp := &cstructs.ClientCSIControllerCreateSnapshotResponse{}
1289-
err = v.srv.RPC(method, cReq, cResp)
1348+
err = v.serializedControllerRPC(pluginID, func() error {
1349+
return v.srv.RPC(method, cReq, cResp)
1350+
})
12901351
if err != nil {
12911352
multierror.Append(&mErr, fmt.Errorf("could not create snapshot: %v", err))
12921353
continue
@@ -1360,7 +1421,9 @@ func (v *CSIVolume) DeleteSnapshot(args *structs.CSISnapshotDeleteRequest, reply
13601421
cReq := &cstructs.ClientCSIControllerDeleteSnapshotRequest{ID: snap.ID}
13611422
cReq.PluginID = plugin.ID
13621423
cResp := &cstructs.ClientCSIControllerDeleteSnapshotResponse{}
1363-
err = v.srv.RPC(method, cReq, cResp)
1424+
err = v.serializedControllerRPC(plugin.ID, func() error {
1425+
return v.srv.RPC(method, cReq, cResp)
1426+
})
13641427
if err != nil {
13651428
multierror.Append(&mErr, fmt.Errorf("could not delete %q: %v", snap.ID, err))
13661429
}

nomad/csi_endpoint_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package nomad
66
import (
77
"fmt"
88
"strings"
9+
"sync"
910
"testing"
1011
"time"
1112

@@ -21,6 +22,7 @@ import (
2122
cconfig "github.com/hashicorp/nomad/client/config"
2223
cstructs "github.com/hashicorp/nomad/client/structs"
2324
"github.com/hashicorp/nomad/helper/uuid"
25+
"github.com/hashicorp/nomad/lib/lang"
2426
"github.com/hashicorp/nomad/nomad/mock"
2527
"github.com/hashicorp/nomad/nomad/state"
2628
"github.com/hashicorp/nomad/nomad/structs"
@@ -1971,3 +1973,49 @@ func TestCSI_RPCVolumeAndPluginLookup(t *testing.T) {
19711973
require.Nil(t, vol)
19721974
require.EqualError(t, err, fmt.Sprintf("volume not found: %s", id2))
19731975
}
1976+
1977+
func TestCSI_SerializedControllerRPC(t *testing.T) {
1978+
ci.Parallel(t)
1979+
1980+
srv, shutdown := TestServer(t, func(c *Config) { c.NumSchedulers = 0 })
1981+
defer shutdown()
1982+
testutil.WaitForLeader(t, srv.RPC)
1983+
1984+
var wg sync.WaitGroup
1985+
wg.Add(3)
1986+
1987+
timeCh := make(chan lang.Pair[string, time.Duration])
1988+
1989+
testFn := func(pluginID string, dur time.Duration) {
1990+
defer wg.Done()
1991+
c := NewCSIVolumeEndpoint(srv, nil)
1992+
now := time.Now()
1993+
err := c.serializedControllerRPC(pluginID, func() error {
1994+
time.Sleep(dur)
1995+
return nil
1996+
})
1997+
elapsed := time.Since(now)
1998+
timeCh <- lang.Pair[string, time.Duration]{pluginID, elapsed}
1999+
must.NoError(t, err)
2000+
}
2001+
2002+
go testFn("plugin1", 50*time.Millisecond)
2003+
go testFn("plugin2", 50*time.Millisecond)
2004+
go testFn("plugin1", 50*time.Millisecond)
2005+
2006+
totals := map[string]time.Duration{}
2007+
for i := 0; i < 3; i++ {
2008+
pair := <-timeCh
2009+
totals[pair.First] += pair.Second
2010+
}
2011+
2012+
wg.Wait()
2013+
2014+
// plugin1 RPCs should block each other
2015+
must.GreaterEq(t, 150*time.Millisecond, totals["plugin1"])
2016+
must.Less(t, 200*time.Millisecond, totals["plugin1"])
2017+
2018+
// plugin1 RPCs should not block plugin2 RPCs
2019+
must.GreaterEq(t, 50*time.Millisecond, totals["plugin2"])
2020+
must.Less(t, 100*time.Millisecond, totals["plugin2"])
2021+
}

nomad/server.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,13 @@ type Server struct {
218218
// volumeWatcher is used to release volume claims
219219
volumeWatcher *volumewatcher.Watcher
220220

221+
// volumeControllerFutures is a map of plugin IDs to pending controller RPCs. If
222+
// no RPC is pending for a given plugin, this may be nil.
223+
volumeControllerFutures map[string]context.Context
224+
225+
// volumeControllerLock synchronizes access controllerFutures map
226+
volumeControllerLock sync.Mutex
227+
221228
// keyringReplicator is used to replicate root encryption keys from the
222229
// leader
223230
keyringReplicator *KeyringReplicator
@@ -445,6 +452,7 @@ func NewServer(config *Config, consulCatalog consul.CatalogAPI, consulConfigEntr
445452
s.logger.Error("failed to create volume watcher", "error", err)
446453
return nil, fmt.Errorf("failed to create volume watcher: %v", err)
447454
}
455+
s.volumeControllerFutures = map[string]context.Context{}
448456

449457
// Start the eval broker notification system so any subscribers can get
450458
// updates when the processes SetEnabled is triggered.

0 commit comments

Comments
 (0)