Skip to content

Commit

Permalink
Fix bug of the redis locator
Browse files Browse the repository at this point in the history
  • Loading branch information
dobyte committed Feb 24, 2025
1 parent fd8775c commit 7cfd0bb
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 70 deletions.
116 changes: 52 additions & 64 deletions locate/redis/locator.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package redis
import (
"context"
"fmt"
"github.com/dobyte/due/v2/cluster"
"github.com/dobyte/due/v2/encoding/json"
"github.com/dobyte/due/v2/errors"
"github.com/dobyte/due/v2/locate"
"github.com/dobyte/due/v2/log"
"github.com/go-redis/redis/v8"
Expand All @@ -24,11 +26,13 @@ const name = "redis"
var _ locate.Locator = &Locator{}

type Locator struct {
ctx context.Context
cancel context.CancelFunc
opts *options
sfg singleflight.Group // singleFlight
watchers sync.Map
opts *options
ctx context.Context
cancel context.CancelFunc
sfg singleflight.Group
watchers sync.Map
unbindGateScript *redis.Script
unbindNodeScript *redis.Script
}

func NewLocator(opts ...Option) *Locator {
Expand All @@ -52,8 +56,10 @@ func NewLocator(opts ...Option) *Locator {
}

l := &Locator{}
l.ctx, l.cancel = context.WithCancel(o.ctx)
l.opts = o
l.ctx, l.cancel = context.WithCancel(o.ctx)
l.unbindGateScript = redis.NewScript(unbindGateScript)
l.unbindNodeScript = redis.NewScript(unbindNodeScript)

return l
}
Expand All @@ -66,9 +72,10 @@ func (l *Locator) Name() string {
// LocateGate 定位用户所在网关
func (l *Locator) LocateGate(ctx context.Context, uid int64) (string, error) {
key := fmt.Sprintf(userGateKey, l.opts.prefix, uid)

val, err, _ := l.sfg.Do(key, func() (interface{}, error) {
val, err := l.opts.client.Get(ctx, key).Result()
if err != nil && err != redis.Nil {
if err != nil && !errors.Is(err, redis.Nil) {
return "", err
}

Expand All @@ -84,9 +91,10 @@ func (l *Locator) LocateGate(ctx context.Context, uid int64) (string, error) {
// LocateNode 定位用户所在节点
func (l *Locator) LocateNode(ctx context.Context, uid int64, name string) (string, error) {
key := fmt.Sprintf(userNodeKey, l.opts.prefix, uid)

val, err, _ := l.sfg.Do(key+name, func() (interface{}, error) {
val, err := l.opts.client.HGet(ctx, key, name).Result()
if err != nil && err != redis.Nil {
if err != nil && !errors.Is(err, redis.Nil) {
return "", err
}

Expand All @@ -102,14 +110,13 @@ func (l *Locator) LocateNode(ctx context.Context, uid int64, name string) (strin
// BindGate 绑定网关
func (l *Locator) BindGate(ctx context.Context, uid int64, gid string) error {
key := fmt.Sprintf(userGateKey, l.opts.prefix, uid)
err := l.opts.client.Set(ctx, key, gid, redis.KeepTTL).Err()
if err != nil {

if err := l.opts.client.Set(ctx, key, gid, redis.KeepTTL).Err(); err != nil {
return err
}

err = l.publish(ctx, locate.BindGate, uid, gid)
if err != nil {
log.Errorf("location event publish failed: %v", err)
if err := l.broadcast(ctx, locate.BindGate, uid, gid); err != nil {
log.Errorf("location event broadcast failed: %v", err)
}

return nil
Expand All @@ -118,97 +125,75 @@ func (l *Locator) BindGate(ctx context.Context, uid int64, gid string) error {
// BindNode 绑定节点
func (l *Locator) BindNode(ctx context.Context, uid int64, name, nid string) error {
key := fmt.Sprintf(userNodeKey, l.opts.prefix, uid)
err := l.opts.client.HSet(ctx, key, name, nid).Err()
if err != nil {

if err := l.opts.client.HSet(ctx, key, name, nid).Err(); err != nil {
return err
}

err = l.publish(ctx, locate.BindNode, uid, nid, name)
if err != nil {
log.Errorf("location event publish failed: %v", err)
if err := l.broadcast(ctx, locate.BindNode, uid, nid, name); err != nil {
log.Errorf("location event broadcast failed: %v", err)
}

return nil
}

// UnbindGate 解绑网关
func (l *Locator) UnbindGate(ctx context.Context, uid int64, gid string) error {
oldGID, err := l.LocateGate(ctx, uid)
if err != nil {
return err
}

if oldGID == "" || oldGID != gid {
return nil
}

key := fmt.Sprintf(userGateKey, l.opts.prefix, uid)
err = l.opts.client.Del(ctx, key).Err()

rst, err := l.unbindGateScript.Run(ctx, l.opts.client, []string{key}, gid).StringSlice()
if err != nil {
return err
}

err = l.publish(ctx, locate.UnbindGate, uid, gid)
if err != nil {
log.Errorf("location event publish failed: %v", err)
if rst[0] == "OK" {
if err = l.broadcast(ctx, locate.UnbindGate, uid, gid); err != nil {
log.Errorf("location event broadcast failed: %v", err)
}
}

return nil
}

// UnbindNode 解绑节点
func (l *Locator) UnbindNode(ctx context.Context, uid int64, name string, nid string) error {
oldNID, err := l.LocateNode(ctx, uid, name)
if err != nil {
return err
}

if oldNID == "" || oldNID != nid {
return nil
}

func (l *Locator) UnbindNode(ctx context.Context, uid int64, name, nid string) error {
key := fmt.Sprintf(userNodeKey, l.opts.prefix, uid)
err = l.opts.client.Del(ctx, key).Err()

rst, err := l.unbindNodeScript.Run(ctx, l.opts.client, []string{key}, name, nid).StringSlice()
if err != nil {
return err
}

err = l.publish(ctx, locate.UnbindNode, uid, nid, name)
if err != nil {
log.Errorf("location event publish failed: %v", err)
if rst[0] == "OK" {
if err = l.broadcast(ctx, locate.UnbindNode, uid, nid, name); err != nil {
log.Errorf("location event broadcast failed: %v", err)
}
}

return nil
}

func (l *Locator) publish(ctx context.Context, typ locate.EventType, uid int64, insID string, insName ...string) error {
var (
kind string
name string
)
// 广播事件
func (l *Locator) broadcast(ctx context.Context, typ locate.EventType, uid int64, insID string, insName ...string) error {
evt := &locate.Event{UID: uid, Type: typ, InsID: insID}

switch typ {
case locate.BindGate, locate.UnbindGate:
kind = "gate"
evt.InsKind = cluster.Gate.String()
case locate.BindNode, locate.UnbindNode:
kind = "node"
evt.InsKind = cluster.Node.String()
}

if len(insName) > 0 {
name = insName[0]
evt.InsName = insName[0]
}

msg, err := marshal(&locate.Event{
UID: uid,
Type: typ,
InsID: insID,
InsKind: kind,
InsName: name,
})
msg, err := marshal(evt)
if err != nil {
return err
}

return l.opts.client.Publish(ctx, fmt.Sprintf(clusterEventKey, l.opts.prefix, kind), msg).Err()
return l.opts.client.Publish(ctx, fmt.Sprintf(clusterEventKey, l.opts.prefix, evt.InsKind), msg).Err()
}

func (l *Locator) toUniqueKey(kinds ...string) string {
Expand Down Expand Up @@ -248,13 +233,16 @@ func marshal(event *locate.Event) (string, error) {
if err != nil {
return "", err
}

return string(buf), nil
}

func unmarshal(data []byte) (*locate.Event, error) {
event := &locate.Event{}
if err := json.Unmarshal(data, event); err != nil {
evt := &locate.Event{}

if err := json.Unmarshal(data, evt); err != nil {
return nil, err
}
return event, nil

return evt, nil
}
27 changes: 27 additions & 0 deletions locate/redis/script.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package redis

// 解绑网关脚本
const unbindGateScript = `
local val = redis.call('GET', KEYS[1])
if val == '' or val ~= ARGV[1] then
return {'NO'}
end
redis.call('DEL', KEYS[1])
return {'OK'}
`

// 解绑节点脚本
const unbindNodeScript = `
local val = redis.call('HGET', KEYS[1], ARGV[1])
if val == '' or val ~= ARGV[2] then
return {'NO'}
end
redis.call('HDEL', KEYS[1], ARGV[1])
return {'OK'}
`
11 changes: 5 additions & 6 deletions locate/redis/watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,11 @@ func (w *watcher) Stop() error {
}

type watcherMgr struct {
ctx context.Context
cancel context.CancelFunc
locator *Locator
key string
sub *redis.PubSub

ctx context.Context
cancel context.CancelFunc
locator *Locator
key string
sub *redis.PubSub
rw sync.RWMutex
idx int64
watchers map[int64]*watcher
Expand Down

0 comments on commit 7cfd0bb

Please sign in to comment.