Skip to content

Commit 694be37

Browse files
committed
chore: refactor flag_loader
1 parent d033c59 commit 694be37

File tree

2 files changed

+254
-25
lines changed

2 files changed

+254
-25
lines changed

flag/flag.go

Lines changed: 68 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,88 @@
1+
// Package flag provides utilities for managing command flags with environment variable fallback support.
2+
// It implements a priority system: flag value > environment variable > default value.
13
package flag
24

35
import (
4-
"github.com/0xPolygon/polygon-cli/util"
6+
"fmt"
7+
"os"
8+
59
"github.com/rs/zerolog/log"
610
"github.com/spf13/cobra"
7-
"github.com/spf13/viper"
811
)
912

1013
const (
11-
// RPCURL is the standard flag name for RPC endpoint URLs.
14+
// RPCURL is the flag name for RPC URL
1215
RPCURL = "rpc-url"
13-
// PrivateKey is the standard flag name for private keys.
14-
PrivateKey = "private-key"
15-
16-
// DefaultRPCURL is the default RPC endpoint URL.
16+
// RPCURLEnvVar is the environment variable name for RPC URL
17+
RPCURLEnvVar = "ETH_RPC_URL"
18+
// DefaultRPCURL is the default RPC URL when no flag or env var is set
1719
DefaultRPCURL = "http://localhost:8545"
20+
// PrivateKey is the flag name for private key
21+
PrivateKey = "private-key"
22+
// PrivateKeyEnvVar is the environment variable name for private key
23+
PrivateKeyEnvVar = "PRIVATE_KEY"
1824
)
1925

