Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions pkg/controllers/multicluster/serviceexport_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,21 @@ func (r *ServiceExportReconciler) Reconcile(ctx context.Context, req ctrl.Reques
}
}

clusterProperties, err := r.ClusterUtils.GetClusterProperties(ctx)
if err != nil {
r.Log.Error(err, "unable to retrieve ClusterId and ClusterSetId")
return ctrl.Result{}, err
}

// Check if the service export is marked to be deleted
if isServiceExportMarkedForDelete {
return r.handleDelete(ctx, &serviceExport)
return r.handleDelete(ctx, clusterProperties.ClusterId(), &serviceExport)
}

return r.handleUpdate(ctx, &serviceExport, &service)
return r.handleUpdate(ctx, clusterProperties.ClusterId(), &serviceExport, &service)
}

func (r *ServiceExportReconciler) handleUpdate(ctx context.Context, serviceExport *multiclusterv1alpha1.ServiceExport, service *v1.Service) (ctrl.Result, error) {
func (r *ServiceExportReconciler) handleUpdate(ctx context.Context, clusterId string, serviceExport *multiclusterv1alpha1.ServiceExport, service *v1.Service) (ctrl.Result, error) {
err := r.addFinalizerAndOwnerRef(ctx, serviceExport, service)
if err != nil {
return ctrl.Result{}, err
Expand All @@ -114,7 +120,7 @@ func (r *ServiceExportReconciler) handleUpdate(ctx context.Context, serviceExpor

// Compute diff between Cloud Map and K8s endpoints, and apply changes
plan := model.Plan{
Current: cmService.Endpoints,
Current: cmService.GetEndpoints(clusterId),
Desired: endpoints,
}
changes := plan.CalculateChanges()
Expand Down Expand Up @@ -186,7 +192,7 @@ func (r *ServiceExportReconciler) createOrGetCloudMapService(ctx context.Context
return cmService, nil
}

func (r *ServiceExportReconciler) handleDelete(ctx context.Context, serviceExport *multiclusterv1alpha1.ServiceExport) (ctrl.Result, error) {
func (r *ServiceExportReconciler) handleDelete(ctx context.Context, clusterId string, serviceExport *multiclusterv1alpha1.ServiceExport) (ctrl.Result, error) {
if controllerutil.ContainsFinalizer(serviceExport, ServiceExportFinalizer) {
r.Log.Info("removing service export", "namespace", serviceExport.Namespace, "name", serviceExport.Name)

Expand All @@ -196,7 +202,7 @@ func (r *ServiceExportReconciler) handleDelete(ctx context.Context, serviceExpor
return ctrl.Result{}, err
}
if cmService != nil {
if err := r.CloudMap.DeleteEndpoints(ctx, cmService.Namespace, cmService.Name, cmService.Endpoints); err != nil {
if err := r.CloudMap.DeleteEndpoints(ctx, cmService.Namespace, cmService.Name, cmService.GetEndpoints(clusterId)); err != nil {
r.Log.Error(err, "error deleting Endpoints from Cloud Map", "namespace", cmService.Namespace, "name", cmService.Name)
return ctrl.Result{}, err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func TestServiceExportReconciler_Reconcile_DeleteExistingService(t *testing.T) {
Return(test.GetTestService(), nil)
// call to delete the endpoint in the cloudmap
mock.EXPECT().DeleteEndpoints(gomock.Any(), test.HttpNsName, test.SvcName,
[]*model.Endpoint{test.GetTestEndpoint1(), test.GetTestEndpoint2()}).Return(nil).Times(1)
test.GetTestService().GetEndpoints(test.ClusterId1)).Return(nil).Times(1)

request := ctrl.Request{
NamespacedName: types.NamespacedName{
Expand Down Expand Up @@ -178,8 +178,6 @@ func TestServiceExportReconciler_Reconcile_NoClusterProperty(t *testing.T) {

mockSDClient := cloudmapMock.NewMockServiceDiscoveryClient(mockController)

mockSDClient.EXPECT().GetService(gomock.Any(), test.HttpNsName, test.SvcName).Return(test.GetTestService(), nil)

reconciler := getServiceExportReconciler(t, mockSDClient, fakeClient)

request := ctrl.Request{
Expand Down
9 changes: 9 additions & 0 deletions pkg/model/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,15 @@ func ConvertNamespaceType(nsType types.NamespaceType) (namespaceType NamespaceTy
}
}

func (svc *Service) GetEndpoints(clusterId string) (endpts []*Endpoint) {
for _, endpt := range svc.Endpoints {
if endpt.ClusterId == clusterId {
endpts = append(endpts, endpt)
}
}
return endpts
}

func (namespaceType *NamespaceType) IsUnsupported() bool {
return *namespaceType == UnsupportedNamespaceType
}
Expand Down
71 changes: 71 additions & 0 deletions pkg/model/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@ import (
"testing"

"github.com/aws/aws-sdk-go-v2/service/servicediscovery/types"
"github.com/google/go-cmp/cmp"
)

var instId = "my-instance"
var ip = "192.168.0.1"
var clusterId = "test-mcs-clusterId"
var clusterId2 = "test-mcs-clusterid-2"
var clusterId3 = "test-mcs-clusterid-3"
var namespaceName = "test-mcs-namespace"
var serviceName = "test-mcs-service"
var clusterSetId = "test-mcs-clusterSetId"
var serviceType = ClusterSetIPType.String()
var svcExportCreationTimestamp int64 = 1640995200000
Expand Down Expand Up @@ -333,3 +338,69 @@ func TestEndpoint_Equals(t *testing.T) {
})
}
}

func TestGetEndpoints(t *testing.T) {
firstEndpoint := Endpoint{
Id: instId + "-1",
IP: ip,
ServicePort: Port{
Port: 80,
},
ClusterId: clusterId,
}
secondEndpoint := Endpoint{
Id: instId + "2",
IP: ip,
ServicePort: Port{
Port: 80,
Name: "",
},
ClusterId: clusterId2,
}
thirdEndpoint := Endpoint{
Id: instId + "3",
IP: ip,
ServicePort: Port{
Port: 80,
Name: "",
},
ClusterId: clusterId2,
}

svc := Service{
Namespace: namespaceName,
Name: serviceName,
Endpoints: []*Endpoint{
&firstEndpoint, &secondEndpoint, &thirdEndpoint,
},
}

tests := []struct {
name string
x string
wantEndpts []*Endpoint
}{
{
name: "return-first-endpoint",
x: clusterId,
wantEndpts: []*Endpoint{&firstEndpoint},
},
{
name: "return-two-endpoints",
x: clusterId2,
wantEndpts: []*Endpoint{&secondEndpoint, &thirdEndpoint},
},
{
name: "return-nil",
x: clusterId3,
wantEndpts: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if gotEndpts := svc.GetEndpoints(tt.x); !cmp.Equal(gotEndpts, tt.wantEndpts) {
t.Errorf("Equals() = %v, Want = %v", gotEndpts, tt.wantEndpts)
}
})
}
}