Skip to content

Commit 7545330

Browse files
Add getAMI func to get AMI based on region
Signed-off-by: Carlos Eduardo Arango Gutierrez <[email protected]>
1 parent 10bdb05 commit 7545330

File tree

4 files changed

+169
-33
lines changed

4 files changed

+169
-33
lines changed

go.mod

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
module github.com/NVIDIA/holodeck
22

3-
go 1.21.5
3+
go 1.22.0
4+
45
toolchain go1.22.5
56

67
require (
78
github.com/aws/aws-sdk-go v1.55.1
9+
github.com/aws/aws-sdk-go-v2 v1.30.3
810
github.com/aws/aws-sdk-go-v2/config v1.27.27
911
github.com/aws/aws-sdk-go-v2/service/ec2 v1.171.0
1012
github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3
@@ -18,7 +20,6 @@ require (
1820
)
1921

2022
require (
21-
github.com/aws/aws-sdk-go-v2 v1.30.3 // indirect
2223
github.com/aws/aws-sdk-go-v2/credentials v1.17.27 // indirect
2324
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 // indirect
2425
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 // indirect

pkg/provider/aws/create.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,13 @@ func (a *Client) createEC2Instance(cache *AWS) error {
281281
a.log.Wg.Add(1)
282282
go a.log.Loading("Creating EC2 instance")
283283

284+
// Check if the image is provided, if not get the latest image
285+
err := a.setAMI()
286+
if err != nil {
287+
a.fail()
288+
return fmt.Errorf("error getting AMI: %w", err)
289+
}
290+
284291
instanceIn := &ec2.RunInstancesInput{
285292
ImageId: a.Spec.Image.ImageId,
286293
InstanceType: types.InstanceType(a.Spec.Instance.Type),

pkg/provider/aws/dryrun.go

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -50,35 +50,6 @@ func (a *Client) checkInstanceTypes() error {
5050
return fmt.Errorf("instance type %s is not supported in the current region %s", string(a.Spec.Instance.Type), a.Spec.Instance.Region)
5151
}
5252

53-
func (a *Client) checkImages() error {
54-
var nextToken *string
55-
56-
for {
57-
// Use the DescribeImages API to get a list of supported images in the current region
58-
resp, err := a.ec2.DescribeImages(context.TODO(), &ec2.DescribeImagesInput{
59-
NextToken: nextToken,
60-
},
61-
)
62-
if err != nil {
63-
return err
64-
}
65-
66-
for _, image := range resp.Images {
67-
if *image.ImageId == *a.Spec.Instance.Image.ImageId {
68-
return nil
69-
}
70-
}
71-
72-
if resp.NextToken != nil {
73-
nextToken = resp.NextToken
74-
} else {
75-
break
76-
}
77-
}
78-
79-
return fmt.Errorf("image %s is not supported in the current region %s", *a.Spec.Instance.Image.ImageId, a.Spec.Instance.Region)
80-
}
81-
8253
func (a *Client) DryRun() error {
8354
// Check if the desired instance type is supported in the region
8455
a.log.Wg.Add(1)
@@ -92,11 +63,11 @@ func (a *Client) DryRun() error {
9263

9364
// Check if the desired image is supported in the region
9465
a.log.Wg.Add(1)
95-
go a.log.Loading("Checking if image %s is supported in region %s", *a.Spec.Instance.Image.ImageId, a.Spec.Instance.Region)
66+
go a.log.Loading("Checking for supported image in region %s", a.Spec.Instance.Region)
9667
err = a.checkImages()
9768
if err != nil {
9869
a.fail()
99-
return fmt.Errorf("failed to get images: %v", err)
70+
return fmt.Errorf("failed to check image: %w", err)
10071
}
10172
a.done()
10273

pkg/provider/aws/image.go

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
/*
2+
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aws
18+
19+
import (
20+
"context"
21+
"errors"
22+
"fmt"
23+
"sort"
24+
25+
"github.com/aws/aws-sdk-go-v2/aws"
26+
"github.com/aws/aws-sdk-go-v2/service/ec2"
27+
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
28+
)
29+
30+
type ImageInfo struct {
31+
ImageID string
32+
CreationDate string
33+
}
34+
35+
type ByCreationDate []ImageInfo
36+
37+
func (a ByCreationDate) Len() int { return len(a) }
38+
func (a ByCreationDate) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
39+
func (a ByCreationDate) Less(i, j int) bool { return a[i].CreationDate < a[j].CreationDate }
40+
41+
func (a *Client) checkImages() error {
42+
// Check if the given image is supported in the region
43+
if a.Spec.Instance.Image.ImageId != nil {
44+
return a.assertImageIdSupported()
45+
}
46+
47+
return a.setAMI()
48+
}
49+
50+
func (a *Client) setAMI() error {
51+
// If the image ID is already set by the user, return
52+
if a.Spec.Image.ImageId != nil {
53+
return nil
54+
}
55+
56+
// Default to the official Ubuntu images in the AWS Marketplace
57+
// TODO: Add support for other image OS types
58+
awsOwner := []string{"099720109477", "679593333241"}
59+
if a.Spec.Instance.Image.OwnerId != nil {
60+
awsOwner = []string{*a.Spec.Instance.Image.OwnerId}
61+
}
62+
63+
var filterNameValue []string
64+
var filterArchitectureValue []string
65+
66+
if a.Spec.Instance.Image.Architecture != "" {
67+
switch a.Spec.Instance.Image.Architecture {
68+
case "x86_64", "amd64":
69+
filterArchitectureValue = []string{"x86_64", "amd64"}
70+
case "arm64", "aarch64":
71+
filterArchitectureValue = []string{"arm64"}
72+
default:
73+
return fmt.Errorf("invalid architecture %s", a.Spec.Instance.Image.Architecture)
74+
}
75+
}
76+
77+
for _, arch := range filterArchitectureValue {
78+
filterNameValue = append(filterNameValue, fmt.Sprintf("ubuntu/images/hvm-ssd/ubuntu-jammy-22.04-%s-server-20*", arch))
79+
}
80+
81+
filter := []types.Filter{
82+
{
83+
Name: aws.String("name"),
84+
Values: filterNameValue,
85+
},
86+
{
87+
Name: aws.String("architecture"),
88+
Values: filterArchitectureValue,
89+
},
90+
{
91+
Name: aws.String("owner-id"),
92+
Values: awsOwner,
93+
},
94+
}
95+
96+
images, err := a.describeImages(filter)
97+
if err != nil {
98+
return fmt.Errorf("failed to describe images: %w", err)
99+
}
100+
101+
if len(images) == 0 {
102+
return fmt.Errorf("no images found")
103+
}
104+
sort.Slice(images, func(i, j int) bool {
105+
return images[i].CreationDate > images[j].CreationDate
106+
})
107+
a.Spec.Image.ImageId = &images[0].ImageID
108+
109+
return nil
110+
}
111+
112+
func (a *Client) assertImageIdSupported() error {
113+
images, err := a.describeImages([]types.Filter{})
114+
if err == nil {
115+
for _, image := range images {
116+
if image.ImageID == *a.Spec.Instance.Image.ImageId {
117+
return nil
118+
}
119+
}
120+
}
121+
122+
return errors.Join(err, fmt.Errorf("image %s is not supported in the current region %s", *a.Spec.Instance.Image.ImageId, a.Spec.Instance.Region))
123+
}
124+
125+
func (a *Client) describeImages(filter []types.Filter) ([]ImageInfo, error) {
126+
var images []ImageInfo
127+
var nextToken *string
128+
129+
for {
130+
// Use the DescribeImages API to get a list of supported images in the current region
131+
resp, err := a.ec2.DescribeImages(context.TODO(), &ec2.DescribeImagesInput{
132+
NextToken: nextToken,
133+
Filters: filter,
134+
})
135+
if err != nil {
136+
return images, fmt.Errorf("failed to describe images: %w", err)
137+
}
138+
if len(resp.Images) == 0 {
139+
return images, fmt.Errorf("no images found")
140+
}
141+
142+
for _, image := range resp.Images {
143+
images = append(images, ImageInfo{
144+
ImageID: *image.ImageId,
145+
CreationDate: *image.CreationDate,
146+
})
147+
}
148+
149+
if resp.NextToken != nil {
150+
nextToken = resp.NextToken
151+
} else {
152+
break
153+
}
154+
}
155+
156+
return images, nil
157+
}

0 commit comments

Comments
 (0)