diff --git a/README.md b/README.md index 237b161..2f9f1a3 100644 --- a/README.md +++ b/README.md @@ -494,25 +494,56 @@ Use function variable when you need to test a package level function: func GetIPNetDeviceByName(ifaceName string) (v4IPNet *net.IPNet, v6IPNet *net.IPNet, link *net.Interface, err error) { link, err = net.InterfaceByName(ifaceName) // External dependency makes it hard to test directly. ... + addrList, err := link.Addrs() // Calling original method of link makes it hard to test with custom output. + ... } // Declare a package level function variable. -var interfaceByName = net.InterfaceByName +var( + interfaceByName = net.InterfaceByName + netInterfaceAddrs = (*net.Interface).Addrs +) func GetIPNetDeviceByName(ifaceName string) (v4IPNet *net.IPNet, v6IPNet *net.IPNet, link *net.Interface, err error) { - link, err = interfaceByName(ifaceName) // Use the variable for actual call. + // Use the variables for actual call. + link, err = interfaceByName(ifaceName) ... + addrList, err := netInterfaceAddrs(link) + ... +} + +func mockNetInterfaceByName(testNetInterface *net.Interface, err error) func() { + // Mock the variable for test. + originalNetInterfaceByName := netInterfaceByName + netInterfaceByName = func(name string) (*net.Interface, error) { + return testNetInterface, err + } + return func() { + netInterfaceByName = originalNetInterfaceByName + } +} + +func mockNetInterfaceAddrs(testNetInterfaceAddrs []net.Addr, err error) func() { + // Mock the method for test. + originalNetInterfaceAddrs := netInterfaceAddrs + netInterfaceAddrs = func(i *net.Interface) ([]net.Addr, error) { + return testNetInterfaceAddrs, err + } + return func() { + netInterfaceAddrs = originalNetInterfaceAddrs + } } func TestGetIPNetDeviceByName(t *testing.T) { tests := []struct { - name string - interfaceName string - interface *net.Interface - wantV4IPNet *net.IPNet - wantV6IPNet *net.IPNet - wantLink *net.Interface - wantErr error + name string + interfaceName string + interface *net.Interface + interfaceAddrs []net.Addr + wantV4IPNet *net.IPNet + wantV6IPNet *net.IPNet + wantLink *net.Interface + wantErr error }{ {name: "case 1", ...}, {name: "case 2", ...}, @@ -521,13 +552,9 @@ func TestGetIPNetDeviceByName(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - // Mock the variable for test. - interfaceByName = func(name string) (*net.Interface, error) { - return tc.interface - } - defer func() { - interfaceByName = net.InterfaceByName - }() + // Set mocks, defer the returned functions for reset. + defer mockNetInterfaceByName(tc.interface, nil)() + defer mockNetInterfaceAddrs(tc.interfaceAddrs, nil)() gotV4IPNet, gotV6IPNet, gotLink, gotErr := GetIPNetDeviceByName(tc.name) assert.Equal(t, tc.wantV4IPNet, gotV4IPNet) assert.Equal(t, tc.wantV6IPNet, gotV6IPNet)