Skip to content

Commit 152f39e

Browse files
authored
Add region to external models (#161)
1 parent 8f9e567 commit 152f39e

File tree

9 files changed

+109
-24
lines changed

9 files changed

+109
-24
lines changed

docs/applications/advanced/external-models.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@ $ zip -r model.zip export/estimator
1717
$ aws s3 cp model.zip s3://your-bucket/model.zip
1818
```
1919

20-
3. Specify `model_path` in an API, e.g.
20+
3. Specify `external_model` in an API, e.g.
2121

2222
```yaml
2323
- kind: api
2424
name: my-api
25-
model_path: s3://your-bucket/model.zip
25+
external_model:
26+
path: s3://your-bucket/model.zip
27+
region: us-west-2
2628
compute:
2729
replicas: 5
2830
gpu: 1

docs/applications/resources/apis.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ Serve models at scale and use them to build smarter applications.
88
- kind: api # (required)
99
name: <string> # API name (required)
1010
model_name: <string> # name of a Cortex model (required)
11-
model_path: <string> # path to a zipped model dir (optional)
11+
external_model:
12+
path: <string> # path to a zipped model dir (optional)
13+
region: <string> # region of external model
1214
compute:
1315
replicas: <int> # number of replicas to launch (default: 1)
1416
cpu: <string> # CPU request (default: Null)

examples/external-model/app.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
- kind: api
55
name: iris
6-
model_path: s3://cortex-examples/iris-model.zip
6+
external_model:
7+
path: s3://cortex-examples/iris-model.zip
8+
region: us-west-2
79
compute:
810
replicas: 1

pkg/lib/aws/s3.go

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,19 @@ func SplitS3aPath(s3aPath string) (string, string, error) {
265265
if !IsValidS3aPath(s3aPath) {
266266
return "", "", ErrorInvalidS3aPath(s3aPath)
267267
}
268-
fullPath := s3aPath[6:]
268+
fullPath := s3aPath[len("s3a://"):]
269+
slashIndex := strings.Index(fullPath, "/")
270+
bucket := fullPath[0:slashIndex]
271+
key := fullPath[slashIndex+1:]
272+
273+
return bucket, key, nil
274+
}
275+
276+
func SplitS3Path(s3Path string) (string, string, error) {
277+
if !IsValidS3Path(s3Path) {
278+
return "", "", ErrorInvalidS3aPath(s3Path)
279+
}
280+
fullPath := s3Path[len("s3://"):]
269281
slashIndex := strings.Index(fullPath, "/")
270282
bucket := fullPath[0:slashIndex]
271283
key := fullPath[slashIndex+1:]
@@ -291,6 +303,27 @@ func IsS3PrefixExternal(bucket string, prefix string, region string) (bool, erro
291303
return hasPrefix, nil
292304
}
293305

306+
func IsS3FileExternal(bucket string, key string, region string) (bool, error) {
307+
sess := session.Must(session.NewSession(&aws.Config{
308+
Region: aws.String(region),
309+
}))
310+
311+
_, err := s3.New(sess).HeadObject(&s3.HeadObjectInput{
312+
Bucket: aws.String(bucket),
313+
Key: aws.String(key),
314+
})
315+
316+
if IsNotFoundErr(err) {
317+
return false, nil
318+
}
319+
320+
if err != nil {
321+
return false, errors.Wrap(err, key)
322+
}
323+
324+
return true, nil
325+
}
326+
294327
func IsS3aPrefixExternal(s3aPath string, region string) (bool, error) {
295328
bucket, prefix, err := SplitS3aPath(s3aPath)
296329
if err != nil {

pkg/operator/api/userconfig/apis.go

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
package userconfig
1818

1919
import (
20+
"github.com/cortexlabs/cortex/pkg/lib/aws"
2021
cr "github.com/cortexlabs/cortex/pkg/lib/configreader"
2122
"github.com/cortexlabs/cortex/pkg/lib/errors"
2223
"github.com/cortexlabs/cortex/pkg/operator/api/resource"
@@ -26,10 +27,10 @@ type APIs []*API
2627

2728
type API struct {
2829
ResourceFields
29-
Model *string `json:"model" yaml:"model"`
30-
ModelPath *string `json:"model_path" yaml:"model_path"`
31-
Compute *APICompute `json:"compute" yaml:"compute"`
32-
Tags Tags `json:"tags" yaml:"tags"`
30+
Model *string `json:"model" yaml:"model"`
31+
ExternalModel *ExternalModel `json:"external_model" yaml:"external_model"`
32+
Compute *APICompute `json:"compute" yaml:"compute"`
33+
Tags Tags `json:"tags" yaml:"tags"`
3334
}
3435

3536
var apiValidation = &cr.StructValidation{
@@ -48,17 +49,40 @@ var apiValidation = &cr.StructValidation{
4849
},
4950
},
5051
{
51-
StructField: "ModelPath",
52-
StringPtrValidation: &cr.StringPtrValidation{
53-
Validator: cr.GetS3PathValidator(),
54-
},
52+
StructField: "ExternalModel",
53+
StructValidation: externalModelFieldValidation,
5554
},
5655
apiComputeFieldValidation,
5756
tagsFieldValidation,
5857
typeFieldValidation,
5958
},
6059
}
6160

61+
type ExternalModel struct {
62+
Path string `json:"path" yaml:"path"`
63+
Region string `json:"region" yaml:"region"`
64+
}
65+
66+
var externalModelFieldValidation = &cr.StructValidation{
67+
DefaultNil: true,
68+
StructFieldValidations: []*cr.StructFieldValidation{
69+
{
70+
StructField: "Path",
71+
StringValidation: &cr.StringValidation{
72+
Validator: cr.GetS3PathValidator(),
73+
Required: true,
74+
},
75+
},
76+
{
77+
StructField: "Region",
78+
StringValidation: &cr.StringValidation{
79+
Default: aws.DefaultS3Region,
80+
AllowedValues: aws.S3Regions.Slice(),
81+
},
82+
},
83+
},
84+
}
85+
6286
func (apis APIs) Validate() error {
6387
for _, api := range apis {
6488
if err := api.Validate(); err != nil {
@@ -80,12 +104,23 @@ func (apis APIs) Validate() error {
80104
}
81105

82106
func (api *API) Validate() error {
83-
if api.ModelPath == nil && api.Model == nil {
84-
return errors.Wrap(ErrorSpecifyOnlyOneMissing("model_name", "model_path"), Identify(api))
107+
if api.ExternalModel == nil && api.Model == nil {
108+
return errors.Wrap(ErrorSpecifyOnlyOneMissing(ModelKey, ExternalModelKey), Identify(api))
85109
}
86110

87-
if api.ModelPath != nil && api.Model != nil {
88-
return errors.Wrap(ErrorSpecifyOnlyOne("model_name", "model_path"), Identify(api))
111+
if api.ExternalModel != nil && api.Model != nil {
112+
return errors.Wrap(ErrorSpecifyOnlyOne(ModelKey, ExternalModelKey), Identify(api))
113+
}
114+
115+
if api.ExternalModel != nil {
116+
bucket, key, err := aws.SplitS3Path(api.ExternalModel.Path)
117+
if err != nil {
118+
return errors.Wrap(err, Identify(api), ExternalModelKey, PathKey)
119+
}
120+
121+
if ok, err := aws.IsS3FileExternal(bucket, key, api.ExternalModel.Region); err != nil || !ok {
122+
return errors.Wrap(ErrorExternalModelNotFound(api.ExternalModel.Path), Identify(api), ExternalModelKey, PathKey)
123+
}
89124
}
90125

91126
return nil

pkg/operator/api/userconfig/config_key.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ const (
9393
DatasetComputeKey = "dataset_compute"
9494

9595
// API
96-
ModelKey = "model"
97-
ModelNameKey = "model_name"
96+
ModelKey = "model"
97+
ModelNameKey = "model_name"
98+
ExternalModelKey = "external_model"
9899
)

pkg/operator/api/userconfig/errors.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ const (
7575
ErrEnvSchemaMismatch
7676
ErrExtraResourcesWithExternalAPIs
7777
ErrImplDoesNotExist
78+
ErrExternalModelNotFound
7879
)
7980

8081
var errorKinds = []string{
@@ -124,9 +125,10 @@ var errorKinds = []string{
124125
"err_env_schema_mismatch",
125126
"err_extra_resources_with_external_a_p_is",
126127
"err_impl_does_not_exist",
128+
"err_external_model_not_found",
127129
}
128130

129-
var _ = [1]int{}[int(ErrImplDoesNotExist)-(len(errorKinds)-1)] // Ensure list length matches
131+
var _ = [1]int{}[int(ErrExternalModelNotFound)-(len(errorKinds)-1)] // Ensure list length matches
130132

131133
func (t ErrorKind) String() string {
132134
return errorKinds[t]
@@ -575,3 +577,10 @@ func ErrorImplDoesNotExist(path string) error {
575577
message: fmt.Sprintf("%s: implementation file does not exist", path),
576578
}
577579
}
580+
581+
func ErrorExternalModelNotFound(path string) error {
582+
return Error{
583+
Kind: ErrExternalModelNotFound,
584+
message: fmt.Sprintf("%s: file not found or inaccessible", path),
585+
}
586+
}

pkg/operator/context/apis.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,11 @@ func getAPIs(config *userconfig.Config,
4848
buf.WriteString(model.ID)
4949
}
5050

51-
if apiConfig.ModelPath != nil {
52-
modelName = *apiConfig.ModelPath
51+
if apiConfig.ExternalModel != nil {
52+
modelName = apiConfig.ExternalModel.Path
5353
buf.WriteString(datasetVersion)
54-
buf.WriteString(*apiConfig.ModelPath)
54+
buf.WriteString(apiConfig.ExternalModel.Path)
55+
buf.WriteString(apiConfig.ExternalModel.Region)
5556
}
5657

5758
id := hash.Bytes(buf.Bytes())

pkg/workloads/tf_api/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def start(args):
402402

403403
else:
404404
if not os.path.isdir(args.model_dir):
405-
ctx.storage.download_and_unzip_external(api["model_path"], args.model_dir)
405+
ctx.storage.download_and_unzip_external(api["external_model"]["path"], args.model_dir)
406406

407407
channel = grpc.insecure_channel("localhost:" + str(args.tf_serve_port))
408408
local_cache["stub"] = prediction_service_pb2_grpc.PredictionServiceStub(channel)

0 commit comments

Comments
 (0)