diff --git a/integration_tests/commands/async/client_name_test.go b/integration_tests/commands/async/client_name_test.go new file mode 100644 index 000000000..cd3488cfb --- /dev/null +++ b/integration_tests/commands/async/client_name_test.go @@ -0,0 +1,116 @@ +package async + +import ( + "strings" + "testing" + + "gotest.tools/v3/assert" +) + +func TestClientSetName(t *testing.T) { + conn := getLocalConnection() + defer conn.Close() + + tests := []struct { + name string + command string + expected string + }{ + { + name: "Set valid name without spaces", + command: "K", + expected: "OK", + }, + { + name: "Set valid name with trailing space", + command: "K ", + expected: "OK", + }, + { + name: "Too many arguments for SETNAME", + command: "K K", + expected: "ERR wrong number of arguments for 'client|setname' command", + }, + { + name: "Name with space between characters", + command: "\"K K\"", + expected: "ERR Client names cannot contain spaces, newlines or special characters.", + }, + { + name: "Empty name argument", + command: " ", + expected: "ERR wrong number of arguments for 'client|setname' command", + }, + { + name: "Missing name argument", + command: "", + expected: "ERR wrong number of arguments for 'client|setname' command", + }, + { + name: "Name with newline character", + command: "\n", + expected: "ERR Client names cannot contain spaces, newlines or special characters.", + }, + { + name: "Name with valid character followed by newline", + command: "K\n", + expected: "ERR Client names cannot contain spaces, newlines or special characters.", + }, + { + name: "Name with special character", + command: "K%", + expected: "ERR Client names cannot contain spaces, newlines or special characters.", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := FireCommand(conn, "CLIENT SETNAME "+tc.command) + assert.DeepEqual(t, tc.expected, result) + }) + } +} + +func TestClientGetName(t *testing.T) { + conn := getLocalConnection() + defer conn.Close() + + tests := []struct { + name string + command string + expected string + }{ + { + name: "GetName with invalid argument", + command: "CLIENT GETNAME invalid-arg", + expected: "ERR wrong number of arguments for 'client|getname' command", + }, + { + name: "GetName with no name set", + command: "CLIENT GETNAME", + expected: "(nil)", + }, + { + name: "SetName with invalid name containing space", + command: "CLIENT SETNAME \"K K\"; CLIENT GETNAME", + expected: "(nil)", + }, + { + name: "SetName with valid name", + command: "CLIENT SETNAME K; CLIENT GETNAME", + expected: "K", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Split multiple commands (like "CLIENT SETNAME" followed by "CLIENT GETNAME") if needed + commands := strings.Split(tt.command, "; ") + var result interface{} + for _, cmd := range commands { + result = FireCommand(conn, strings.TrimSpace(cmd)) + } + assert.DeepEqual(t, tt.expected, result) + }) + } +} diff --git a/internal/comm/client.go b/internal/comm/client.go index c24736c77..69fe8fe94 100644 --- a/internal/comm/client.go +++ b/internal/comm/client.go @@ -22,6 +22,7 @@ type QwatchResponse struct { type Client struct { io.ReadWriter + Name string HTTPQwatchResponseChan chan QwatchResponse // Response channel to send back the operation response Fd int Cqueue cmd.RedisCmds diff --git a/internal/errors/migrated_errors.go b/internal/errors/migrated_errors.go index fd1b000bd..46d616aad 100644 --- a/internal/errors/migrated_errors.go +++ b/internal/errors/migrated_errors.go @@ -35,6 +35,12 @@ var ( ErrKeyDoesNotExist = errors.New("ERR could not perform this operation on a key that doesn't exist") // Error generation functions for specific error messages with dynamic parameters. + + // Indicates an subcommand for a given command. + ErrUnknownSubcommand = func(command, subcommand string) error { + return fmt.Errorf("ERR unknown subcommand '%s'. Try %s HELP.", subcommand, strings.ToUpper(command)) //nolint: stylecheck + } + ErrWrongArgumentCount = func(command string) error { return fmt.Errorf("ERR wrong number of arguments for '%s' command", strings.ToLower(command)) // Indicates an incorrect number of arguments for a given command. } diff --git a/internal/eval/commands.go b/internal/eval/commands.go index 5c8b1ea51..eda0cb96b 100644 --- a/internal/eval/commands.go +++ b/internal/eval/commands.go @@ -446,7 +446,7 @@ var ( clientCmdMeta = DiceCmdMeta{ Name: "CLIENT", Info: `This is a container command for client connection commands.`, - Eval: evalCLIENT, + Eval: nil, Arity: -2, } latencyCmdMeta = DiceCmdMeta{ diff --git a/internal/eval/constants.go b/internal/eval/constants.go index a7d924b6a..5b148f8a3 100644 --- a/internal/eval/constants.go +++ b/internal/eval/constants.go @@ -50,4 +50,6 @@ const ( FILTERS string = "FILTER" ITEMS string = "ITEMS" EXPANSION string = "EXPANSION" + GETNAME string = "GETNAME" + SETNAME string = "SETNAME" ) diff --git a/internal/eval/eval.go b/internal/eval/eval.go index 11705958b..7c77d9557 100644 --- a/internal/eval/eval.go +++ b/internal/eval/eval.go @@ -410,11 +410,6 @@ func evalINFO(args []string, store *dstore.Store) []byte { return clientio.Encode(buf.String(), false) } -// TODO: Placeholder to support monitoring -func evalCLIENT(args []string, store *dstore.Store) []byte { - return clientio.RespOK -} - // TODO: Placeholder to support monitoring func evalLATENCY(args []string, store *dstore.Store) []byte { return clientio.Encode([]string{}, false) @@ -437,6 +432,36 @@ func evalSLEEP(args []string, store *dstore.Store) []byte { return clientio.RespOK } +func evalCLIENT(args []string, client *comm.Client) []byte { + if len(args) == 0 { + return clientio.Encode(diceerrors.ErrWrongArgumentCount("CLIENT"), false) + } + + subcommand := strings.ToUpper(args[0]) + switch subcommand { + case GETNAME: + if len(args) != 1 { + return clientio.Encode(diceerrors.ErrWrongArgumentCount("CLIENT|GETNAME"), false) + } + if client.Name == utils.EmptyStr { + return clientio.RespNIL + } + return clientio.Encode(client.Name, false) + case SETNAME: + if len(args) != 2 { + return clientio.Encode(diceerrors.ErrWrongArgumentCount("CLIENT|SETNAME"), false) + } + clientName := args[1] + if containsSpacesNewlinesOrSpecialChars(clientName) { + return clientio.Encode(diceerrors.NewErrWithMessage("Client names cannot contain spaces, newlines or special characters."), false) + } + client.Name = clientName + return clientio.RespOK + default: + return clientio.Encode(diceerrors.ErrUnknownSubcommand("CLIENT", subcommand), false) + } +} + // EvalQWATCH adds the specified key to the watch list for the caller client. // Every time a key in the watch list is modified, the client will be sent a response // containing the new value of the key along with the operation that was performed on it. @@ -699,7 +724,7 @@ func evalCommand(args []string, store *dstore.Store) []byte { case Docs: return evalCommandDocs(args[1:]) default: - return diceerrors.NewErrWithFormattedMessage("unknown subcommand '%s'. Try COMMAND HELP.", subcommand) + return clientio.Encode(diceerrors.ErrUnknownSubcommand("COMMAND", subcommand), false) } } diff --git a/internal/eval/execute.go b/internal/eval/execute.go index 4824a67bb..4c44ea3f6 100644 --- a/internal/eval/execute.go +++ b/internal/eval/execute.go @@ -70,6 +70,9 @@ func (e *Eval) ExecuteCommand() *EvalResponse { switch diceCmd.Name { // Old implementation kept as it is, but we will be moving // to the new implementation soon for all commands + + case "CLIENT": + return &EvalResponse{Result: evalCLIENT(e.cmd.Args, e.client), Error: nil} case "SUBSCRIBE", "Q.WATCH": return &EvalResponse{Result: EvalQWATCH(e.cmd.Args, e.isHTTPOperation, e.isWebSocketOperation, e.client, e.store), Error: nil} case "UNSUBSCRIBE", "Q.UNWATCH": diff --git a/internal/eval/type_string.go b/internal/eval/type_string.go index 722de5cc3..2a36de200 100644 --- a/internal/eval/type_string.go +++ b/internal/eval/type_string.go @@ -2,6 +2,7 @@ package eval import ( "strconv" + "unicode" dstore "github.com/dicedb/dice/internal/object" ) @@ -17,3 +18,12 @@ func deduceTypeEncoding(v string) (o, e uint8) { } return dstore.ObjTypeString, dstore.ObjEncodingRaw } + +func containsSpacesNewlinesOrSpecialChars(s string) bool { + for _, r := range s { + if unicode.IsSpace(r) || (!unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_') { + return true + } + } + return false +} diff --git a/internal/eval/type_string_test.go b/internal/eval/type_string_test.go index 9ba597b46..9c6fdac3d 100644 --- a/internal/eval/type_string_test.go +++ b/internal/eval/type_string_test.go @@ -1,9 +1,10 @@ package eval import ( - "github.com/dicedb/dice/internal/object" "testing" + "github.com/dicedb/dice/internal/object" + "github.com/dicedb/dice/internal/server/utils" ) @@ -56,3 +57,30 @@ func TestDeduceTypeEncoding(t *testing.T) { }) } } + +func TestContainsSpacesNewlinesOrSpecialChars(t *testing.T) { + tests := []struct { + input string + expected bool + }{ + {"NoSpecialChars123", false}, + {"HelloWorld123", false}, + {"1234567890", false}, + {"Hello_World", false}, + {"", false}, + {"₹₹", true}, + {"Hello, World!", true}, + {"Hello\nWorld", true}, + {"\tTabbedText", true}, + {"NormalText!", true}, + } + + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + result := containsSpacesNewlinesOrSpecialChars(test.input) + if result != test.expected { + t.Errorf("For input '%s', expected %v but got %v", test.input, test.expected, result) + } + }) + } +}