diff --git a/cloudstack_loadbalancer.go b/cloudstack_loadbalancer.go index 7a3fd6b0..00f536f0 100644 --- a/cloudstack_loadbalancer.go +++ b/cloudstack_loadbalancer.go @@ -22,6 +22,7 @@ package cloudstack import ( "context" "fmt" + "net" "strconv" "strings" @@ -41,9 +42,9 @@ const ( // service to enable the proxy protocol on a CloudStack load balancer. // Note that this protocol only applies to TCP service ports and // CloudStack >= 4.6 is required for it to work. - ServiceAnnotationLoadBalancerProxyProtocol = "service.beta.kubernetes.io/cloudstack-load-balancer-proxy-protocol" - + ServiceAnnotationLoadBalancerProxyProtocol = "service.beta.kubernetes.io/cloudstack-load-balancer-proxy-protocol" ServiceAnnotationLoadBalancerLoadbalancerHostname = "service.beta.kubernetes.io/cloudstack-load-balancer-hostname" + ServiceAnnotationLoadBalancerSourceCidrs = "service.beta.kubernetes.io/cloudstack-load-balancer-source-cidrs" ) type loadBalancer struct { @@ -162,7 +163,7 @@ func (cs *CSCloud) EnsureLoadBalancer(ctx context.Context, clusterName string, s } } else { klog.V(4).Infof("Creating load balancer rule: %v", lbRuleName) - lbRule, err = lb.createLoadBalancerRule(lbRuleName, port, protocol) + lbRule, err = lb.createLoadBalancerRule(lbRuleName, port, protocol, service) if err != nil { return nil, err } @@ -596,7 +597,7 @@ func (lb *loadBalancer) updateLoadBalancerRule(lbRuleName string, protocol LoadB } // createLoadBalancerRule creates a new load balancer rule and returns it's ID. -func (lb *loadBalancer) createLoadBalancerRule(lbRuleName string, port corev1.ServicePort, protocol LoadBalancerProtocol) (*cloudstack.LoadBalancerRule, error) { +func (lb *loadBalancer) createLoadBalancerRule(lbRuleName string, port corev1.ServicePort, protocol LoadBalancerProtocol, service *corev1.Service) (*cloudstack.LoadBalancerRule, error) { p := lb.LoadBalancer.NewCreateLoadBalancerRuleParams( lb.algorithm, lbRuleName, @@ -606,12 +607,30 @@ func (lb *loadBalancer) createLoadBalancerRule(lbRuleName string, port corev1.Se p.SetNetworkid(lb.networkID) p.SetPublicipid(lb.ipAddrID) - p.SetProtocol(protocol.CSProtocol()) // Do not open the firewall implicitly, we always create explicit firewall rules p.SetOpenfirewall(false) + // Read the source CIDR annotation + sourceCIDRs, ok := service.Annotations[ServiceAnnotationLoadBalancerSourceCidrs] + var cidrList []string + if ok && sourceCIDRs != "" { + cidrList = strings.Split(sourceCIDRs, ",") + for i, cidr := range cidrList { + cidr = strings.TrimSpace(cidr) + if _, _, err := net.ParseCIDR(cidr); err != nil { + return nil, fmt.Errorf("invalid CIDR in annotation %s: %s", ServiceAnnotationLoadBalancerSourceCidrs, cidr) + } + cidrList[i] = cidr + } + } else { + cidrList = []string{defaultAllowedCIDR} + } + + // Set the CIDR list in the parameters + p.SetCidrlist(cidrList) + // Create a new load balancer rule. r, err := lb.LoadBalancer.CreateLoadBalancerRule(p) if err != nil {