diff --git a/cli/cmd/get.go b/cli/cmd/get.go index b3b3d5a0d0..67ddf1090b 100644 --- a/cli/cmd/get.go +++ b/cli/cmd/get.go @@ -382,7 +382,6 @@ func describeAPI(name string, resourcesRes *schema.GetResourcesResponse) (string ctx := resourcesRes.Context api := ctx.APIs[name] - model := ctx.Models[api.ModelName] var staleReplicas int32 var ctxAPIStatus *resource.APIStatus @@ -412,26 +411,29 @@ func describeAPI(name string, resourcesRes *schema.GetResourcesResponse) (string } out += titleStr("Endpoint") - resIDs := strset.New() - combinedInput := []interface{}{model.Input, model.TrainingInput} - for _, res := range ctx.ExtractCortexResources(combinedInput, resource.ConstantType, resource.RawColumnType, resource.AggregateType, resource.TransformedColumnType) { - resIDs.Add(res.GetID()) - resIDs.Merge(ctx.AllComputedResourceDependencies(res.GetID())) - } - var samplePlaceholderFields []string - for rawColumnName, rawColumn := range ctx.RawColumns { - if resIDs.Has(rawColumn.GetID()) { - fieldStr := fmt.Sprintf("\"%s\": %s", rawColumnName, rawColumn.GetColumnType().JSONPlaceholder()) - samplePlaceholderFields = append(samplePlaceholderFields, fieldStr) - } - } - sort.Strings(samplePlaceholderFields) - samplesPlaceholderStr := `{ "samples": [ { ` + strings.Join(samplePlaceholderFields, ", ") + " } ] }" out += "URL: " + urls.Join(resourcesRes.APIsBaseURL, anyAPIStatus.Path) + "\n" out += "Method: POST\n" out += `Header: "Content-Type: application/json"` + "\n" - out += "Payload: " + samplesPlaceholderStr + "\n" + if api.Model != nil { + model := ctx.Models[api.ModelName] + resIDs := strset.New() + combinedInput := []interface{}{model.Input, model.TrainingInput} + for _, res := range ctx.ExtractCortexResources(combinedInput, resource.ConstantType, resource.RawColumnType, resource.AggregateType, resource.TransformedColumnType) { + resIDs.Add(res.GetID()) + resIDs.Merge(ctx.AllComputedResourceDependencies(res.GetID())) + } + var samplePlaceholderFields []string + for rawColumnName, rawColumn := range ctx.RawColumns { + if resIDs.Has(rawColumn.GetID()) { + fieldStr := fmt.Sprintf("\"%s\": %s", rawColumnName, rawColumn.GetColumnType().JSONPlaceholder()) + samplePlaceholderFields = append(samplePlaceholderFields, fieldStr) + } + } + sort.Strings(samplePlaceholderFields) + samplesPlaceholderStr := `{ "samples": [ { ` + strings.Join(samplePlaceholderFields, ", ") + " } ] }" + out += "Payload: " + samplesPlaceholderStr + "\n" + } if api != nil { out += resourceStr(api.API) } diff --git a/cli/cmd/lib_cli_config.go b/cli/cmd/lib_cli_config.go index 811c19f3d1..e881adcc2f 100644 --- a/cli/cmd/lib_cli_config.go +++ b/cli/cmd/lib_cli_config.go @@ -59,10 +59,11 @@ func getPromptValidation(defaults *CliConfig) *cr.PromptValidation { PromptOpts: &cr.PromptOptions{ Prompt: "Enter Cortex operator endpoint", }, - StringValidation: cr.GetURLValidation(&cr.URLValidation{ - Required: true, - Default: defaults.CortexURL, - }), + StringValidation: &cr.StringValidation{ + Required: true, + Default: defaults.CortexURL, + Validator: cr.GetURLValidator(false, false), + }, }, { StructField: "AWSAccessKeyID", @@ -97,9 +98,10 @@ var fileValidation = &cr.StructValidation{ { Key: "cortex_url", StructField: "CortexURL", - StringValidation: cr.GetURLValidation(&cr.URLValidation{ - Required: true, - }), + StringValidation: &cr.StringValidation{ + Required: true, + Validator: cr.GetURLValidator(false, false), + }, }, { Key: "aws_access_key_id", diff --git a/cli/cmd/predict.go b/cli/cmd/predict.go index 4fe20925e1..78c8e7d7fb 100644 --- a/cli/cmd/predict.go +++ b/cli/cmd/predict.go @@ -109,6 +109,16 @@ var predictCmd = &cobra.Command{ } for _, prediction := range predictResponse.Predictions { + if prediction.Prediction == nil { + prettyResp, err := json.Pretty(prediction.Response) + if err != nil { + errors.Exit(err) + } + + fmt.Println(prettyResp) + continue + } + value := prediction.Prediction if prediction.PredictionReversed != nil { value = prediction.PredictionReversed diff --git a/docs/applications/advanced/external-models.md b/docs/applications/advanced/external-models.md new file mode 100644 index 0000000000..d82abdb466 --- /dev/null +++ b/docs/applications/advanced/external-models.md @@ -0,0 +1,29 @@ +# Importing External Models + +You can serve a model that was trained outside of Cortex as an API. + +1. Zip the exported estimator output in your checkpoint directory, e.g. + +```bash +$ ls export/estimator +saved_model.pb variables/ + +$ zip -r model.zip export/estimator +``` + +2. Upload the zipped file to Amazon S3, e.g. + +```bash +$ aws s3 cp model.zip s3://your-bucket/model.zip +``` + +3. Specify `model_path` in an API, e.g. + +```yaml +- kind: api + name: my-api + model_path: s3://your-bucket/model.zip + compute: + replicas: 5 + gpu: 1 +``` diff --git a/docs/applications/resources/apis.md b/docs/applications/resources/apis.md index aa2fb07183..29c8ac0638 100644 --- a/docs/applications/resources/apis.md +++ b/docs/applications/resources/apis.md @@ -8,6 +8,7 @@ Serve models at scale and use them to build smarter applications. - kind: api # (required) name: # API name (required) model_name: # name of a Cortex model (required) + model_path: # path to a zipped model dir (optional) compute: replicas: # number of replicas to launch (default: 1) cpu: # CPU request (default: Null) diff --git a/docs/summary.md b/docs/summary.md index b169d99715..5de9956fc7 100644 --- a/docs/summary.md +++ b/docs/summary.md @@ -38,6 +38,7 @@ * [Compute](applications/advanced/compute.md) * [Python Packages](applications/advanced/python-packages.md) * [Development](development.md) + * [Importing External Models](applications/advanced/external-models.md) ## Operator diff --git a/examples/external-model/app.yaml b/examples/external-model/app.yaml new file mode 100644 index 0000000000..a9c60cd5ed --- /dev/null +++ b/examples/external-model/app.yaml @@ -0,0 +1,8 @@ +- kind: app + name: iris + +- kind: api + name: iris + model_path: s3://cortex-examples/iris-model.zip + compute: + replicas: 1 diff --git a/examples/external-model/samples.json b/examples/external-model/samples.json new file mode 100644 index 0000000000..33d1e6a5b5 --- /dev/null +++ b/examples/external-model/samples.json @@ -0,0 +1,10 @@ +{ + "samples": [ + { + "sepal_length": 5.2, + "sepal_width": 3.6, + "petal_length": 1.4, + "petal_width": 0.3 + } + ] +} diff --git a/pkg/lib/aws/errors.go b/pkg/lib/aws/errors.go index 2a2d09e16d..634ba81e49 100644 --- a/pkg/lib/aws/errors.go +++ b/pkg/lib/aws/errors.go @@ -29,12 +29,14 @@ type ErrorKind int const ( ErrUnknown ErrorKind = iota ErrInvalidS3aPath + ErrInvalidS3Path ErrAuth ) var errorKinds = []string{ "err_unknown", "err_invalid_s3a_path", + "err_invalid_s3_path", "err_auth", } @@ -105,7 +107,14 @@ func (e Error) Error() string { func ErrorInvalidS3aPath(provided string) error { return Error{ Kind: ErrInvalidS3aPath, - message: fmt.Sprintf("%s is not a valid s3a path", s.UserStr(provided)), + message: fmt.Sprintf("%s is not a valid s3a path (e.g. s3a://cortex-examples/iris.csv is a valid s3a path)", s.UserStr(provided)), + } +} + +func ErrorInvalidS3Path(provided string) error { + return Error{ + Kind: ErrInvalidS3Path, + message: fmt.Sprintf("%s is not a valid s3 path (e.g. s3://cortex-examples/iris-model.zip is a valid s3 path)", s.UserStr(provided)), } } diff --git a/pkg/lib/aws/s3.go b/pkg/lib/aws/s3.go index 9c62935b2b..4bcd6a7c4f 100644 --- a/pkg/lib/aws/s3.go +++ b/pkg/lib/aws/s3.go @@ -233,6 +233,20 @@ func (c *Client) DeleteFromS3ByPrefix(prefix string, continueIfFailure bool) err return errors.Wrap(err, prefix) } +func IsValidS3Path(s3Path string) bool { + if !strings.HasPrefix(s3Path, "s3://") { + return false + } + parts := strings.Split(s3Path[5:], "/") + if len(parts) < 2 { + return false + } + if parts[0] == "" || parts[1] == "" { + return false + } + return true +} + func IsValidS3aPath(s3aPath string) bool { if !strings.HasPrefix(s3aPath, "s3a://") { return false diff --git a/pkg/lib/configreader/float32_ptr.go b/pkg/lib/configreader/float32_ptr.go index 540da3ccfb..bc069e2c70 100644 --- a/pkg/lib/configreader/float32_ptr.go +++ b/pkg/lib/configreader/float32_ptr.go @@ -33,7 +33,7 @@ type Float32PtrValidation struct { GreaterThanOrEqualTo *float32 LessThan *float32 LessThanOrEqualTo *float32 - Validator func(*float32) (*float32, error) + Validator func(float32) (float32, error) } func makeFloat32ValValidation(v *Float32PtrValidation) *Float32Validation { @@ -171,8 +171,17 @@ func validateFloat32Ptr(val *float32, v *Float32PtrValidation) (*float32, error) } } + if val == nil { + return val, nil + } + if v.Validator != nil { - return v.Validator(val) + validated, err := v.Validator(*val) + if err != nil { + return nil, err + } + return &validated, nil } + return val, nil } diff --git a/pkg/lib/configreader/float64_ptr.go b/pkg/lib/configreader/float64_ptr.go index 51808f6c63..d5678d43f5 100644 --- a/pkg/lib/configreader/float64_ptr.go +++ b/pkg/lib/configreader/float64_ptr.go @@ -33,7 +33,7 @@ type Float64PtrValidation struct { GreaterThanOrEqualTo *float64 LessThan *float64 LessThanOrEqualTo *float64 - Validator func(*float64) (*float64, error) + Validator func(float64) (float64, error) } func makeFloat64ValValidation(v *Float64PtrValidation) *Float64Validation { @@ -171,8 +171,17 @@ func validateFloat64Ptr(val *float64, v *Float64PtrValidation) (*float64, error) } } + if val == nil { + return val, nil + } + if v.Validator != nil { - return v.Validator(val) + validated, err := v.Validator(*val) + if err != nil { + return nil, err + } + return &validated, nil } + return val, nil } diff --git a/pkg/lib/configreader/int32_ptr.go b/pkg/lib/configreader/int32_ptr.go index e8f3f45d07..c9f8e91f0c 100644 --- a/pkg/lib/configreader/int32_ptr.go +++ b/pkg/lib/configreader/int32_ptr.go @@ -33,7 +33,7 @@ type Int32PtrValidation struct { GreaterThanOrEqualTo *int32 LessThan *int32 LessThanOrEqualTo *int32 - Validator func(*int32) (*int32, error) + Validator func(int32) (int32, error) } func makeInt32ValValidation(v *Int32PtrValidation) *Int32Validation { @@ -171,8 +171,17 @@ func validateInt32Ptr(val *int32, v *Int32PtrValidation) (*int32, error) { } } + if val == nil { + return val, nil + } + if v.Validator != nil { - return v.Validator(val) + validated, err := v.Validator(*val) + if err != nil { + return nil, err + } + return &validated, nil } + return val, nil } diff --git a/pkg/lib/configreader/int64_ptr.go b/pkg/lib/configreader/int64_ptr.go index 560403a7f7..c20c1571ff 100644 --- a/pkg/lib/configreader/int64_ptr.go +++ b/pkg/lib/configreader/int64_ptr.go @@ -33,7 +33,7 @@ type Int64PtrValidation struct { GreaterThanOrEqualTo *int64 LessThan *int64 LessThanOrEqualTo *int64 - Validator func(*int64) (*int64, error) + Validator func(int64) (int64, error) } func makeInt64ValValidation(v *Int64PtrValidation) *Int64Validation { @@ -171,8 +171,17 @@ func validateInt64Ptr(val *int64, v *Int64PtrValidation) (*int64, error) { } } + if val == nil { + return val, nil + } + if v.Validator != nil { - return v.Validator(val) + validated, err := v.Validator(*val) + if err != nil { + return nil, err + } + return &validated, nil } + return val, nil } diff --git a/pkg/lib/configreader/int_ptr.go b/pkg/lib/configreader/int_ptr.go index 5f7d43b5f8..f86e2e1f38 100644 --- a/pkg/lib/configreader/int_ptr.go +++ b/pkg/lib/configreader/int_ptr.go @@ -33,7 +33,7 @@ type IntPtrValidation struct { GreaterThanOrEqualTo *int LessThan *int LessThanOrEqualTo *int - Validator func(*int) (*int, error) + Validator func(int) (int, error) } func makeIntValValidation(v *IntPtrValidation) *IntValidation { @@ -171,8 +171,17 @@ func validateIntPtr(val *int, v *IntPtrValidation) (*int, error) { } } + if val == nil { + return val, nil + } + if v.Validator != nil { - return v.Validator(val) + validated, err := v.Validator(*val) + if err != nil { + return nil, err + } + return &validated, nil } + return val, nil } diff --git a/pkg/lib/configreader/string_ptr.go b/pkg/lib/configreader/string_ptr.go index f9ee57e2e3..a18ce2f553 100644 --- a/pkg/lib/configreader/string_ptr.go +++ b/pkg/lib/configreader/string_ptr.go @@ -35,7 +35,7 @@ type StringPtrValidation struct { DNS1123 bool AllowCortexResources bool RequireCortexResources bool - Validator func(*string) (*string, error) + Validator func(string) (string, error) } func makeStringValValidation(v *StringPtrValidation) *StringValidation { @@ -170,8 +170,17 @@ func validateStringPtr(val *string, v *StringPtrValidation) (*string, error) { } } + if val == nil { + return val, nil + } + if v.Validator != nil { - return v.Validator(val) + validated, err := v.Validator(*val) + if err != nil { + return nil, err + } + return &validated, nil } + return val, nil } diff --git a/pkg/lib/configreader/validators.go b/pkg/lib/configreader/validators.go index 0e8e34bc8e..e6d7745299 100644 --- a/pkg/lib/configreader/validators.go +++ b/pkg/lib/configreader/validators.go @@ -31,69 +31,49 @@ func init() { portRe = regexp.MustCompile(`:[0-9]+$`) } -type PathValidation struct { - Required bool - Default string - BaseDir string -} - -func GetFilePathValidation(v *PathValidation) *StringValidation { - validator := func(val string) (string, error) { - val = files.RelPath(val, v.BaseDir) +func GetFilePathValidator(baseDir string) func(string) (string, error) { + return func(val string) (string, error) { + val = files.RelPath(val, baseDir) if err := files.CheckFile(val); err != nil { return "", err } return val, nil } - - return &StringValidation{ - Required: v.Required, - Default: v.Default, - Validator: validator, - } } -type S3aPathValidation struct { - Required bool - Default string -} - -func GetS3aPathValidation(v *S3aPathValidation) *StringValidation { - validator := func(val string) (string, error) { +func GetS3aPathValidator() func(string) (string, error) { + return func(val string) (string, error) { if !aws.IsValidS3aPath(val) { return "", aws.ErrorInvalidS3aPath(val) } return val, nil } - - return &StringValidation{ - Required: v.Required, - Default: v.Default, - Validator: validator, - } } -type URLValidation struct { - Required bool - Default string - DefaultHTTP bool // Otherwise default is https - AddPort bool +func GetS3PathValidator() func(string) (string, error) { + return func(val string) (string, error) { + if !aws.IsValidS3Path(val) { + return "", aws.ErrorInvalidS3Path(val) + } + return val, nil + } } -func GetURLValidation(v *URLValidation) *StringValidation { - validator := func(val string) (string, error) { +// uses https unless defaultHTTP == true +func GetURLValidator(defaultHTTP bool, addPort bool) func(string) (string, error) { + return func(val string) (string, error) { urlStr := strings.TrimSpace(val) if !strings.HasPrefix(strings.ToLower(urlStr), "http") { - if v.DefaultHTTP { + if defaultHTTP { urlStr = "http://" + urlStr } else { urlStr = "https://" + urlStr } } - if v.AddPort { + if addPort { if !portRe.MatchString(urlStr) { if strings.HasPrefix(strings.ToLower(urlStr), "https") { urlStr = urlStr + ":443" @@ -109,10 +89,4 @@ func GetURLValidation(v *URLValidation) *StringValidation { return urlStr, nil } - - return &StringValidation{ - Required: v.Required, - Default: v.Default, - Validator: validator, - } } diff --git a/pkg/operator/api/context/apis.go b/pkg/operator/api/context/apis.go index a6165d40eb..0a6dd7d2e1 100644 --- a/pkg/operator/api/context/apis.go +++ b/pkg/operator/api/context/apis.go @@ -26,7 +26,7 @@ type API struct { *userconfig.API *ComputedResourceFields Path string `json:"path"` - ModelName string `json:"model_name"` // This is just a convenience which removes the @ from userconfig.API.Model + ModelName string `json:"model_name"` // This removes the @ from userconfig.API.Model, or sets it to userconfig.API.ModelPath if it's external } func APIPath(apiName string, appName string) string { diff --git a/pkg/operator/api/context/dependencies.go b/pkg/operator/api/context/dependencies.go index eb7a8fe3c9..4c80076f53 100644 --- a/pkg/operator/api/context/dependencies.go +++ b/pkg/operator/api/context/dependencies.go @@ -151,6 +151,9 @@ func (ctx *Context) modelDependencies(model *Model) strset.Set { } func (ctx *Context) apiDependencies(api *API) strset.Set { + if api.Model == nil { + return strset.New() + } model := ctx.Models[api.ModelName] return strset.New(model.ID) } diff --git a/pkg/operator/api/context/serialize.go b/pkg/operator/api/context/serialize.go index df147dc310..9a68e373e5 100644 --- a/pkg/operator/api/context/serialize.go +++ b/pkg/operator/api/context/serialize.go @@ -189,6 +189,10 @@ func (ctx *Context) castSchemaTypes() error { } func (ctx *Context) ToSerial() *Serial { + if ctx.Environment == nil { + return &Serial{Context: ctx} + } + serial := Serial{ Context: ctx, RawColumnSplit: ctx.splitRawColumns(), @@ -201,17 +205,19 @@ func (ctx *Context) ToSerial() *Serial { func (serial *Serial) ContextFromSerial() (*Context, error) { ctx := serial.Context - ctx.RawColumns = serial.collectRawColumns() + if ctx.Environment != nil { + ctx.RawColumns = serial.collectRawColumns() - environment, err := serial.collectEnvironment() - if err != nil { - return nil, err - } - ctx.Environment = environment + environment, err := serial.collectEnvironment() + if err != nil { + return nil, err + } + ctx.Environment = environment - err = ctx.castSchemaTypes() - if err != nil { - return nil, err + err = ctx.castSchemaTypes() + if err != nil { + return nil, err + } } return ctx, nil diff --git a/pkg/operator/api/userconfig/apis.go b/pkg/operator/api/userconfig/apis.go index e7cd27909d..fbc960d628 100644 --- a/pkg/operator/api/userconfig/apis.go +++ b/pkg/operator/api/userconfig/apis.go @@ -17,9 +17,8 @@ limitations under the License. package userconfig import ( - "github.com/cortexlabs/yaml" - cr "github.com/cortexlabs/cortex/pkg/lib/configreader" + "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/operator/api/resource" ) @@ -27,9 +26,10 @@ type APIs []*API type API struct { ResourceFields - Model string `json:"model" yaml:"model"` - Compute *APICompute `json:"compute" yaml:"compute"` - Tags Tags `json:"tags" yaml:"tags"` + Model *string `json:"model" yaml:"model"` + ModelPath *string `json:"model_path" yaml:"model_path"` + Compute *APICompute `json:"compute" yaml:"compute"` + Tags Tags `json:"tags" yaml:"tags"` } var apiValidation = &cr.StructValidation{ @@ -42,18 +42,17 @@ var apiValidation = &cr.StructValidation{ }, }, { - StructField: "Model", - DefaultField: "Name", - DefaultFieldFunc: func(name interface{}) interface{} { - model := "@" + name.(string) - escapedModel, _ := yaml.EscapeAtSymbol(model) - return escapedModel - }, - StringValidation: &cr.StringValidation{ - Required: false, + StructField: "Model", + StringPtrValidation: &cr.StringPtrValidation{ RequireCortexResources: true, }, }, + { + StructField: "ModelPath", + StringPtrValidation: &cr.StringPtrValidation{ + Validator: cr.GetS3PathValidator(), + }, + }, apiComputeFieldValidation, tagsFieldValidation, typeFieldValidation, @@ -61,6 +60,12 @@ var apiValidation = &cr.StructValidation{ } func (apis APIs) Validate() error { + for _, api := range apis { + if err := api.Validate(); err != nil { + return err + } + } + resources := make([]Resource, len(apis)) for i, res := range apis { resources[i] = res @@ -70,6 +75,19 @@ func (apis APIs) Validate() error { if len(dups) > 0 { return ErrorDuplicateResourceName(dups...) } + + return nil +} + +func (api *API) Validate() error { + if api.ModelPath == nil && api.Model == nil { + return errors.Wrap(ErrorSpecifyOnlyOneMissing("model_name", "model_path"), Identify(api)) + } + + if api.ModelPath != nil && api.Model != nil { + return errors.Wrap(ErrorSpecifyOnlyOne("model_name", "model_path"), Identify(api)) + } + return nil } diff --git a/pkg/operator/api/userconfig/config.go b/pkg/operator/api/userconfig/config.go index 49afdca6e2..df7523d796 100644 --- a/pkg/operator/api/userconfig/config.go +++ b/pkg/operator/api/userconfig/config.go @@ -74,6 +74,15 @@ func mergeConfigs(target *Config, source *Config) error { target.App = source.App } + if target.Resources == nil { + target.Resources = make(map[string][]Resource) + } + for resourceName, resources := range source.Resources { + for _, res := range resources { + target.Resources[resourceName] = append(target.Resources[resourceName], res) + } + } + return nil } @@ -190,8 +199,27 @@ func (config *Config) Validate(envName string) error { config.Environment = env } } + + apisAllExternal := true + for _, api := range config.APIs { + if api.Model != nil { + apisAllExternal = false + break + } + } + if config.Environment == nil { - return ErrorUndefinedResource(envName, resource.EnvironmentType) + if !apisAllExternal || len(config.APIs) == 0 { + return ErrorUndefinedResource(envName, resource.EnvironmentType) + } + + for _, resources := range config.Resources { + for _, res := range resources { + if res.GetResourceType() != resource.APIType { + return ErrorExtraResourcesWithExternalAPIs(res) + } + } + } } return nil diff --git a/pkg/operator/api/userconfig/environments.go b/pkg/operator/api/userconfig/environments.go index 78f2ac5355..0e2c5e956a 100644 --- a/pkg/operator/api/userconfig/environments.go +++ b/pkg/operator/api/userconfig/environments.go @@ -153,9 +153,10 @@ type ExternalData struct { var externalDataValidation = []*cr.StructFieldValidation{ { StructField: "Path", - StringValidation: cr.GetS3aPathValidation(&cr.S3aPathValidation{ - Required: true, - }), + StringValidation: &cr.StringValidation{ + Required: true, + Validator: cr.GetS3aPathValidator(), + }, }, { StructField: "Region", diff --git a/pkg/operator/api/userconfig/errors.go b/pkg/operator/api/userconfig/errors.go index b5efdc55c7..19b1416996 100644 --- a/pkg/operator/api/userconfig/errors.go +++ b/pkg/operator/api/userconfig/errors.go @@ -73,6 +73,7 @@ const ( ErrPredictionKeyOnModelWithEstimator ErrSpecifyOnlyOneMissing ErrEnvSchemaMismatch + ErrExtraResourcesWithExternalAPIs ErrImplDoesNotExist ) @@ -121,6 +122,7 @@ var errorKinds = []string{ "err_prediction_key_on_model_with_estimator", "err_specify_only_one_missing", "err_env_schema_mismatch", + "err_extra_resources_with_external_a_p_is", "err_impl_does_not_exist", } @@ -560,6 +562,13 @@ func ErrorEnvSchemaMismatch(env1, env2 *Environment) error { } } +func ErrorExtraResourcesWithExternalAPIs(res Resource) error { + return Error{ + Kind: ErrExtraResourcesWithExternalAPIs, + message: fmt.Sprintf("only apis can be defined if environment is not defined (found %s)", Identify(res)), + } +} + func ErrorImplDoesNotExist(path string) error { return Error{ Kind: ErrImplDoesNotExist, diff --git a/pkg/operator/context/apis.go b/pkg/operator/context/apis.go index b2024f5c93..c3383d49de 100644 --- a/pkg/operator/context/apis.go +++ b/pkg/operator/context/apis.go @@ -29,20 +29,31 @@ import ( func getAPIs(config *userconfig.Config, models context.Models, + datasetVersion string, ) (context.APIs, error) { apis := context.APIs{} for _, apiConfig := range config.APIs { - modelName, _ := yaml.ExtractAtSymbolText(apiConfig.Model) - - model := models[modelName] - if model == nil { - return nil, errors.Wrap(userconfig.ErrorUndefinedResource(modelName, resource.ModelType), userconfig.Identify(apiConfig), userconfig.ModelNameKey) - } var buf bytes.Buffer + var modelName string buf.WriteString(apiConfig.Name) - buf.WriteString(model.ID) + + if apiConfig.Model != nil { + modelName, _ = yaml.ExtractAtSymbolText(*apiConfig.Model) + model := models[modelName] + if model == nil { + return nil, errors.Wrap(userconfig.ErrorUndefinedResource(modelName, resource.ModelType), userconfig.Identify(apiConfig), userconfig.ModelNameKey) + } + buf.WriteString(model.ID) + } + + if apiConfig.ModelPath != nil { + modelName = *apiConfig.ModelPath + buf.WriteString(datasetVersion) + buf.WriteString(*apiConfig.ModelPath) + } + id := hash.Bytes(buf.Bytes()) apis[apiConfig.Name] = &context.API{ diff --git a/pkg/operator/context/context.go b/pkg/operator/context/context.go index 641d7f87e9..9742c715da 100644 --- a/pkg/operator/context/context.go +++ b/pkg/operator/context/context.go @@ -133,16 +133,24 @@ func New( } ctx.DatasetVersion = datasetVersion - ctx.Environment = getEnvironment(userconf, datasetVersion) + if userconf.Environment != nil { + ctx.Environment = getEnvironment(userconf, datasetVersion) + } ctx.Root = filepath.Join( consts.AppsDir, ctx.App.Name, consts.DataDir, ctx.DatasetVersion, - ctx.Environment.ID, ) + if ctx.Environment != nil { + ctx.Root = filepath.Join( + ctx.Root, + ctx.Environment.ID, + ) + } + ctx.MetadataRoot = filepath.Join( ctx.Root, consts.MetadataDir, @@ -223,7 +231,7 @@ func New( } ctx.Models = models - apis, err := getAPIs(userconf, ctx.Models) + apis, err := getAPIs(userconf, ctx.Models, ctx.DatasetVersion) if err != nil { return nil, err } @@ -256,7 +264,9 @@ func calculateID(ctx *context.Context) string { ids = append(ids, ctx.RawDataset.Key) ids = append(ids, ctx.StatusPrefix) ids = append(ids, ctx.App.ID) - ids = append(ids, ctx.Environment.ID) + if ctx.Environment != nil { + ids = append(ids, ctx.Environment.ID) + } for _, resource := range ctx.AllResources() { ids = append(ids, resource.GetID()) diff --git a/pkg/operator/workloads/workflow.go b/pkg/operator/workloads/workflow.go index 8827bfd3a5..17c623b5a1 100644 --- a/pkg/operator/workloads/workflow.go +++ b/pkg/operator/workloads/workflow.go @@ -65,23 +65,25 @@ func Create(ctx *context.Context) (*awfv1.Workflow, error) { var allSpecs []*WorkloadSpec - pythonPackageJobSpecs, err := pythonPackageWorkloadSpecs(ctx) - if err != nil { - return nil, err - } - allSpecs = append(allSpecs, pythonPackageJobSpecs...) + if ctx.Environment != nil { + pythonPackageJobSpecs, err := pythonPackageWorkloadSpecs(ctx) + if err != nil { + return nil, err + } + allSpecs = append(allSpecs, pythonPackageJobSpecs...) - dataJobSpecs, err := dataWorkloadSpecs(ctx) - if err != nil { - return nil, err - } - allSpecs = append(allSpecs, dataJobSpecs...) + dataJobSpecs, err := dataWorkloadSpecs(ctx) + if err != nil { + return nil, err + } + allSpecs = append(allSpecs, dataJobSpecs...) - trainingJobSpecs, err := trainingWorkloadSpecs(ctx) - if err != nil { - return nil, err + trainingJobSpecs, err := trainingWorkloadSpecs(ctx) + if err != nil { + return nil, err + } + allSpecs = append(allSpecs, trainingJobSpecs...) } - allSpecs = append(allSpecs, trainingJobSpecs...) apiSpecs, err := apiWorkloadSpecs(ctx) if err != nil { diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index 9135147771..e2730161fb 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -99,11 +99,12 @@ def __init__(self, **kwargs): ) ) - self.columns = util.merge_dicts_overwrite(self.raw_columns, self.transformed_columns) + if self.environment is not None: + self.columns = util.merge_dicts_overwrite(self.raw_columns, self.transformed_columns) - self.raw_column_names = list(self.raw_columns.keys()) - self.transformed_column_names = list(self.transformed_columns.keys()) - self.column_names = list(self.columns.keys()) + self.raw_column_names = list(self.raw_columns.keys()) + self.transformed_column_names = list(self.transformed_columns.keys()) + self.column_names = list(self.columns.keys()) # Internal caches self._transformer_impls = {} @@ -117,15 +118,14 @@ def __init__(self, **kwargs): os.environ["AWS_REGION"] = self.cortex_config.get("region", "") # Id map - self.pp_id_map = ResourceMap(self.python_packages) - self.rf_id_map = ResourceMap(self.raw_columns) - self.ag_id_map = ResourceMap(self.aggregates) - self.tf_id_map = ResourceMap(self.transformed_columns) - self.td_id_map = ResourceMap(self.training_datasets) - self.models_id_map = ResourceMap(self.models) - self.apis_id_map = ResourceMap(self.apis) - self.constants_id_map = ResourceMap(self.constants) - + self.pp_id_map = ResourceMap(self.python_packages) if self.python_packages else None + self.rf_id_map = ResourceMap(self.raw_columns) if self.raw_columns else None + self.ag_id_map = ResourceMap(self.aggregates) if self.aggregates else None + self.tf_id_map = ResourceMap(self.transformed_columns) if self.transformed_columns else None + self.td_id_map = ResourceMap(self.training_datasets) if self.training_datasets else None + self.models_id_map = ResourceMap(self.models) if self.models else None + self.apis_id_map = ResourceMap(self.apis) if self.apis else None + self.constants_id_map = ResourceMap(self.constants) if self.constants else None self.id_map = util.merge_dicts_overwrite( self.pp_id_map, self.rf_id_map, @@ -704,17 +704,19 @@ def _validate_required_fn_args(impl, fn_name, args): def _deserialize_raw_ctx(raw_ctx): - raw_columns = raw_ctx["raw_columns"] - raw_ctx["raw_columns"] = util.merge_dicts_overwrite(*raw_columns.values()) + if raw_ctx.get("environment") is not None: + raw_columns = raw_ctx["raw_columns"] + raw_ctx["raw_columns"] = util.merge_dicts_overwrite(*raw_columns.values()) + + data_split = raw_ctx["environment_data"] - data_split = raw_ctx["environment_data"] + if data_split["csv_data"] is not None and data_split["parquet_data"] is None: + raw_ctx["environment"]["data"] = data_split["csv_data"] + elif data_split["parquet_data"] is not None and data_split["csv_data"] is None: + raw_ctx["environment"]["data"] = data_split["parquet_data"] + else: + raise CortexException("expected csv_data or parquet_data but found " + data_split) - if data_split["csv_data"] is not None and data_split["parquet_data"] is None: - raw_ctx["environment"]["data"] = data_split["csv_data"] - elif data_split["parquet_data"] is not None and data_split["csv_data"] is None: - raw_ctx["environment"]["data"] = data_split["parquet_data"] - else: - raise CortexException("expected csv_data or parquet_data but found " + data_split) return raw_ctx diff --git a/pkg/workloads/lib/storage/s3.py b/pkg/workloads/lib/storage/s3.py index afdb36b4b8..cd8dea6d58 100644 --- a/pkg/workloads/lib/storage/s3.py +++ b/pkg/workloads/lib/storage/s3.py @@ -175,3 +175,18 @@ def download_and_unzip(self, key, local_dir): local_zip = os.path.join(local_dir, "zip.zip") self.download_file(key, local_zip) util.extract_zip(local_zip, delete_zip_file=True) + + def download_and_unzip_external(self, s3_path, local_dir): + util.mkdir_p(local_dir) + local_zip = os.path.join(local_dir, "zip.zip") + self.download_file_external(s3_path, local_zip) + util.extract_zip(local_zip, delete_zip_file=True) + + def download_file_external(self, s3_path, local_path): + try: + util.mkdir_p(os.path.dirname(local_path)) + bucket, key = self.deconstruct_s3_path(s3_path) + self.s3.download_file(bucket, key, local_path) + return local_path + except Exception as e: + raise CortexException("bucket " + bucket, "key " + key) from e diff --git a/pkg/workloads/lib/util.py b/pkg/workloads/lib/util.py index 4a35aec5b1..77103a0994 100644 --- a/pkg/workloads/lib/util.py +++ b/pkg/workloads/lib/util.py @@ -421,6 +421,15 @@ def merge_dicts_no_overwrite(*dicts): def merge_two_dicts_in_place_overwrite(x, y): """Merge y into x, with overwriting. x is updated in place""" + if x is None: + return y + + if y is None: + return x + + if y is None and x is None: + return None + for k, v in y.items(): if k in x and isinstance(x[k], dict) and isinstance(y[k], collections.Mapping): merge_dicts_in_place_overwrite(x[k], y[k]) diff --git a/pkg/workloads/tf_api/api.py b/pkg/workloads/tf_api/api.py index ce763fca79..69fe3ff7da 100644 --- a/pkg/workloads/tf_api/api.py +++ b/pkg/workloads/tf_api/api.py @@ -64,6 +64,14 @@ "DT_COMPLEX128": "dcomplexVal", } +DTYPE_TO_TF_TYPE = { + "DT_INT32": tf.int32, + "DT_INT64": tf.int64, + "DT_FLOAT": tf.float32, + "DT_STRING": tf.string, + "DT_BOOL": tf.bool, +} + def transform_sample(sample): ctx = local_cache["ctx"] @@ -96,8 +104,8 @@ def transform_sample(sample): def create_prediction_request(transformed_sample): ctx = local_cache["ctx"] - signatureDef = local_cache["metadata"]["signatureDef"] - signature_key = list(signatureDef.keys())[0] + signature_def = local_cache["metadata"]["signatureDef"] + signature_key = list(signature_def.keys())[0] prediction_request = predict_pb2.PredictRequest() prediction_request.model_spec.name = "default" prediction_request.model_spec.signature_name = signature_key @@ -114,6 +122,24 @@ def create_prediction_request(transformed_sample): return prediction_request +def create_raw_prediction_request(sample): + signature_def = local_cache["metadata"]["signatureDef"] + signature_key = list(signature_def.keys())[0] + prediction_request = predict_pb2.PredictRequest() + prediction_request.model_spec.name = "default" + prediction_request.model_spec.signature_name = signature_key + + for column_name, value in sample.items(): + shape = [1] + if util.is_list(value): + shape = [len(value)] + sig_type = signature_def[signature_key]["inputs"][column_name]["dtype"] + tensor_proto = tf.make_tensor_proto([value], dtype=DTYPE_TO_TF_TYPE[sig_type], shape=shape) + prediction_request.inputs[column_name].CopyFrom(tensor_proto) + + return prediction_request + + def reverse_transform(value): ctx = local_cache["ctx"] model = local_cache["model"] @@ -204,19 +230,42 @@ def run_get_model_metadata(): return sigmap +def parse_response_proto_raw(response_proto): + results_dict = json_format.MessageToDict(response_proto) + outputs = results_dict["outputs"] + + outputs_simplified = {} + for key in outputs.keys(): + value_key = DTYPE_TO_VALUE_KEY[outputs[key]["dtype"]] + outputs_simplified[key] = outputs[key][value_key] + + return {"response": outputs_simplified} + + def run_predict(sample): - transformed_sample = transform_sample(sample) - prediction_request = create_prediction_request(transformed_sample) - response_proto = local_cache["stub"].Predict(prediction_request, timeout=10.0) - result = parse_response_proto(response_proto) - util.log_indent("Raw sample:", indent=4) - util.log_pretty(sample, indent=6) - util.log_indent("Transformed sample:", indent=4) - util.log_pretty(transformed_sample, indent=6) - util.log_indent("Prediction:", indent=4) - util.log_pretty(result, indent=6) - - result["transformed_sample"] = transformed_sample + if local_cache["ctx"].environment is not None: + transformed_sample = transform_sample(sample) + prediction_request = create_prediction_request(transformed_sample) + response_proto = local_cache["stub"].Predict(prediction_request, timeout=10.0) + result = parse_response_proto(response_proto) + + util.log_indent("Raw sample:", indent=4) + util.log_pretty(sample, indent=6) + util.log_indent("Transformed sample:", indent=4) + util.log_pretty(transformed_sample, indent=6) + util.log_indent("Prediction:", indent=4) + util.log_pretty(result, indent=6) + + result["transformed_sample"] = transformed_sample + + else: + prediction_request = create_raw_prediction_request(sample) + response_proto = local_cache["stub"].Predict(prediction_request, timeout=10.0) + result = parse_response_proto_raw(response_proto) + util.log_indent("Sample:", indent=4) + util.log_pretty(sample, indent=6) + util.log_indent("Prediction:", indent=4) + util.log_pretty(result, indent=6) return result @@ -281,13 +330,14 @@ def predict(app_name, api_name): for i, sample in enumerate(payload["samples"]): util.log_indent("sample {}".format(i + 1), 2) - is_valid, reason = is_valid_sample(sample) - if not is_valid: - return prediction_failed(sample, reason) + if local_cache["ctx"].environment is not None: + is_valid, reason = is_valid_sample(sample) + if not is_valid: + return prediction_failed(sample, reason) - for column in local_cache["required_inputs"]: - column_type = ctx.get_inferred_column_type(column["name"]) - sample[column["name"]] = util.upcast(sample[column["name"]], column_type) + for column in local_cache["required_inputs"]: + column_type = ctx.get_inferred_column_type(column["name"]) + sample[column["name"]] = util.upcast(sample[column["name"]], column_type) try: result = run_predict(sample) @@ -317,40 +367,48 @@ def start(args): package.install_packages(ctx.python_packages, ctx.storage) api = ctx.apis_id_map[args.api] - model = ctx.models[api["model_name"]] - estimator = ctx.estimators[model["estimator"]] - tf_lib.set_logging_verbosity(ctx.environment["log_level"]["tensorflow"]) - local_cache["ctx"] = ctx local_cache["api"] = api - local_cache["model"] = model - local_cache["estimator"] = estimator - local_cache["target_col"] = ctx.columns[util.get_resource_ref(model["target_column"])] - local_cache["target_col_type"] = ctx.get_inferred_column_type( - util.get_resource_ref(model["target_column"]) - ) + local_cache["ctx"] = ctx - if not os.path.isdir(args.model_dir): - ctx.storage.download_and_unzip(model["key"], args.model_dir) + if ctx.environment is not None: + model = ctx.models[api["model_name"]] + estimator = ctx.estimators[model["estimator"]] - for column_name in ctx.extract_column_names([model["input"], model["target_column"]]): - if ctx.is_transformed_column(column_name): - trans_impl, _ = ctx.get_transformer_impl(column_name) - local_cache["trans_impls"][column_name] = trans_impl - transformed_column = ctx.transformed_columns[column_name] + local_cache["model"] = model + local_cache["estimator"] = estimator + local_cache["target_col"] = ctx.columns[util.get_resource_ref(model["target_column"])] + local_cache["target_col_type"] = ctx.get_inferred_column_type( + util.get_resource_ref(model["target_column"]) + ) + + tf_lib.set_logging_verbosity(ctx.environment["log_level"]["tensorflow"]) + + if not os.path.isdir(args.model_dir): + ctx.storage.download_and_unzip(model["key"], args.model_dir) + + for column_name in ctx.extract_column_names([model["input"], model["target_column"]]): + if ctx.is_transformed_column(column_name): + trans_impl, _ = ctx.get_transformer_impl(column_name) + local_cache["trans_impls"][column_name] = trans_impl + transformed_column = ctx.transformed_columns[column_name] - # cache aggregate values - for resource_name in util.extract_resource_refs(transformed_column["input"]): - if resource_name in ctx.aggregates: - ctx.get_obj(ctx.aggregates[resource_name]["key"]) + # cache aggregate values + for resource_name in util.extract_resource_refs(transformed_column["input"]): + if resource_name in ctx.aggregates: + ctx.get_obj(ctx.aggregates[resource_name]["key"]) + + local_cache["required_inputs"] = tf_lib.get_base_input_columns(model["name"], ctx) + + else: + if not os.path.isdir(args.model_dir): + ctx.storage.download_and_unzip_external(api["model_path"], args.model_dir) channel = grpc.insecure_channel("localhost:" + str(args.tf_serve_port)) local_cache["stub"] = prediction_service_pb2_grpc.PredictionServiceStub(channel) - local_cache["required_inputs"] = tf_lib.get_base_input_columns(model["name"], ctx) - # wait a bit for tf serving to start before querying metadata - limit = 600 + limit = 300 for i in range(limit): try: local_cache["metadata"] = run_get_model_metadata() @@ -364,7 +422,7 @@ def start(args): time.sleep(1) - logger.info("Serving model: {}".format(model["name"])) + logger.info("Serving model: {}".format(api["model_name"])) serve(app, listen="*:{}".format(args.port))