20-
// GetFlag retrieves a flag value from Viper after binding it.
21-
// It binds the flag to enable environment variable fallback via Viper.
22-
func GetFlag(cmd *cobra.Command, flagName string) string {
23-
if err := viper.BindPFlag(flagName, cmd.Flags().Lookup(flagName)); err != nil {
24-
log.Fatal().Err(err).Str("flag", flagName).Msg("Failed to bind flag to viper")
25-
}
26-
return viper.GetString(flagName)
26+
// GetRPCURL retrieves the RPC URL from the command flag or environment variable.
27+
// Returns the flag value if set, otherwise the environment variable value, otherwise the default.
28+
// Returns empty string and nil error if none are set.
29+
func GetRPCURL(cmd *cobra.Command) (string, error) {
30+
return getValue(cmd, RPCURL, RPCURLEnvVar, false)
2731
}
2832

29-
// GetRPCURL retrieves the rpc-url flag value from Viper after binding it and validates
30-
// that it is a valid URL with a supported scheme (http, https, ws, wss).
31-
func GetRPCURL(cmd *cobra.Command) (string, error) {
32-
rpcURL := GetFlag(cmd, RPCURL)
33-
if err := util.ValidateUrl(rpcURL); err != nil {
34-
return "", err
35-
}
36-
return rpcURL, nil
33+
// GetRequiredRPCURL retrieves the RPC URL from the command flag or environment variable.
34+
// Returns an error if the value is not set or empty.
35+
func GetRequiredRPCURL(cmd *cobra.Command) (string, error) {
36+
return getValue(cmd, RPCURL, RPCURLEnvVar, true)
3737
}
3838

39-
// GetPrivateKey retrieves the private-key flag value from Viper after binding it.
40-
// This is a convenience wrapper around GetFlag for the standard private key flag.
39+
// GetPrivateKey retrieves the private key from the command flag or environment variable.
40+
// Returns the flag value if set, otherwise the environment variable value, otherwise the default.
41+
// Returns empty string and nil error if none are set.
4142
func GetPrivateKey(cmd *cobra.Command) (string, error) {
42-
return GetFlag(cmd, PrivateKey), nil
43+
return getValue(cmd, PrivateKey, PrivateKeyEnvVar, false)
44+
}
45+
46+
// GetRequiredPrivateKey retrieves the private key from the command flag or environment variable.
47+
// Returns an error if the value is not set or empty.
48+
func GetRequiredPrivateKey(cmd *cobra.Command) (string, error) {
49+
return getValue(cmd, PrivateKey, PrivateKeyEnvVar, true)
50+
}
51+
52+
// getValue retrieves a flag value with environment variable fallback support.
53+
// It implements a priority system where flag values take precedence over environment variables,
54+
// which take precedence over default values.
55+
//
56+
// Parameters:
57+
// - cmd: The cobra command to retrieve the flag from
58+
// - flagName: The name of the flag to retrieve
59+
// - envVarName: The environment variable name to check as fallback
60+
// - required: Whether the value is required (returns error if empty)
61+
//
62+
// Returns the resolved value and an error if required validation fails.
63+
func getValue(cmd *cobra.Command, flagName, envVarName string, required bool) (string, error) {
64+
flag := cmd.Flag(flagName)
65+
if flag == nil {
66+
return "", fmt.Errorf("flag %q not found", flagName)
67+
}
68+
69+
// Priority: flag > env var > default
70+
value := flag.DefValue
71+
72+
envVarValue := os.Getenv(envVarName)
73+
if envVarValue != "" {
74+
value = envVarValue
75+
}
76+
77+
if flag.Changed {
78+
value = flag.Value.String()
79+
}
80+
81+
if required && value == "" {
82+
return "", fmt.Errorf("required flag(s) %q not set", flagName)
83+
}
84+
85+
return value, nil
4386
}
4487

4588
// MarkFlagRequired marks a regular flag as required and logs a fatal error if marking fails.

flag/flag_test.go

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
package flag
2+
3+
import (
4+
"fmt"
5+
"os"
6+
"strconv"
7+
"testing"
8+
9+
"github.com/spf13/cobra"
10+
"github.com/stretchr/testify/assert"
11+
)
12+
13+
// TestValuePriority tests the priority system for flag value resolution.
14+
// It verifies that flag values take precedence over environment variables,
15+
// which take precedence over default values. It also tests the required
16+
// flag validation logic.
17+
func TestValuePriority(t *testing.T) {
18+
type testCase struct {
19+
defaultValue *int
20+
envVarValue *int
21+
flagValue *int
22+
required bool
23+
24+
expectedValue *int
25+
expectedError error
26+
}
27+
28+
testCases := []testCase{
29+
// Test case: All three sources set - flag should win
30+
{
31+
defaultValue: ptr(1),
32+
envVarValue: ptr(2),
33+
flagValue: ptr(3),
34+
expectedValue: ptr(3),
35+
required: true,
36+
expectedError: nil,
37+
},
38+
// Test case: Flag set to same value as default - flag should still win
39+
{
40+
defaultValue: ptr(1),
41+
envVarValue: ptr(2),
42+
flagValue: ptr(1),
43+
expectedValue: ptr(1),
44+
required: true,
45+
expectedError: nil,
46+
},
47+
// Test case: Default and env var set - env var should win
48+
{
49+
defaultValue: ptr(1),
50+
envVarValue: ptr(2),
51+
flagValue: nil,
52+
expectedValue: ptr(2),
53+
required: true,
54+
expectedError: nil,
55+
},
56+
// Test case: Default and flag set - flag should win
57+
{
58+
defaultValue: ptr(1),
59+
envVarValue: nil,
60+
flagValue: ptr(3),
61+
expectedValue: ptr(3),
62+
required: true,
63+
expectedError: nil,
64+
},
65+
// Test case: Env var and flag set - flag should win
66+
{
67+
defaultValue: nil,
68+
envVarValue: ptr(2),
69+
flagValue: ptr(3),
70+
expectedValue: ptr(3),
71+
required: true,
72+
expectedError: nil,
73+
},
74+
// Test case: Only flag set
75+
{
76+
defaultValue: nil,
77+
envVarValue: nil,
78+
flagValue: ptr(3),
79+
expectedValue: ptr(3),
80+
required: true,
81+
expectedError: nil,
82+
},
83+
// Test case: Only default set (non-required)
84+
{
85+
defaultValue: ptr(1),
86+
envVarValue: nil,
87+
flagValue: nil,
88+
expectedValue: ptr(1),
89+
required: false,
90+
expectedError: nil,
91+
},
92+
// Test case: Only default set (required) - default should satisfy requirement
93+
{
94+
defaultValue: ptr(1),
95+
envVarValue: nil,
96+
flagValue: nil,
97+
expectedValue: ptr(1),
98+
required: true,
99+
expectedError: nil,
100+
},
101+
// Test case: Only env var set
102+
{
103+
defaultValue: nil,
104+
envVarValue: ptr(2),
105+
flagValue: nil,
106+
expectedValue: ptr(2),
107+
required: true,
108+
expectedError: nil,
109+
},
110+
// Test case: Nothing set (non-required) - should return empty
111+
{
112+
defaultValue: nil,
113+
envVarValue: nil,
114+
flagValue: nil,
115+
expectedValue: nil,
116+
required: false,
117+
expectedError: nil,
118+
},
119+
// Test case: Nothing set (required) - should return error
120+
{
121+
defaultValue: nil,
122+
envVarValue: nil,
123+
flagValue: nil,
124+
expectedValue: nil,
125+
required: true,
126+
expectedError: fmt.Errorf("required flag(s) \"flag\" not set"),
127+
},
128+
}
129+
130+
for _, tc := range testCases {
131+
var value *int
132+
cmd := &cobra.Command{
133+
Use: "test",
134+
PersistentPreRun: func(cmd *cobra.Command, args []string) {
135+
valueStr, err := getValue(cmd, "flag", "FLAG", tc.required)
136+
if tc.expectedError != nil {
137+
assert.EqualError(t, err, tc.expectedError.Error())
138+
return
139+
}
140+
assert.NoError(t, err)
141+
if valueStr != "" {
142+
valueInt, err := strconv.Atoi(valueStr)
143+
assert.NoError(t, err)
144+
value = &valueInt
145+
}
146+
},
147+
Run: func(cmd *cobra.Command, args []string) {
148+
if tc.expectedValue != nil {
149+
assert.NotNil(t, value)
150+
if value != nil {
151+
assert.Equal(t, *tc.expectedValue, *value)
152+
}
153+
} else {
154+
assert.Nil(t, value)
155+
}
156+
},
157+
}
158+
if tc.defaultValue != nil {
159+
cmd.Flags().Int("flag", *tc.defaultValue, "flag")
160+
} else {
161+
cmd.Flags().String("flag", "", "flag")
162+
}
163+
164+
os.Unsetenv("FLAG")
165+
if tc.envVarValue != nil {
166+
v := strconv.Itoa(*tc.envVarValue)
167+
os.Setenv("FLAG", v)
168+
}
169+
170+
if tc.flagValue != nil {
171+
v := strconv.Itoa(*tc.flagValue)
172+
cmd.SetArgs([]string{"--flag", v})
173+
}
174+
175+
err := cmd.Execute()
176+
assert.Nil(t, err)
177+
}
178+
179+
}
180+
181+
// ptr is a helper function to create a pointer to a value.
182+
// This is useful for test cases where we need to distinguish between
183+
// nil (not set) and a zero value (explicitly set to 0).
184+
func ptr[T any](v T) *T {
185+
return &v
186+
}

0 commit comments

Comments
 (0)