diff --git a/locate/redis/locator.go b/locate/redis/locator.go index 1cbaf43..a2a7fa6 100644 --- a/locate/redis/locator.go +++ b/locate/redis/locator.go @@ -3,7 +3,9 @@ package redis import ( "context" "fmt" + "github.com/dobyte/due/v2/cluster" "github.com/dobyte/due/v2/encoding/json" + "github.com/dobyte/due/v2/errors" "github.com/dobyte/due/v2/locate" "github.com/dobyte/due/v2/log" "github.com/go-redis/redis/v8" @@ -24,11 +26,13 @@ const name = "redis" var _ locate.Locator = &Locator{} type Locator struct { - ctx context.Context - cancel context.CancelFunc - opts *options - sfg singleflight.Group // singleFlight - watchers sync.Map + opts *options + ctx context.Context + cancel context.CancelFunc + sfg singleflight.Group + watchers sync.Map + unbindGateScript *redis.Script + unbindNodeScript *redis.Script } func NewLocator(opts ...Option) *Locator { @@ -52,8 +56,10 @@ func NewLocator(opts ...Option) *Locator { } l := &Locator{} - l.ctx, l.cancel = context.WithCancel(o.ctx) l.opts = o + l.ctx, l.cancel = context.WithCancel(o.ctx) + l.unbindGateScript = redis.NewScript(unbindGateScript) + l.unbindNodeScript = redis.NewScript(unbindNodeScript) return l } @@ -66,9 +72,10 @@ func (l *Locator) Name() string { // LocateGate 定位用户所在网关 func (l *Locator) LocateGate(ctx context.Context, uid int64) (string, error) { key := fmt.Sprintf(userGateKey, l.opts.prefix, uid) + val, err, _ := l.sfg.Do(key, func() (interface{}, error) { val, err := l.opts.client.Get(ctx, key).Result() - if err != nil && err != redis.Nil { + if err != nil && !errors.Is(err, redis.Nil) { return "", err } @@ -84,9 +91,10 @@ func (l *Locator) LocateGate(ctx context.Context, uid int64) (string, error) { // LocateNode 定位用户所在节点 func (l *Locator) LocateNode(ctx context.Context, uid int64, name string) (string, error) { key := fmt.Sprintf(userNodeKey, l.opts.prefix, uid) + val, err, _ := l.sfg.Do(key+name, func() (interface{}, error) { val, err := l.opts.client.HGet(ctx, key, name).Result() - if err != nil && err != redis.Nil { + if err != nil && !errors.Is(err, redis.Nil) { return "", err } @@ -102,14 +110,13 @@ func (l *Locator) LocateNode(ctx context.Context, uid int64, name string) (strin // BindGate 绑定网关 func (l *Locator) BindGate(ctx context.Context, uid int64, gid string) error { key := fmt.Sprintf(userGateKey, l.opts.prefix, uid) - err := l.opts.client.Set(ctx, key, gid, redis.KeepTTL).Err() - if err != nil { + + if err := l.opts.client.Set(ctx, key, gid, redis.KeepTTL).Err(); err != nil { return err } - err = l.publish(ctx, locate.BindGate, uid, gid) - if err != nil { - log.Errorf("location event publish failed: %v", err) + if err := l.broadcast(ctx, locate.BindGate, uid, gid); err != nil { + log.Errorf("location event broadcast failed: %v", err) } return nil @@ -118,14 +125,13 @@ func (l *Locator) BindGate(ctx context.Context, uid int64, gid string) error { // BindNode 绑定节点 func (l *Locator) BindNode(ctx context.Context, uid int64, name, nid string) error { key := fmt.Sprintf(userNodeKey, l.opts.prefix, uid) - err := l.opts.client.HSet(ctx, key, name, nid).Err() - if err != nil { + + if err := l.opts.client.HSet(ctx, key, name, nid).Err(); err != nil { return err } - err = l.publish(ctx, locate.BindNode, uid, nid, name) - if err != nil { - log.Errorf("location event publish failed: %v", err) + if err := l.broadcast(ctx, locate.BindNode, uid, nid, name); err != nil { + log.Errorf("location event broadcast failed: %v", err) } return nil @@ -133,82 +139,61 @@ func (l *Locator) BindNode(ctx context.Context, uid int64, name, nid string) err // UnbindGate 解绑网关 func (l *Locator) UnbindGate(ctx context.Context, uid int64, gid string) error { - oldGID, err := l.LocateGate(ctx, uid) - if err != nil { - return err - } - - if oldGID == "" || oldGID != gid { - return nil - } - key := fmt.Sprintf(userGateKey, l.opts.prefix, uid) - err = l.opts.client.Del(ctx, key).Err() + + rst, err := l.unbindGateScript.Run(ctx, l.opts.client, []string{key}, gid).StringSlice() if err != nil { return err } - err = l.publish(ctx, locate.UnbindGate, uid, gid) - if err != nil { - log.Errorf("location event publish failed: %v", err) + if rst[0] == "OK" { + if err = l.broadcast(ctx, locate.UnbindGate, uid, gid); err != nil { + log.Errorf("location event broadcast failed: %v", err) + } } return nil } // UnbindNode 解绑节点 -func (l *Locator) UnbindNode(ctx context.Context, uid int64, name string, nid string) error { - oldNID, err := l.LocateNode(ctx, uid, name) - if err != nil { - return err - } - - if oldNID == "" || oldNID != nid { - return nil - } - +func (l *Locator) UnbindNode(ctx context.Context, uid int64, name, nid string) error { key := fmt.Sprintf(userNodeKey, l.opts.prefix, uid) - err = l.opts.client.Del(ctx, key).Err() + + rst, err := l.unbindNodeScript.Run(ctx, l.opts.client, []string{key}, name, nid).StringSlice() if err != nil { return err } - err = l.publish(ctx, locate.UnbindNode, uid, nid, name) - if err != nil { - log.Errorf("location event publish failed: %v", err) + if rst[0] == "OK" { + if err = l.broadcast(ctx, locate.UnbindNode, uid, nid, name); err != nil { + log.Errorf("location event broadcast failed: %v", err) + } } return nil } -func (l *Locator) publish(ctx context.Context, typ locate.EventType, uid int64, insID string, insName ...string) error { - var ( - kind string - name string - ) +// 广播事件 +func (l *Locator) broadcast(ctx context.Context, typ locate.EventType, uid int64, insID string, insName ...string) error { + evt := &locate.Event{UID: uid, Type: typ, InsID: insID} + switch typ { case locate.BindGate, locate.UnbindGate: - kind = "gate" + evt.InsKind = cluster.Gate.String() case locate.BindNode, locate.UnbindNode: - kind = "node" + evt.InsKind = cluster.Node.String() } if len(insName) > 0 { - name = insName[0] + evt.InsName = insName[0] } - msg, err := marshal(&locate.Event{ - UID: uid, - Type: typ, - InsID: insID, - InsKind: kind, - InsName: name, - }) + msg, err := marshal(evt) if err != nil { return err } - return l.opts.client.Publish(ctx, fmt.Sprintf(clusterEventKey, l.opts.prefix, kind), msg).Err() + return l.opts.client.Publish(ctx, fmt.Sprintf(clusterEventKey, l.opts.prefix, evt.InsKind), msg).Err() } func (l *Locator) toUniqueKey(kinds ...string) string { @@ -248,13 +233,16 @@ func marshal(event *locate.Event) (string, error) { if err != nil { return "", err } + return string(buf), nil } func unmarshal(data []byte) (*locate.Event, error) { - event := &locate.Event{} - if err := json.Unmarshal(data, event); err != nil { + evt := &locate.Event{} + + if err := json.Unmarshal(data, evt); err != nil { return nil, err } - return event, nil + + return evt, nil } diff --git a/locate/redis/script.go b/locate/redis/script.go new file mode 100644 index 0000000..d996eb1 --- /dev/null +++ b/locate/redis/script.go @@ -0,0 +1,27 @@ +package redis + +// 解绑网关脚本 +const unbindGateScript = ` + local val = redis.call('GET', KEYS[1]) + + if val == '' or val ~= ARGV[1] then + return {'NO'} + end + + redis.call('DEL', KEYS[1]) + + return {'OK'} +` + +// 解绑节点脚本 +const unbindNodeScript = ` + local val = redis.call('HGET', KEYS[1], ARGV[1]) + + if val == '' or val ~= ARGV[2] then + return {'NO'} + end + + redis.call('HDEL', KEYS[1], ARGV[1]) + + return {'OK'} +` diff --git a/locate/redis/watcher.go b/locate/redis/watcher.go index 506f99a..7392606 100644 --- a/locate/redis/watcher.go +++ b/locate/redis/watcher.go @@ -65,12 +65,11 @@ func (w *watcher) Stop() error { } type watcherMgr struct { - ctx context.Context - cancel context.CancelFunc - locator *Locator - key string - sub *redis.PubSub - + ctx context.Context + cancel context.CancelFunc + locator *Locator + key string + sub *redis.PubSub rw sync.RWMutex idx int64 watchers map[int64]*watcher