Skip to content

Commit 4146e2f

Browse files
1vndeliahu
authored andcommitted
External model inputs update (#159)
1 parent deb8360 commit 4146e2f

File tree

31 files changed

+486
-200
lines changed

31 files changed

+486
-200
lines changed

cli/cmd/get.go

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,6 @@ func describeAPI(name string, resourcesRes *schema.GetResourcesResponse) (string
382382

383383
ctx := resourcesRes.Context
384384
api := ctx.APIs[name]
385-
model := ctx.Models[api.ModelName]
386385

387386
var staleReplicas int32
388387
var ctxAPIStatus *resource.APIStatus
@@ -412,26 +411,29 @@ func describeAPI(name string, resourcesRes *schema.GetResourcesResponse) (string
412411
}
413412

414413
out += titleStr("Endpoint")
415-
resIDs := strset.New()
416-
combinedInput := []interface{}{model.Input, model.TrainingInput}
417-
for _, res := range ctx.ExtractCortexResources(combinedInput, resource.ConstantType, resource.RawColumnType, resource.AggregateType, resource.TransformedColumnType) {
418-
resIDs.Add(res.GetID())
419-
resIDs.Merge(ctx.AllComputedResourceDependencies(res.GetID()))
420-
}
421-
var samplePlaceholderFields []string
422-
for rawColumnName, rawColumn := range ctx.RawColumns {
423-
if resIDs.Has(rawColumn.GetID()) {
424-
fieldStr := fmt.Sprintf("\"%s\": %s", rawColumnName, rawColumn.GetColumnType().JSONPlaceholder())
425-
samplePlaceholderFields = append(samplePlaceholderFields, fieldStr)
426-
}
427-
}
428-
sort.Strings(samplePlaceholderFields)
429-
samplesPlaceholderStr := `{ "samples": [ { ` + strings.Join(samplePlaceholderFields, ", ") + " } ] }"
430414
out += "URL: " + urls.Join(resourcesRes.APIsBaseURL, anyAPIStatus.Path) + "\n"
431415
out += "Method: POST\n"
432416
out += `Header: "Content-Type: application/json"` + "\n"
433-
out += "Payload: " + samplesPlaceholderStr + "\n"
434417

418+
if api.Model != nil {
419+
model := ctx.Models[api.ModelName]
420+
resIDs := strset.New()
421+
combinedInput := []interface{}{model.Input, model.TrainingInput}
422+
for _, res := range ctx.ExtractCortexResources(combinedInput, resource.ConstantType, resource.RawColumnType, resource.AggregateType, resource.TransformedColumnType) {
423+
resIDs.Add(res.GetID())
424+
resIDs.Merge(ctx.AllComputedResourceDependencies(res.GetID()))
425+
}
426+
var samplePlaceholderFields []string
427+
for rawColumnName, rawColumn := range ctx.RawColumns {
428+
if resIDs.Has(rawColumn.GetID()) {
429+
fieldStr := fmt.Sprintf("\"%s\": %s", rawColumnName, rawColumn.GetColumnType().JSONPlaceholder())
430+
samplePlaceholderFields = append(samplePlaceholderFields, fieldStr)
431+
}
432+
}
433+
sort.Strings(samplePlaceholderFields)
434+
samplesPlaceholderStr := `{ "samples": [ { ` + strings.Join(samplePlaceholderFields, ", ") + " } ] }"
435+
out += "Payload: " + samplesPlaceholderStr + "\n"
436+
}
435437
if api != nil {
436438
out += resourceStr(api.API)
437439
}

cli/cmd/lib_cli_config.go

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,11 @@ func getPromptValidation(defaults *CliConfig) *cr.PromptValidation {
5959
PromptOpts: &cr.PromptOptions{
6060
Prompt: "Enter Cortex operator endpoint",
6161
},
62-
StringValidation: cr.GetURLValidation(&cr.URLValidation{
63-
Required: true,
64-
Default: defaults.CortexURL,
65-
}),
62+
StringValidation: &cr.StringValidation{
63+
Required: true,
64+
Default: defaults.CortexURL,
65+
Validator: cr.GetURLValidator(false, false),
66+
},
6667
},
6768
{
6869
StructField: "AWSAccessKeyID",
@@ -97,9 +98,10 @@ var fileValidation = &cr.StructValidation{
9798
{
9899
Key: "cortex_url",
99100
StructField: "CortexURL",
100-
StringValidation: cr.GetURLValidation(&cr.URLValidation{
101-
Required: true,
102-
}),
101+
StringValidation: &cr.StringValidation{
102+
Required: true,
103+
Validator: cr.GetURLValidator(false, false),
104+
},
103105
},
104106
{
105107
Key: "aws_access_key_id",

cli/cmd/predict.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,16 @@ var predictCmd = &cobra.Command{
109109
}
110110

111111
for _, prediction := range predictResponse.Predictions {
112+
if prediction.Prediction == nil {
113+
prettyResp, err := json.Pretty(prediction.Response)
114+
if err != nil {
115+
errors.Exit(err)
116+
}
117+
118+
fmt.Println(prettyResp)
119+
continue
120+
}
121+
112122
value := prediction.Prediction
113123
if prediction.PredictionReversed != nil {
114124
value = prediction.PredictionReversed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Importing External Models
2+
3+
You can serve a model that was trained outside of Cortex as an API.
4+
5+
1. Zip the exported estimator output in your checkpoint directory, e.g.
6+
7+
```bash
8+
$ ls export/estimator
9+
saved_model.pb variables/
10+
11+
$ zip -r model.zip export/estimator
12+
```
13+
14+
2. Upload the zipped file to Amazon S3, e.g.
15+
16+
```bash
17+
$ aws s3 cp model.zip s3://your-bucket/model.zip
18+
```
19+
20+
3. Specify `model_path` in an API, e.g.
21+
22+
```yaml
23+
- kind: api
24+
name: my-api
25+
model_path: s3://your-bucket/model.zip
26+
compute:
27+
replicas: 5
28+
gpu: 1
29+
```

docs/applications/resources/apis.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Serve models at scale and use them to build smarter applications.
88
- kind: api # (required)
99
name: <string> # API name (required)
1010
model_name: <string> # name of a Cortex model (required)
11+
model_path: <string> # path to a zipped model dir (optional)
1112
compute:
1213
replicas: <int> # number of replicas to launch (default: 1)
1314
cpu: <string> # CPU request (default: Null)

docs/summary.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
* [Compute](applications/advanced/compute.md)
3939
* [Python Packages](applications/advanced/python-packages.md)
4040
* [Development](development.md)
41+
* [Importing External Models](applications/advanced/external-models.md)
4142

4243
## Operator
4344

examples/external-model/app.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
- kind: app
2+
name: iris
3+
4+
- kind: api
5+
name: iris
6+
model_path: s3://cortex-examples/iris-model.zip
7+
compute:
8+
replicas: 1

examples/external-model/samples.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
"samples": [
3+
{
4+
"sepal_length": 5.2,
5+
"sepal_width": 3.6,
6+
"petal_length": 1.4,
7+
"petal_width": 0.3
8+
}
9+
]
10+
}

pkg/lib/aws/errors.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,14 @@ type ErrorKind int
2929
const (
3030
ErrUnknown ErrorKind = iota
3131
ErrInvalidS3aPath
32+
ErrInvalidS3Path
3233
ErrAuth
3334
)
3435

3536
var errorKinds = []string{
3637
"err_unknown",
3738
"err_invalid_s3a_path",
39+
"err_invalid_s3_path",
3840
"err_auth",
3941
}
4042

@@ -105,7 +107,14 @@ func (e Error) Error() string {
105107
func ErrorInvalidS3aPath(provided string) error {
106108
return Error{
107109
Kind: ErrInvalidS3aPath,
108-
message: fmt.Sprintf("%s is not a valid s3a path", s.UserStr(provided)),
110+
message: fmt.Sprintf("%s is not a valid s3a path (e.g. s3a://cortex-examples/iris.csv is a valid s3a path)", s.UserStr(provided)),
111+
}
112+
}
113+
114+
func ErrorInvalidS3Path(provided string) error {
115+
return Error{
116+
Kind: ErrInvalidS3Path,
117+
message: fmt.Sprintf("%s is not a valid s3 path (e.g. s3://cortex-examples/iris-model.zip is a valid s3 path)", s.UserStr(provided)),
109118
}
110119
}
111120

pkg/lib/aws/s3.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,20 @@ func (c *Client) DeleteFromS3ByPrefix(prefix string, continueIfFailure bool) err
233233
return errors.Wrap(err, prefix)
234234
}
235235

236+
func IsValidS3Path(s3Path string) bool {
237+
if !strings.HasPrefix(s3Path, "s3://") {
238+
return false
239+
}
240+
parts := strings.Split(s3Path[5:], "/")
241+
if len(parts) < 2 {
242+
return false
243+
}
244+
if parts[0] == "" || parts[1] == "" {
245+
return false
246+
}
247+
return true
248+
}
249+
236250
func IsValidS3aPath(s3aPath string) bool {
237251
if !strings.HasPrefix(s3aPath, "s3a://") {
238252
return false

pkg/lib/configreader/float32_ptr.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ type Float32PtrValidation struct {
3333
GreaterThanOrEqualTo *float32
3434
LessThan *float32
3535
LessThanOrEqualTo *float32
36-
Validator func(*float32) (*float32, error)
36+
Validator func(float32) (float32, error)
3737
}
3838

3939
func makeFloat32ValValidation(v *Float32PtrValidation) *Float32Validation {
@@ -171,8 +171,17 @@ func validateFloat32Ptr(val *float32, v *Float32PtrValidation) (*float32, error)
171171
}
172172
}
173173

174+
if val == nil {
175+
return val, nil
176+
}
177+
174178
if v.Validator != nil {
175-
return v.Validator(val)
179+
validated, err := v.Validator(*val)
180+
if err != nil {
181+
return nil, err
182+
}
183+
return &validated, nil
176184
}
185+
177186
return val, nil
178187
}

pkg/lib/configreader/float64_ptr.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ type Float64PtrValidation struct {
3333
GreaterThanOrEqualTo *float64
3434
LessThan *float64
3535
LessThanOrEqualTo *float64
36-
Validator func(*float64) (*float64, error)
36+
Validator func(float64) (float64, error)
3737
}
3838

3939
func makeFloat64ValValidation(v *Float64PtrValidation) *Float64Validation {
@@ -171,8 +171,17 @@ func validateFloat64Ptr(val *float64, v *Float64PtrValidation) (*float64, error)
171171
}
172172
}
173173

174+
if val == nil {
175+
return val, nil
176+
}
177+
174178
if v.Validator != nil {
175-
return v.Validator(val)
179+
validated, err := v.Validator(*val)
180+
if err != nil {
181+
return nil, err
182+
}
183+
return &validated, nil
176184
}
185+
177186
return val, nil
178187
}

pkg/lib/configreader/int32_ptr.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ type Int32PtrValidation struct {
3333
GreaterThanOrEqualTo *int32
3434
LessThan *int32
3535
LessThanOrEqualTo *int32
36-
Validator func(*int32) (*int32, error)
36+
Validator func(int32) (int32, error)
3737
}
3838

3939
func makeInt32ValValidation(v *Int32PtrValidation) *Int32Validation {
@@ -171,8 +171,17 @@ func validateInt32Ptr(val *int32, v *Int32PtrValidation) (*int32, error) {
171171
}
172172
}
173173

174+
if val == nil {
175+
return val, nil
176+
}
177+
174178
if v.Validator != nil {
175-
return v.Validator(val)
179+
validated, err := v.Validator(*val)
180+
if err != nil {
181+
return nil, err
182+
}
183+
return &validated, nil
176184
}
185+
177186
return val, nil
178187
}

pkg/lib/configreader/int64_ptr.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ type Int64PtrValidation struct {
3333
GreaterThanOrEqualTo *int64
3434
LessThan *int64
3535
LessThanOrEqualTo *int64
36-
Validator func(*int64) (*int64, error)
36+
Validator func(int64) (int64, error)
3737
}
3838

3939
func makeInt64ValValidation(v *Int64PtrValidation) *Int64Validation {
@@ -171,8 +171,17 @@ func validateInt64Ptr(val *int64, v *Int64PtrValidation) (*int64, error) {
171171
}
172172
}
173173

174+
if val == nil {
175+
return val, nil
176+
}
177+
174178
if v.Validator != nil {
175-
return v.Validator(val)
179+
validated, err := v.Validator(*val)
180+
if err != nil {
181+
return nil, err
182+
}
183+
return &validated, nil
176184
}
185+
177186
return val, nil
178187
}

pkg/lib/configreader/int_ptr.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ type IntPtrValidation struct {
3333
GreaterThanOrEqualTo *int
3434
LessThan *int
3535
LessThanOrEqualTo *int
36-
Validator func(*int) (*int, error)
36+
Validator func(int) (int, error)
3737
}
3838

3939
func makeIntValValidation(v *IntPtrValidation) *IntValidation {
@@ -171,8 +171,17 @@ func validateIntPtr(val *int, v *IntPtrValidation) (*int, error) {
171171
}
172172
}
173173

174+
if val == nil {
175+
return val, nil
176+
}
177+
174178
if v.Validator != nil {
175-
return v.Validator(val)
179+
validated, err := v.Validator(*val)
180+
if err != nil {
181+
return nil, err
182+
}
183+
return &validated, nil
176184
}
185+
177186
return val, nil
178187
}

pkg/lib/configreader/string_ptr.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ type StringPtrValidation struct {
3535
DNS1123 bool
3636
AllowCortexResources bool
3737
RequireCortexResources bool
38-
Validator func(*string) (*string, error)
38+
Validator func(string) (string, error)
3939
}
4040

4141
func makeStringValValidation(v *StringPtrValidation) *StringValidation {
@@ -170,8 +170,17 @@ func validateStringPtr(val *string, v *StringPtrValidation) (*string, error) {
170170
}
171171
}
172172

173+
if val == nil {
174+
return val, nil
175+
}
176+
173177
if v.Validator != nil {
174-
return v.Validator(val)
178+
validated, err := v.Validator(*val)
179+
if err != nil {
180+
return nil, err
181+
}
182+
return &validated, nil
175183
}
184+
176185
return val, nil
177186
}

0 commit comments

Comments
 (0)