From a9184a104ea2b9a42d772e0060862892eeca8041 Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Thu, 6 Jun 2019 16:27:38 -0700 Subject: [PATCH 01/44] Update input, add estimators (go only) --- cli/cmd/get.go | 7 +- dev/operator_local.sh | 3 +- go.mod | 5 +- go.sum | 12 +- pkg/consts/consts.go | 2 + pkg/lib/argo/argo.go | 2 +- pkg/lib/cast/interface.go | 16 + pkg/lib/configreader/errors.go | 22 +- pkg/lib/configreader/interface.go | 80 +- pkg/lib/configreader/interface_map.go | 28 +- pkg/lib/configreader/interface_map_list.go | 22 +- pkg/lib/configreader/reader.go | 2 +- pkg/lib/configreader/string.go | 12 + pkg/lib/configreader/string_list.go | 24 +- pkg/lib/configreader/string_map.go | 22 +- pkg/lib/configreader/string_ptr.go | 4 + pkg/lib/configreader/types.go | 2 + pkg/lib/strings/stringify.go | 9 +- pkg/operator/api/context/aggregates.go | 11 - pkg/operator/api/context/apis.go | 3 +- pkg/operator/api/context/columns.go | 66 +- pkg/operator/api/context/columns_test.go | 95 - pkg/operator/api/context/context.go | 62 +- pkg/operator/api/context/dependencies.go | 137 +- pkg/operator/api/context/dependencies_test.go | 78 + .../api/context/{values.go => estimators.go} | 27 +- pkg/operator/api/context/models.go | 34 - pkg/operator/api/context/raw_columns.go | 53 +- .../api/context/resource_fakes_test.go | 136 + .../api/context/transformed_columns.go | 24 +- pkg/operator/api/resource/type.go | 9 +- pkg/operator/api/userconfig/aggregates.go | 38 +- pkg/operator/api/userconfig/aggregators.go | 19 +- pkg/operator/api/userconfig/apis.go | 19 +- pkg/operator/api/userconfig/column_type.go | 10 +- pkg/operator/api/userconfig/columns.go | 70 - pkg/operator/api/userconfig/compound_type.go | 174 ++ .../api/userconfig/compound_type_test.go | 84 + pkg/operator/api/userconfig/config.go | 115 +- pkg/operator/api/userconfig/config_key.go | 97 +- pkg/operator/api/userconfig/constants.go | 27 +- pkg/operator/api/userconfig/environments.go | 34 +- pkg/operator/api/userconfig/errors.go | 260 +- pkg/operator/api/userconfig/estimators.go | 134 + pkg/operator/api/userconfig/inputs.go | 84 - pkg/operator/api/userconfig/model_type.go | 78 - pkg/operator/api/userconfig/models.go | 91 +- pkg/operator/api/userconfig/raw_columns.go | 10 +- .../api/userconfig/transformed_columns.go | 47 +- pkg/operator/api/userconfig/transformers.go | 16 +- pkg/operator/api/userconfig/types.go | 43 +- pkg/operator/api/userconfig/validators.go | 714 +++--- .../api/userconfig/validators_test.go | 2177 +++++++++++------ pkg/operator/api/userconfig/value_type.go | 41 +- .../api/userconfig/value_type_test.go | 49 + pkg/operator/context/aggregates.go | 105 +- pkg/operator/context/aggregators.go | 44 +- pkg/operator/context/apis.go | 19 +- pkg/operator/context/autogenerator.go | 129 - pkg/operator/context/constants.go | 34 +- pkg/operator/context/context.go | 51 +- pkg/operator/context/environment.go | 7 +- pkg/operator/context/errors.go | 85 - pkg/operator/context/estimators.go | 174 ++ pkg/operator/context/models.go | 163 +- pkg/operator/context/raw_columns.go | 10 +- pkg/operator/context/resource_fakes_test.go | 344 +++ pkg/operator/context/resources.go | 408 +++ pkg/operator/context/resources_test.go | 871 +++++++ pkg/operator/context/transformed_columns.go | 108 +- pkg/operator/context/transformers.go | 42 +- pkg/operator/workloads/workload_spec.go | 5 +- 72 files changed, 5448 insertions(+), 2591 deletions(-) delete mode 100644 pkg/operator/api/context/columns_test.go create mode 100644 pkg/operator/api/context/dependencies_test.go rename pkg/operator/api/context/{values.go => estimators.go} (62%) create mode 100644 pkg/operator/api/context/resource_fakes_test.go create mode 100644 pkg/operator/api/userconfig/compound_type.go create mode 100644 pkg/operator/api/userconfig/compound_type_test.go create mode 100644 pkg/operator/api/userconfig/estimators.go delete mode 100644 pkg/operator/api/userconfig/inputs.go delete mode 100644 pkg/operator/api/userconfig/model_type.go create mode 100644 pkg/operator/api/userconfig/value_type_test.go delete mode 100644 pkg/operator/context/autogenerator.go delete mode 100644 pkg/operator/context/errors.go create mode 100644 pkg/operator/context/estimators.go create mode 100644 pkg/operator/context/resource_fakes_test.go create mode 100644 pkg/operator/context/resources.go create mode 100644 pkg/operator/context/resources_test.go diff --git a/cli/cmd/get.go b/cli/cmd/get.go index c35771e70d..335791282a 100644 --- a/cli/cmd/get.go +++ b/cli/cmd/get.go @@ -412,9 +412,10 @@ func describeAPI(name string, resourcesRes *schema.GetResourcesResponse) (string out += titleStr("Endpoint") var samplePlaceholderFields []string - for _, colName := range ctx.RawColumnInputNames(model) { - column := ctx.GetColumn(colName) - fieldStr := `"` + colName + `": ` + column.GetType().JSONPlaceholder() + combinedInput := []interface{}{model.Input, model.TrainingInput} + for _, res := range ctx.ExtractCortexResources(combinedInput, resource.RawColumnType) { + rawColumn := res.(context.RawColumn) + fieldStr := `"` + rawColumn.GetName() + `": ` + rawColumn.GetColumnType().JSONPlaceholder() samplePlaceholderFields = append(samplePlaceholderFields, fieldStr) } samplesPlaceholderStr := `{ "samples": [ { ` + strings.Join(samplePlaceholderFields, ", ") + " } ] }" diff --git a/dev/operator_local.sh b/dev/operator_local.sh index 79bcc90197..c4d223b8c8 100755 --- a/dev/operator_local.sh +++ b/dev/operator_local.sh @@ -21,8 +21,9 @@ ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")"/.. >/dev/null && pwd)" source $ROOT/dev/config/cortex.sh -export CONST_OPERATOR_TRANSFORMERS_DIR=$ROOT/pkg/transformers export CONST_OPERATOR_AGGREGATORS_DIR=$ROOT/pkg/aggregators +export CONST_OPERATOR_TRANSFORMERS_DIR=$ROOT/pkg/transformers +export CONST_OPERATOR_ESTIMATORS_DIR=$ROOT/pkg/estimators export CONST_OPERATOR_IN_CLUSTER=false rerun -watch $ROOT/pkg $ROOT/cli -ignore $ROOT/vendor $ROOT/bin -run sh -c \ diff --git a/go.mod b/go.mod index 19953a7188..a8322b2a26 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ // go mod tidy // (replace versions of locked dependencies above) // go mod tidy +// check the diff in this file module github.com/cortexlabs/cortex @@ -18,6 +19,7 @@ require ( github.com/GoogleCloudPlatform/spark-on-k8s-operator v0.0.0-20181208011959-62db1d66dafa github.com/argoproj/argo v2.3.0+incompatible github.com/aws/aws-sdk-go v1.16.17 + github.com/cortexlabs/yaml v0.0.0-20190530233410-11baebde6c89 github.com/davecgh/go-spew v1.1.1 github.com/emicklei/go-restful v2.8.0+incompatible // indirect github.com/go-openapi/spec v0.18.0 // indirect @@ -33,6 +35,7 @@ require ( github.com/imdario/mergo v0.3.6 // indirect github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/json-iterator/go v1.1.5 // indirect + github.com/kr/pretty v0.1.0 // indirect github.com/mitchellh/go-homedir v1.0.0 github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.1 // indirect @@ -49,9 +52,9 @@ require ( golang.org/x/oauth2 v0.0.0-20190110195249-fd3eaa146cbb // indirect golang.org/x/sys v0.0.0-20190109145017-48ac38b7c8cb // indirect golang.org/x/time v0.0.0-20181108054448-85acf8d2951c // indirect + gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/robfig/cron.v2 v2.0.0-20150107220207-be2e0b0deed5 - gopkg.in/yaml.v2 v2.2.2 k8s.io/api v0.0.0-20181204000039-89a74a8d264d k8s.io/apimachinery v0.0.0-20181127025237-2b1284ed4c93 k8s.io/client-go v10.0.0+incompatible diff --git a/go.sum b/go.sum index 86213f5bda..5fc76a6c40 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/argoproj/argo v2.3.0+incompatible h1:L1OYZ86Q7NK19ahdl/eJOq78Mlf52wUK github.com/argoproj/argo v2.3.0+incompatible/go.mod h1:KJ0MB+tuhtAklR4jkPM10mIZXfRA0peTYJ1sLUnFLVU= github.com/aws/aws-sdk-go v1.16.17 h1:hHRKZhoB4qEY17aGNp71UxQFyYpx6WZXGMUzx9y/A4w= github.com/aws/aws-sdk-go v1.16.17/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= +github.com/cortexlabs/yaml v0.0.0-20190530233410-11baebde6c89 h1:NdtOtp57mz3NX/O8k0X9Qic0syGjOyT0b1+3GSq1uMU= +github.com/cortexlabs/yaml v0.0.0-20190530233410-11baebde6c89/go.mod h1:ZQaiMs8i2UxfSwf0gNtM0kfmXg6hQHArtgaI6aL0+/U= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -51,6 +53,11 @@ github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5i github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/json-iterator/go v1.1.5 h1:gL2yXlmiIo4+t+y32d4WGwOjKGYcGOuyrg46vadswDE= github.com/json-iterator/go v1.1.5/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/mailru/easyjson v0.0.0-20180823135443-60711f1a8329 h1:2gxZ0XQIU/5z3Z3bUBu+FXuk2pFbkN6tcwi/pjyaDic= github.com/mailru/easyjson v0.0.0-20180823135443-60711f1a8329/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mitchellh/go-homedir v1.0.0 h1:vKb8ShqSby24Yrqr/yDYkuFz8d0WUjys40rvnGC8aR0= @@ -100,13 +107,14 @@ google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO50 google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/robfig/cron.v2 v2.0.0-20150107220207-be2e0b0deed5 h1:E846t8CnR+lv5nE+VuiKTDG/v1U2stad0QzddfJC7kY= gopkg.in/robfig/cron.v2 v2.0.0-20150107220207-be2e0b0deed5/go.mod h1:hiOFpYm0ZJbusNj2ywpbrXowU3G8U6GIQzqn2mw1UIE= +gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= k8s.io/api v0.0.0-20181204000039-89a74a8d264d h1:HQoGWsWUe/FmRcX9BU440AAMnzBFEf+DBo4nbkQlNzs= k8s.io/api v0.0.0-20181204000039-89a74a8d264d/go.mod h1:iuAfoD4hCxJ8Onx9kaTIt30j7jUFS00AXQi6QMi99vA= k8s.io/apimachinery v0.0.0-20181127025237-2b1284ed4c93 h1:tT6oQBi0qwLbbZSfDkdIsb23EwaLY85hoAV4SpXfdao= diff --git a/pkg/consts/consts.go b/pkg/consts/consts.go index 85543ae63a..fd80d06442 100644 --- a/pkg/consts/consts.go +++ b/pkg/consts/consts.go @@ -23,6 +23,7 @@ import ( var ( CortexVersion = "master" // CORTEX_VERSION + TypeStrRegex = regexp.MustCompile(`"(INT|FLOAT|STRING|BOOL)(_COLUMN)?(\|(INT|FLOAT|STRING|BOOL)(_COLUMN)?)*"`) SingleTypeStrRegex = regexp.MustCompile(`"(INT|FLOAT|STRING|BOOL)(_COLUMN)?"`) CompoundTypeStrRegex = regexp.MustCompile(`"(INT|FLOAT|STRING|BOOL)(_COLUMN)?(\|(INT|FLOAT|STRING|BOOL)(_COLUMN)?)+"`) @@ -44,6 +45,7 @@ var ( AggregatorsDir = "aggregators" AggregatesDir = "aggregates" TransformersDir = "transformers" + EstimatorsDir = "estimators" ModelImplsDir = "model_implementations" PythonPackagesDir = "python_packages" ModelsDir = "models" diff --git a/pkg/lib/argo/argo.go b/pkg/lib/argo/argo.go index a1bdf2541b..f830662437 100644 --- a/pkg/lib/argo/argo.go +++ b/pkg/lib/argo/argo.go @@ -315,5 +315,5 @@ func (wfItem *WorkflowItem) Dependencies() strset.Set { return strset.New(wfItem.Task.Dependencies...) } - return make(strset.Set) + return strset.New() } diff --git a/pkg/lib/cast/interface.go b/pkg/lib/cast/interface.go index f2ebebd508..c15147ae97 100644 --- a/pkg/lib/cast/interface.go +++ b/pkg/lib/cast/interface.go @@ -742,3 +742,19 @@ func IsScalarType(in interface{}) bool { } return false } + +func ToScalarType(in interface{}) (interface{}, bool) { + if casted, ok := InterfaceToInt64(in); ok { + return casted, true + } + if casted, ok := InterfaceToFloat64(in); ok { + return casted, true + } + if casted, ok := in.(bool); ok { + return casted, true + } + if casted, ok := in.(string); ok { + return casted, true + } + return in, false +} diff --git a/pkg/lib/configreader/errors.go b/pkg/lib/configreader/errors.go index b6c4931a02..7e374ad989 100644 --- a/pkg/lib/configreader/errors.go +++ b/pkg/lib/configreader/errors.go @@ -51,7 +51,8 @@ const ( ErrMustBeDefined ErrMapMustBeDefined ErrMustBeEmpty - ErrNotAFile + ErrCortexResourceOnlyAllowed + ErrCortexResourceNotAllowed ) var errorKinds = []string{ @@ -80,10 +81,11 @@ var errorKinds = []string{ "err_must_be_defined", "err_map_must_be_defined", "err_must_be_empty", - "err_not_a_file", + "err_cortex_resource_only_allowed", + "err_cortex_resource_not_allowed", } -var _ = [1]int{}[int(ErrNotAFile)-(len(errorKinds)-1)] // Ensure list length matches +var _ = [1]int{}[int(ErrCortexResourceNotAllowed)-(len(errorKinds)-1)] // Ensure list length matches func (t ErrorKind) String() string { return errorKinds[t] @@ -300,3 +302,17 @@ func ErrorMustBeEmpty() error { message: "must be empty", } } + +func ErrorCortexResourceOnlyAllowed(invalidStr string) error { + return Error{ + Kind: ErrCortexResourceOnlyAllowed, + message: fmt.Sprintf("%s: only cortex resource references (which start with @) are allowed in this context", invalidStr), + } +} + +func ErrorCortexResourceNotAllowed(resourceName string) error { + return Error{ + Kind: ErrCortexResourceNotAllowed, + message: fmt.Sprintf("@%s: cortex resource references (which start with @) are not allowed in this context", resourceName), + } +} diff --git a/pkg/lib/configreader/interface.go b/pkg/lib/configreader/interface.go index 7e06b9697d..1f7ad11d01 100644 --- a/pkg/lib/configreader/interface.go +++ b/pkg/lib/configreader/interface.go @@ -17,6 +17,8 @@ limitations under the License. package configreader import ( + "github.com/cortexlabs/yaml" + "github.com/cortexlabs/cortex/pkg/lib/cast" "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/maps" @@ -26,10 +28,12 @@ import ( ) type InterfaceValidation struct { - Required bool - Default interface{} - AllowExplicitNull bool - Validator func(interface{}) (interface{}, error) + Required bool + Default interface{} + AllowExplicitNull bool + AllowCortexResources bool + RequireCortexResources bool + Validator func(interface{}) (interface{}, error) } func Interface(inter interface{}, v *InterfaceValidation) (interface{}, error) { @@ -67,12 +71,80 @@ func ValidateInterfaceProvided(val interface{}, v *InterfaceValidation) (interfa } func validateInterface(val interface{}, v *InterfaceValidation) (interface{}, error) { + if v.RequireCortexResources { + if err := checkOnlyCortexResources(val); err != nil { + return nil, err + } + } else if !v.AllowCortexResources { + if err := checkNoCortexResources(val); err != nil { + return nil, err + } + } + if v.Validator != nil { return v.Validator(val) } return val, nil } +func checkNoCortexResources(obj interface{}) error { + if objStr, ok := obj.(string); ok { + if resourceName, ok := yaml.ExtractAtSymbolText(objStr); ok { + return ErrorCortexResourceNotAllowed(resourceName) + } + } + + if objSlice, ok := cast.InterfaceToInterfaceSlice(obj); ok { + for i, objItem := range objSlice { + if err := checkNoCortexResources(objItem); err != nil { + return errors.Wrap(err, s.Index(i)) + } + } + } + + if objMap, ok := cast.InterfaceToInterfaceInterfaceMap(obj); ok { + for k, v := range objMap { + if err := checkNoCortexResources(k); err != nil { + return err + } + if err := checkNoCortexResources(v); err != nil { + return errors.Wrap(err, s.UserStrStripped(k)) + } + } + } + + return nil +} + +func checkOnlyCortexResources(obj interface{}) error { + if objStr, ok := obj.(string); ok { + if _, ok := yaml.ExtractAtSymbolText(objStr); !ok { + return ErrorCortexResourceOnlyAllowed(objStr) + } + } + + if objSlice, ok := cast.InterfaceToInterfaceSlice(obj); ok { + for i, objItem := range objSlice { + if err := checkOnlyCortexResources(objItem); err != nil { + return errors.Wrap(err, s.Index(i)) + } + } + } + + if objMap, ok := cast.InterfaceToInterfaceInterfaceMap(obj); ok { + for k, v := range objMap { + if err := checkOnlyCortexResources(k); err != nil { + return err + } + if err := checkOnlyCortexResources(v); err != nil { + return errors.Wrap(err, s.UserStrStripped(k)) + } + } + } + + return nil +} + // FlattenAllStrValues assumes that the order for maps is deterministic func FlattenAllStrValues(obj interface{}) ([]string, error) { obj = pointer.IndirectSafe(obj) diff --git a/pkg/lib/configreader/interface_map.go b/pkg/lib/configreader/interface_map.go index e6870167ad..598fa4628b 100644 --- a/pkg/lib/configreader/interface_map.go +++ b/pkg/lib/configreader/interface_map.go @@ -23,14 +23,16 @@ import ( ) type InterfaceMapValidation struct { - Required bool - Default map[string]interface{} - AllowExplicitNull bool - AllowEmpty bool - ScalarsOnly bool - StringLeavesOnly bool - AllowedLeafValues []string - Validator func(map[string]interface{}) (map[string]interface{}, error) + Required bool + Default map[string]interface{} + AllowExplicitNull bool + AllowEmpty bool + ScalarsOnly bool + StringLeavesOnly bool + AllowedLeafValues []string + AllowCortexResources bool + RequireCortexResources bool + Validator func(map[string]interface{}) (map[string]interface{}, error) } func InterfaceMap(inter interface{}, v *InterfaceMapValidation) (map[string]interface{}, error) { @@ -72,6 +74,16 @@ func ValidateInterfaceMapProvided(val map[string]interface{}, v *InterfaceMapVal } func validateInterfaceMap(val map[string]interface{}, v *InterfaceMapValidation) (map[string]interface{}, error) { + if v.RequireCortexResources { + if err := checkOnlyCortexResources(val); err != nil { + return nil, err + } + } else if !v.AllowCortexResources { + if err := checkNoCortexResources(val); err != nil { + return nil, err + } + } + if !v.AllowEmpty { if val != nil && len(val) == 0 { return nil, ErrorCannotBeEmpty() diff --git a/pkg/lib/configreader/interface_map_list.go b/pkg/lib/configreader/interface_map_list.go index d6fd9f1fba..82c3833cb5 100644 --- a/pkg/lib/configreader/interface_map_list.go +++ b/pkg/lib/configreader/interface_map_list.go @@ -22,11 +22,13 @@ import ( ) type InterfaceMapListValidation struct { - Required bool - Default []map[string]interface{} - AllowExplicitNull bool - AllowEmpty bool - Validator func([]map[string]interface{}) ([]map[string]interface{}, error) + Required bool + Default []map[string]interface{} + AllowExplicitNull bool + AllowEmpty bool + AllowCortexResources bool + RequireCortexResources bool + Validator func([]map[string]interface{}) ([]map[string]interface{}, error) } func InterfaceMapList(inter interface{}, v *InterfaceMapListValidation) ([]map[string]interface{}, error) { @@ -68,6 +70,16 @@ func ValidateInterfaceMapListProvided(val []map[string]interface{}, v *Interface } func validateInterfaceMapList(val []map[string]interface{}, v *InterfaceMapListValidation) ([]map[string]interface{}, error) { + if v.RequireCortexResources { + if err := checkOnlyCortexResources(val); err != nil { + return nil, err + } + } else if !v.AllowCortexResources { + if err := checkNoCortexResources(val); err != nil { + return nil, err + } + } + if !v.AllowEmpty { if val != nil && len(val) == 0 { return nil, ErrorCannotBeEmpty() diff --git a/pkg/lib/configreader/reader.go b/pkg/lib/configreader/reader.go index 091efddad2..3b40cfdc34 100644 --- a/pkg/lib/configreader/reader.go +++ b/pkg/lib/configreader/reader.go @@ -22,8 +22,8 @@ import ( "reflect" "strings" + "github.com/cortexlabs/yaml" input "github.com/tcnksm/go-input" - yaml "gopkg.in/yaml.v2" "github.com/cortexlabs/cortex/pkg/lib/cast" "github.com/cortexlabs/cortex/pkg/lib/debug" diff --git a/pkg/lib/configreader/string.go b/pkg/lib/configreader/string.go index 4b13ead1a0..8ef0742163 100644 --- a/pkg/lib/configreader/string.go +++ b/pkg/lib/configreader/string.go @@ -38,6 +38,8 @@ type StringValidation struct { AlphaNumericDashUnderscore bool DNS1035 bool DNS1123 bool + AllowCortexResources bool + RequireCortexResources bool Validator func(string) (string, error) } @@ -162,6 +164,16 @@ func ValidateString(val string, v *StringValidation) (string, error) { } func ValidateStringVal(val string, v *StringValidation) error { + if v.RequireCortexResources { + if err := checkOnlyCortexResources(val); err != nil { + return err + } + } else if !v.AllowCortexResources { + if err := checkNoCortexResources(val); err != nil { + return err + } + } + if !v.AllowEmpty { if len(val) == 0 { return ErrorCannotBeEmpty() diff --git a/pkg/lib/configreader/string_list.go b/pkg/lib/configreader/string_list.go index b80b3573ca..9d123d8cbe 100644 --- a/pkg/lib/configreader/string_list.go +++ b/pkg/lib/configreader/string_list.go @@ -23,12 +23,14 @@ import ( ) type StringListValidation struct { - Required bool - Default []string - AllowExplicitNull bool - AllowEmpty bool - DisallowDups bool - Validator func([]string) ([]string, error) + Required bool + Default []string + AllowExplicitNull bool + AllowEmpty bool + DisallowDups bool + AllowCortexResources bool + RequireCortexResources bool + Validator func([]string) ([]string, error) } func StringList(inter interface{}, v *StringListValidation) ([]string, error) { @@ -70,6 +72,16 @@ func ValidateStringListProvided(val []string, v *StringListValidation) ([]string } func validateStringList(val []string, v *StringListValidation) ([]string, error) { + if v.RequireCortexResources { + if err := checkOnlyCortexResources(val); err != nil { + return nil, err + } + } else if !v.AllowCortexResources { + if err := checkNoCortexResources(val); err != nil { + return nil, err + } + } + if !v.AllowEmpty { if val != nil && len(val) == 0 { return nil, ErrorCannotBeEmpty() diff --git a/pkg/lib/configreader/string_map.go b/pkg/lib/configreader/string_map.go index 633b15ab90..5009ba7a04 100644 --- a/pkg/lib/configreader/string_map.go +++ b/pkg/lib/configreader/string_map.go @@ -22,11 +22,13 @@ import ( ) type StringMapValidation struct { - Required bool - Default map[string]string - AllowExplicitNull bool - AllowEmpty bool - Validator func(map[string]string) (map[string]string, error) + Required bool + Default map[string]string + AllowExplicitNull bool + AllowEmpty bool + AllowCortexResources bool + RequireCortexResources bool + Validator func(map[string]string) (map[string]string, error) } func StringMap(inter interface{}, v *StringMapValidation) (map[string]string, error) { @@ -68,6 +70,16 @@ func ValidateStringMapProvided(val map[string]string, v *StringMapValidation) (m } func validateStringMap(val map[string]string, v *StringMapValidation) (map[string]string, error) { + if v.RequireCortexResources { + if err := checkOnlyCortexResources(val); err != nil { + return nil, err + } + } else if !v.AllowCortexResources { + if err := checkNoCortexResources(val); err != nil { + return nil, err + } + } + if !v.AllowEmpty { if val != nil && len(val) == 0 { return nil, ErrorCannotBeEmpty() diff --git a/pkg/lib/configreader/string_ptr.go b/pkg/lib/configreader/string_ptr.go index 2c0cac4288..f9ee57e2e3 100644 --- a/pkg/lib/configreader/string_ptr.go +++ b/pkg/lib/configreader/string_ptr.go @@ -33,6 +33,8 @@ type StringPtrValidation struct { AlphaNumericDashUnderscore bool DNS1035 bool DNS1123 bool + AllowCortexResources bool + RequireCortexResources bool Validator func(*string) (*string, error) } @@ -45,6 +47,8 @@ func makeStringValValidation(v *StringPtrValidation) *StringValidation { AlphaNumericDashUnderscore: v.AlphaNumericDashUnderscore, DNS1035: v.DNS1035, DNS1123: v.DNS1123, + AllowCortexResources: v.AllowCortexResources, + RequireCortexResources: v.RequireCortexResources, } } diff --git a/pkg/lib/configreader/types.go b/pkg/lib/configreader/types.go index 46755fe191..c734ef204e 100644 --- a/pkg/lib/configreader/types.go +++ b/pkg/lib/configreader/types.go @@ -40,6 +40,8 @@ var ( PrimTypeStringToStringMap PrimitiveType = "map of strings to strings" ) +var PrimTypeScalars = []PrimitiveType{PrimTypeInt, PrimTypeFloat, PrimTypeString, PrimTypeBool} + func (ts PrimitiveTypes) StringList() []string { strs := make([]string, len(ts)) for i, t := range ts { diff --git a/pkg/lib/strings/stringify.go b/pkg/lib/strings/stringify.go index c1b53133a4..5e9bbf8f08 100644 --- a/pkg/lib/strings/stringify.go +++ b/pkg/lib/strings/stringify.go @@ -25,6 +25,8 @@ import ( "strconv" "strings" "time" + + "github.com/cortexlabs/yaml" ) var emptyTime time.Time @@ -147,7 +149,7 @@ func strIndent(val interface{}, indent string, currentIndent string, newlineChar if funcVal.IsValid() { t := funcVal.Type() if t.NumIn() == 0 && t.NumOut() == 1 && t.Out(0).Kind() == reflect.String { - return quoteStr + funcVal.Call(nil)[0].Interface().(string) + quoteStr + return strIndent(funcVal.Call(nil)[0].Interface().(string), indent, currentIndent, newlineChar, quoteStr) } } if _, ok := reflect.PtrTo(valueType).MethodByName("String"); ok { @@ -157,7 +159,7 @@ func strIndent(val interface{}, indent string, currentIndent string, newlineChar if funcVal.IsValid() { t := funcVal.Type() if t.NumIn() == 0 && t.NumOut() == 1 && t.Out(0).Kind() == reflect.String { - return quoteStr + funcVal.Call(nil)[0].Interface().(string) + quoteStr + return strIndent(funcVal.Call(nil)[0].Interface().(string), indent, currentIndent, newlineChar, quoteStr) } } } @@ -216,6 +218,9 @@ func strIndent(val interface{}, indent string, currentIndent string, newlineChar case reflect.String: var t string casted := value.Convert(reflect.TypeOf(t)).Interface().(string) + + casted, _ = yaml.UnescapeAtSymbol(casted) + switch val.(type) { case json.Number: return casted diff --git a/pkg/operator/api/context/aggregates.go b/pkg/operator/api/context/aggregates.go index 69c5afe16e..7daeda8712 100644 --- a/pkg/operator/api/context/aggregates.go +++ b/pkg/operator/api/context/aggregates.go @@ -17,7 +17,6 @@ limitations under the License. package context import ( - "github.com/cortexlabs/cortex/pkg/lib/cast" "github.com/cortexlabs/cortex/pkg/operator/api/userconfig" ) @@ -30,16 +29,6 @@ type Aggregate struct { Key string `json:"key"` } -func (aggregate *Aggregate) GetType() interface{} { - return aggregate.Type -} - -// Returns map[string]string because after autogen, arg values are constant or aggregate names -func (aggregate *Aggregate) Args() map[string]string { - args, _ := cast.InterfaceToStrStrMap(aggregate.Inputs.Args) - return args -} - func (aggregates Aggregates) OneByID(id string) *Aggregate { for _, aggregate := range aggregates { if aggregate.ID == id { diff --git a/pkg/operator/api/context/apis.go b/pkg/operator/api/context/apis.go index 9fcf247a72..a6165d40eb 100644 --- a/pkg/operator/api/context/apis.go +++ b/pkg/operator/api/context/apis.go @@ -25,7 +25,8 @@ type APIs map[string]*API type API struct { *userconfig.API *ComputedResourceFields - Path string `json:"path"` + Path string `json:"path"` + ModelName string `json:"model_name"` // This is just a convenience which removes the @ from userconfig.API.Model } func APIPath(apiName string, appName string) string { diff --git a/pkg/operator/api/context/columns.go b/pkg/operator/api/context/columns.go index acc4e33f99..0c55a974e6 100644 --- a/pkg/operator/api/context/columns.go +++ b/pkg/operator/api/context/columns.go @@ -17,13 +17,7 @@ limitations under the License. package context import ( - "github.com/cortexlabs/cortex/pkg/lib/cast" - "github.com/cortexlabs/cortex/pkg/lib/configreader" - "github.com/cortexlabs/cortex/pkg/lib/errors" - "github.com/cortexlabs/cortex/pkg/lib/hash" "github.com/cortexlabs/cortex/pkg/lib/sets/strset" - s "github.com/cortexlabs/cortex/pkg/lib/strings" - "github.com/cortexlabs/cortex/pkg/operator/api/resource" "github.com/cortexlabs/cortex/pkg/operator/api/userconfig" ) @@ -31,9 +25,8 @@ type Columns map[string]Column type Column interface { ComputedResource - GetType() userconfig.ColumnType + GetColumnType() userconfig.ColumnType IsRaw() bool - GetInputRawColumnNames() []string } func (ctx *Context) Columns() Columns { @@ -66,60 +59,3 @@ func (ctx *Context) GetColumn(name string) Column { } return nil } - -func (columns Columns) ID(columnNames []string) string { - columnIDMap := make(map[string]string) - for _, columnName := range columnNames { - columnIDMap[columnName] = columns[columnName].GetID() - } - return hash.Any(columnIDMap) -} - -func (columns Columns) IDWithTags(columnNames []string) string { - columnIDMap := make(map[string]string) - for _, columnName := range columnNames { - columnIDMap[columnName] = columns[columnName].GetIDWithTags() - } - return hash.Any(columnIDMap) -} - -func GetColumnRuntimeTypes( - columnInputValues map[string]interface{}, - rawColumns RawColumns, -) (map[string]interface{}, error) { - - err := userconfig.ValidateColumnInputValues(columnInputValues) - if err != nil { - return nil, err - } - - columnRuntimeTypes := make(map[string]interface{}, len(columnInputValues)) - - for inputName, columnInputValue := range columnInputValues { - if rawColumnName, ok := columnInputValue.(string); ok { - rawColumn, ok := rawColumns[rawColumnName] - if !ok { - return nil, errors.Wrap(userconfig.ErrorUndefinedResource(rawColumnName, resource.RawColumnType), inputName) - } - columnRuntimeTypes[inputName] = rawColumn.GetType() - continue - } - - if rawColumnNames, ok := cast.InterfaceToStrSlice(columnInputValue); ok { - rawColumnTypes := make([]userconfig.ColumnType, len(rawColumnNames)) - for i, rawColumnName := range rawColumnNames { - rawColumn, ok := rawColumns[rawColumnName] - if !ok { - return nil, errors.Wrap(userconfig.ErrorUndefinedResource(rawColumnName, resource.RawColumnType), inputName, s.Index(i)) - } - rawColumnTypes[i] = rawColumn.GetType() - } - columnRuntimeTypes[inputName] = rawColumnTypes - continue - } - - return nil, errors.Wrap(configreader.ErrorInvalidPrimitiveType(columnInputValue, configreader.PrimTypeString, configreader.PrimTypeStringList), inputName) // unexpected - } - - return columnRuntimeTypes, nil -} diff --git a/pkg/operator/api/context/columns_test.go b/pkg/operator/api/context/columns_test.go deleted file mode 100644 index d6e3aba150..0000000000 --- a/pkg/operator/api/context/columns_test.go +++ /dev/null @@ -1,95 +0,0 @@ -/* -Copyright 2019 Cortex Labs, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package context - -import ( - "testing" - - "github.com/stretchr/testify/require" - - cr "github.com/cortexlabs/cortex/pkg/lib/configreader" - "github.com/cortexlabs/cortex/pkg/operator/api/userconfig" -) - -func TestGetColumnRuntimeTypes(t *testing.T) { - var columnInputValues map[string]interface{} - var expected map[string]interface{} - - rawColumns := RawColumns{ - "rfInt": &RawIntColumn{ - RawIntColumn: &userconfig.RawIntColumn{ - Type: userconfig.IntegerColumnType, - }, - }, - "rfFloat": &RawFloatColumn{ - RawFloatColumn: &userconfig.RawFloatColumn{ - Type: userconfig.FloatColumnType, - }, - }, - "rfStr": &RawStringColumn{ - RawStringColumn: &userconfig.RawStringColumn{ - Type: userconfig.StringColumnType, - }, - }, - } - - columnInputValues = cr.MustReadYAMLStrMap("in: rfInt") - expected = map[string]interface{}{"in": userconfig.IntegerColumnType} - checkTestGetColumnRuntimeTypes(columnInputValues, rawColumns, expected, t) - - columnInputValues = cr.MustReadYAMLStrMap("in: rfStr") - expected = map[string]interface{}{"in": userconfig.StringColumnType} - checkTestGetColumnRuntimeTypes(columnInputValues, rawColumns, expected, t) - - columnInputValues = cr.MustReadYAMLStrMap("in: [rfFloat]") - expected = map[string]interface{}{"in": []userconfig.ColumnType{userconfig.FloatColumnType}} - checkTestGetColumnRuntimeTypes(columnInputValues, rawColumns, expected, t) - - columnInputValues = cr.MustReadYAMLStrMap("in: [rfInt, rfFloat, rfStr, rfInt]") - expected = map[string]interface{}{"in": []userconfig.ColumnType{userconfig.IntegerColumnType, userconfig.FloatColumnType, userconfig.StringColumnType, userconfig.IntegerColumnType}} - checkTestGetColumnRuntimeTypes(columnInputValues, rawColumns, expected, t) - - columnInputValues = cr.MustReadYAMLStrMap("in1: [rfInt, rfFloat]\nin2: rfStr") - expected = map[string]interface{}{"in1": []userconfig.ColumnType{userconfig.IntegerColumnType, userconfig.FloatColumnType}, "in2": userconfig.StringColumnType} - checkTestGetColumnRuntimeTypes(columnInputValues, rawColumns, expected, t) - - columnInputValues = cr.MustReadYAMLStrMap("in: 1") - checkErrTestGetColumnRuntimeTypes(columnInputValues, rawColumns, t) - - columnInputValues = cr.MustReadYAMLStrMap("in: [1, 2, 3]") - checkErrTestGetColumnRuntimeTypes(columnInputValues, rawColumns, t) - - columnInputValues = cr.MustReadYAMLStrMap("in: {in: rfInt}") - checkErrTestGetColumnRuntimeTypes(columnInputValues, rawColumns, t) - - columnInputValues = cr.MustReadYAMLStrMap("in: rfMissing") - checkErrTestGetColumnRuntimeTypes(columnInputValues, rawColumns, t) - - columnInputValues = cr.MustReadYAMLStrMap("in: [rfMissing]") - checkErrTestGetColumnRuntimeTypes(columnInputValues, rawColumns, t) -} - -func checkTestGetColumnRuntimeTypes(columnInputValues map[string]interface{}, rawColumns RawColumns, expected map[string]interface{}, t *testing.T) { - runtimeTypes, err := GetColumnRuntimeTypes(columnInputValues, rawColumns) - require.NoError(t, err) - require.Equal(t, expected, runtimeTypes) -} - -func checkErrTestGetColumnRuntimeTypes(columnInputValues map[string]interface{}, rawColumns RawColumns, t *testing.T) { - _, err := GetColumnRuntimeTypes(columnInputValues, rawColumns) - require.Error(t, err) -} diff --git a/pkg/operator/api/context/context.go b/pkg/operator/api/context/context.go index 125077cac1..533d888093 100644 --- a/pkg/operator/api/context/context.go +++ b/pkg/operator/api/context/context.go @@ -44,6 +44,7 @@ type Context struct { Constants Constants `json:"constants"` Aggregators Aggregators `json:"aggregators"` Transformers Transformers `json:"transformers"` + Estimators Estimators `json:"estimators"` } type RawDataset struct { @@ -53,8 +54,6 @@ type RawDataset struct { type Resource interface { userconfig.Resource GetID() string - GetIDWithTags() string - GetResourceFields() *ResourceFields } type ComputedResource interface { @@ -63,14 +62,8 @@ type ComputedResource interface { SetWorkloadID(string) } -type ValueResource interface { - Resource - GetType() interface{} -} - type ResourceFields struct { ID string `json:"id"` - IDWithTags string `json:"id_with_tags"` ResourceType resource.Type `json:"resource_type"` } @@ -83,14 +76,6 @@ func (r *ResourceFields) GetID() string { return r.ID } -func (r *ResourceFields) GetIDWithTags() string { - return r.IDWithTags -} - -func (r *ResourceFields) GetResourceFields() *ResourceFields { - return r -} - func (r *ComputedResourceFields) GetWorkloadID() string { return r.WorkloadID } @@ -101,8 +86,8 @@ func (r *ComputedResourceFields) SetWorkloadID(workloadID string) { func ExtractResourceWorkloadIDs(resources []ComputedResource) map[string]string { resourceWorkloadIDs := make(map[string]string, len(resources)) - for _, resource := range resources { - resourceWorkloadIDs[resource.GetID()] = resource.GetWorkloadID() + for _, res := range resources { + resourceWorkloadIDs[res.GetID()] = res.GetWorkloadID() } return resourceWorkloadIDs } @@ -143,8 +128,8 @@ func (ctx *Context) ComputedResources() []ComputedResource { func (ctx *Context) AllResources() []Resource { var resources []Resource - for _, resource := range ctx.ComputedResources() { - resources = append(resources, resource) + for _, res := range ctx.ComputedResources() { + resources = append(resources, res) } for _, constant := range ctx.Constants { resources = append(resources, constant) @@ -155,13 +140,16 @@ func (ctx *Context) AllResources() []Resource { for _, transformer := range ctx.Transformers { resources = append(resources, transformer) } + for _, estimator := range ctx.Estimators { + resources = append(resources, estimator) + } return resources } func (ctx *Context) ComputedResourceIDs() strset.Set { - resourceIDs := make(strset.Set) - for _, resource := range ctx.ComputedResources() { - resourceIDs.Add(resource.GetID()) + resourceIDs := strset.New() + for _, res := range ctx.ComputedResources() { + resourceIDs.Add(res.GetID()) } return resourceIDs } @@ -188,27 +176,37 @@ func (ctx *Context) ComputedResourceWorkloadIDs() strset.Set { // Note: there may be >1 resources with the ID, this returns one of them func (ctx *Context) OneResourceByID(resourceID string) Resource { - for _, resource := range ctx.AllResources() { - if resource.GetID() == resourceID { - return resource + for _, res := range ctx.AllResources() { + if res.GetID() == resourceID { + return res } } return nil } +func (ctx *Context) AllResourcesByName(name string) []Resource { + var resources []Resource + for _, res := range ctx.AllResources() { + if res.GetName() == name { + resources = append(resources, res) + } + } + return resources +} + // Overwrites any existing workload IDs func (ctx *Context) PopulateWorkloadIDs(resourceWorkloadIDs map[string]string) { - for _, resource := range ctx.ComputedResources() { - if workloadID, ok := resourceWorkloadIDs[resource.GetID()]; ok { - resource.SetWorkloadID(workloadID) + for _, res := range ctx.ComputedResources() { + if workloadID, ok := resourceWorkloadIDs[res.GetID()]; ok { + res.SetWorkloadID(workloadID) } } } func (ctx *Context) CheckAllWorkloadIDsPopulated() error { - for _, resource := range ctx.ComputedResources() { - if resource.GetWorkloadID() == "" { - return errors.New(ctx.App.Name, "resource", resource.GetID(), "workload ID is missing") // unexpected + for _, res := range ctx.ComputedResources() { + if res.GetWorkloadID() == "" { + return errors.New(ctx.App.Name, "resource", res.GetID(), "workload ID is missing") // unexpected } } return nil diff --git a/pkg/operator/api/context/dependencies.go b/pkg/operator/api/context/dependencies.go index 58a2d192f0..2b9944c9d5 100644 --- a/pkg/operator/api/context/dependencies.go +++ b/pkg/operator/api/context/dependencies.go @@ -17,7 +17,14 @@ limitations under the License. package context import ( + "sort" + + "github.com/cortexlabs/yaml" + + "github.com/cortexlabs/cortex/pkg/lib/cast" + "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/sets/strset" + "github.com/cortexlabs/cortex/pkg/operator/api/resource" ) func (ctx *Context) AllComputedResourceDependencies(resourceID string) strset.Set { @@ -64,17 +71,17 @@ func (ctx *Context) DirectComputedResourceDependencies(resourceID string) strset return ctx.apiDependencies(api) } } - return make(strset.Set) + return strset.New() } func (ctx *Context) pythonPackageDependencies(pythonPackage *PythonPackage) strset.Set { - return make(strset.Set) + return strset.New() } func (ctx *Context) rawColumnDependencies(rawColumn RawColumn) strset.Set { // Currently python packages are a dependency on raw features because raw features share // the same workload as transformed features and aggregates. - dependencies := make(strset.Set) + dependencies := strset.New() for _, pythonPackage := range ctx.PythonPackages { dependencies.Add(pythonPackage.GetID()) } @@ -82,60 +89,58 @@ func (ctx *Context) rawColumnDependencies(rawColumn RawColumn) strset.Set { } func (ctx *Context) aggregatesDependencies(aggregate *Aggregate) strset.Set { - rawColumnNames := aggregate.InputColumnNames() - dependencies := make(strset.Set, len(rawColumnNames)) + dependencies := strset.New() + for _, pythonPackage := range ctx.PythonPackages { dependencies.Add(pythonPackage.GetID()) } - for _, rawColumnName := range rawColumnNames { - rawColumn := ctx.RawColumns[rawColumnName] - dependencies.Add(rawColumn.GetID()) + + for _, res := range ctx.ExtractCortexResources(aggregate.Input) { + dependencies.Add(res.GetID()) } + return dependencies } func (ctx *Context) transformedColumnDependencies(transformedColumn *TransformedColumn) strset.Set { - dependencies := make(strset.Set) + dependencies := strset.New() for _, pythonPackage := range ctx.PythonPackages { dependencies.Add(pythonPackage.GetID()) } - rawColumnNames := transformedColumn.InputColumnNames() - for _, rawColumnName := range rawColumnNames { - rawColumn := ctx.RawColumns[rawColumnName] - dependencies.Add(rawColumn.GetID()) - } - - aggregateNames := transformedColumn.InputAggregateNames(ctx) - for aggregateName := range aggregateNames { - aggregate := ctx.Aggregates[aggregateName] - dependencies.Add(aggregate.GetID()) + for _, res := range ctx.ExtractCortexResources(transformedColumn.Input) { + dependencies.Add(res.GetID()) } return dependencies } func (ctx *Context) trainingDatasetDependencies(model *Model) strset.Set { - dependencies := make(strset.Set) - for _, columnName := range model.AllColumnNames() { - column := ctx.GetColumn(columnName) + dependencies := strset.New() + + combinedInput := []interface{}{model.Input, model.TrainingInput, model.TargetColumn} + for _, column := range ctx.ExtractCortexResources(combinedInput, resource.RawColumnType, resource.TransformedColumnType) { dependencies.Add(column.GetID()) } + return dependencies } func (ctx *Context) modelDependencies(model *Model) strset.Set { - dependencies := make(strset.Set) + dependencies := strset.New() for _, pythonPackage := range ctx.PythonPackages { dependencies.Add(pythonPackage.GetID()) } dependencies.Add(model.Dataset.ID) - for _, aggregate := range model.Aggregates { - dependencies.Add(ctx.Aggregates[aggregate].GetID()) + + combinedInput := []interface{}{model.Input, model.TrainingInput, model.TargetColumn} + for _, res := range ctx.ExtractCortexResources(combinedInput) { + dependencies.Add(res.GetID()) } + return dependencies } @@ -143,3 +148,85 @@ func (ctx *Context) apiDependencies(api *API) strset.Set { model := ctx.Models[api.ModelName] return strset.New(model.ID) } + +func (ctx *Context) ExtractCortexResources( + input interface{}, + resourceTypes ...resource.Type, // indicates which resource types to include in the query; if none are passed in, no filter is applied +) []Resource { + + return ExtractCortexResources(input, ctx.AllResources(), resourceTypes...) +} + +func ExtractCortexResources( + input interface{}, + validResources []Resource, + resourceTypes ...resource.Type, // indicates which resource types to include in the query; if none are passed in, no filter is applied +) []Resource { + + resourceTypeFilter := make(map[resource.Type]bool) + for _, resourceType := range resourceTypes { + resourceTypeFilter[resourceType] = true + } + + validResourcesMap := make(map[string][]Resource) + for _, res := range validResources { + validResourcesMap[res.GetName()] = append(validResourcesMap[res.GetName()], res) + } + + resources := make(map[string]Resource) + extractCortexResourcesHelper(input, validResourcesMap, resourceTypeFilter, resources) + + // convert to slice and sort by ID + var resourceIDs []string + for resourceId := range resources { + resourceIDs = append(resourceIDs, resourceId) + } + sort.Strings(resourceIDs) + resoucesSlice := make([]Resource, len(resources)) + for i, resourceID := range resourceIDs { + resoucesSlice[i] = resources[resourceID] + } + + return resoucesSlice +} + +func extractCortexResourcesHelper( + input interface{}, + validResourcesMap map[string][]Resource, // key is resource name + resourceTypeFilter map[resource.Type]bool, + collectedResources map[string]Resource, +) { + + if input == nil { + return + } + + if resourceName, ok := yaml.ExtractAtSymbolText(input); ok { + for _, res := range validResourcesMap[resourceName] { + foundMatch := false + if len(resourceTypeFilter) == 0 || resourceTypeFilter[res.GetResourceType()] == true { + if foundMatch { + errors.Panic("found multiple resources with the same name", resourceName) // unexpected + } + collectedResources[res.GetID()] = res + foundMatch = true + } + } + return + } + + if inputSlice, ok := cast.InterfaceToInterfaceSlice(input); ok { + for _, elem := range inputSlice { + extractCortexResourcesHelper(elem, validResourcesMap, resourceTypeFilter, collectedResources) + } + return + } + + if inputMap, ok := cast.InterfaceToInterfaceInterfaceMap(input); ok { + for key, val := range inputMap { + extractCortexResourcesHelper(key, validResourcesMap, resourceTypeFilter, collectedResources) + extractCortexResourcesHelper(val, validResourcesMap, resourceTypeFilter, collectedResources) + } + return + } +} diff --git a/pkg/operator/api/context/dependencies_test.go b/pkg/operator/api/context/dependencies_test.go new file mode 100644 index 0000000000..46a1fc992f --- /dev/null +++ b/pkg/operator/api/context/dependencies_test.go @@ -0,0 +1,78 @@ +/* +Copyright 2019 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package context + +import ( + "testing" + + "github.com/stretchr/testify/require" + + cr "github.com/cortexlabs/cortex/pkg/lib/configreader" + "github.com/cortexlabs/cortex/pkg/operator/api/resource" +) + +func TestExtractCortexResources(t *testing.T) { + var resources []Resource + + resources = ExtractCortexResources(cr.MustReadYAMLStr(`@rc1`), rawCols) + require.Equal(t, []Resource{rc1}, resources) + + resources = ExtractCortexResources(cr.MustReadYAMLStr(`@rc1`), nil) + require.Equal(t, []Resource{}, resources) + + resources = ExtractCortexResources(cr.MustReadYAMLStr(`@rc1`), rawCols, resource.TransformedColumnType) + require.Equal(t, []Resource{}, resources) + + resources = ExtractCortexResources(cr.MustReadYAMLStr(`[@rc1, rc2, @tc1]`), allResources) + require.Equal(t, []Resource{rc1, tc1}, resources) + + resources = ExtractCortexResources(cr.MustReadYAMLStr(`[@rc1, rc2, @tc1]`), allResources, resource.RawColumnType) + require.Equal(t, []Resource{rc1}, resources) + + resources = ExtractCortexResources(cr.MustReadYAMLStr(`[@rc1, rc2, @tc1]`), allResources, resource.TransformedColumnType) + require.Equal(t, []Resource{tc1}, resources) + + // Check sorted by ID + resources = ExtractCortexResources(cr.MustReadYAMLStr(`[@tc1, rc2, @rc1]`), allResources) + require.Equal(t, []Resource{rc1, tc1}, resources) + + // Check duplicates + resources = ExtractCortexResources(cr.MustReadYAMLStr(`[@rc1, rc2, @tc1, @rc1]`), allResources) + require.Equal(t, []Resource{rc1, tc1}, resources) + + mixedInput := cr.MustReadYAMLStr( + ` + map: {@agg1: @c1} + str: @rc1 + floats: [@tc2] + map2: + map3: + lat: @c2 + lon: + @c3: agg2 + b: [@tc1, @agg3] + `) + + resources = ExtractCortexResources(mixedInput, allResources) + require.Equal(t, []Resource{c1, c2, c3, rc1, agg1, agg3, tc1, tc2}, resources) + + resources = ExtractCortexResources(mixedInput, allResources, resource.AggregateType) + require.Equal(t, []Resource{agg1, agg3}, resources) + + resources = ExtractCortexResources(mixedInput, allResources, resource.AggregateType, resource.TransformedColumnType) + require.Equal(t, []Resource{agg1, agg3, tc1, tc2}, resources) +} diff --git a/pkg/operator/api/context/values.go b/pkg/operator/api/context/estimators.go similarity index 62% rename from pkg/operator/api/context/values.go rename to pkg/operator/api/context/estimators.go index b0c247ce68..69f91d0150 100644 --- a/pkg/operator/api/context/values.go +++ b/pkg/operator/api/context/estimators.go @@ -17,22 +17,23 @@ limitations under the License. package context import ( - "github.com/cortexlabs/cortex/pkg/operator/api/resource" "github.com/cortexlabs/cortex/pkg/operator/api/userconfig" ) -func GetValueResource( - name string, - constants Constants, - aggregates Aggregates, -) (ValueResource, error) { +type Estimators map[string]*Estimator - if constant, ok := constants[name]; ok { - return constant, nil - } - if aggregate, ok := aggregates[name]; ok { - return aggregate, nil - } +type Estimator struct { + *userconfig.Estimator + *ResourceFields + Namespace *string `json:"namespace"` + ImplKey string `json:"impl_key"` +} - return nil, userconfig.ErrorUndefinedResource(name, resource.ConstantType, resource.AggregateType) +func (estimators Estimators) OneByID(id string) *Estimator { + for _, estimator := range estimators { + if estimator.ID == id { + return estimator + } + } + return nil } diff --git a/pkg/operator/api/context/models.go b/pkg/operator/api/context/models.go index 9df19a0bc2..401300dc3c 100644 --- a/pkg/operator/api/context/models.go +++ b/pkg/operator/api/context/models.go @@ -17,10 +17,6 @@ limitations under the License. package context import ( - "sort" - - "github.com/cortexlabs/cortex/pkg/lib/configreader" - "github.com/cortexlabs/cortex/pkg/lib/sets/strset" "github.com/cortexlabs/cortex/pkg/operator/api/resource" "github.com/cortexlabs/cortex/pkg/operator/api/userconfig" ) @@ -32,8 +28,6 @@ type Model struct { *userconfig.Model *ComputedResourceFields Key string `json:"key"` - ImplID string `json:"impl_id"` - ImplKey string `json:"impl_key"` Dataset *TrainingDataset `json:"dataset"` } @@ -74,31 +68,3 @@ func (models Models) GetTrainingDatasets() TrainingDatasets { } return trainingDatasets } - -func ValidateModelTargetType(targetType userconfig.ColumnType, modelType userconfig.ModelType) error { - switch modelType { - case userconfig.ClassificationModelType: - if targetType != userconfig.IntegerColumnType { - return userconfig.ErrorClassificationTargetType() - } - return nil - case userconfig.RegressionModelType: - if targetType != userconfig.IntegerColumnType && targetType != userconfig.FloatColumnType { - return userconfig.ErrorRegressionTargetType() - } - return nil - } - - return configreader.ErrorInvalidStr(modelType.String(), "classification", "regression") // unexpected -} - -func (ctx *Context) RawColumnInputNames(model *Model) []string { - rawColumnInputNames := strset.New() - for _, colName := range model.FeatureColumns { - col := ctx.GetColumn(colName) - rawColumnInputNames.Add(col.GetInputRawColumnNames()...) - } - columnNames := rawColumnInputNames.Slice() - sort.Strings(columnNames) - return columnNames -} diff --git a/pkg/operator/api/context/raw_columns.go b/pkg/operator/api/context/raw_columns.go index 6a6830a19a..02f61ccfcf 100644 --- a/pkg/operator/api/context/raw_columns.go +++ b/pkg/operator/api/context/raw_columns.go @@ -17,8 +17,6 @@ limitations under the License. package context import ( - "github.com/cortexlabs/cortex/pkg/lib/cast" - "github.com/cortexlabs/cortex/pkg/lib/hash" "github.com/cortexlabs/cortex/pkg/operator/api/userconfig" ) @@ -58,33 +56,8 @@ func (rawColumns RawColumns) OneByID(id string) RawColumn { return nil } -func (rawColumns RawColumns) columnInputsID(columnInputValues map[string]interface{}, includeTags bool) string { - columnIDMap := make(map[string]string) - for columnInputName, columnInputValue := range columnInputValues { - if columnName, ok := columnInputValue.(string); ok { - if includeTags { - columnIDMap[columnInputName] = rawColumns[columnName].GetIDWithTags() - } else { - columnIDMap[columnInputName] = rawColumns[columnName].GetID() - } - } - if columnNames, ok := cast.InterfaceToStrSlice(columnInputValue); ok { - var columnIDs string - for _, columnName := range columnNames { - if includeTags { - columnIDs = columnIDs + rawColumns[columnName].GetIDWithTags() - } else { - columnIDs = columnIDs + rawColumns[columnName].GetID() - } - } - columnIDMap[columnInputName] = columnIDs - } - } - return hash.Any(columnIDMap) -} - func GetRawColumnUserConfig(rawColumn RawColumn) userconfig.Resource { - switch rawColumn.GetType() { + switch rawColumn.GetColumnType() { case userconfig.IntegerColumnType: return rawColumn.(*RawIntColumn).RawIntColumn case userconfig.FloatColumnType: @@ -97,27 +70,3 @@ func GetRawColumnUserConfig(rawColumn RawColumn) userconfig.Resource { return nil } - -func (rawColumns RawColumns) ColumnInputsID(columnInputValues map[string]interface{}) string { - return rawColumns.columnInputsID(columnInputValues, false) -} - -func (rawColumns RawColumns) ColumnInputsIDWithTags(columnInputValues map[string]interface{}) string { - return rawColumns.columnInputsID(columnInputValues, true) -} - -func (rawColumn *RawIntColumn) GetInputRawColumnNames() []string { - return []string{rawColumn.GetName()} -} - -func (rawColumn *RawFloatColumn) GetInputRawColumnNames() []string { - return []string{rawColumn.GetName()} -} - -func (rawColumn *RawStringColumn) GetInputRawColumnNames() []string { - return []string{rawColumn.GetName()} -} - -func (rawColumn *RawInferredColumn) GetInputRawColumnNames() []string { - return []string{rawColumn.GetName()} -} diff --git a/pkg/operator/api/context/resource_fakes_test.go b/pkg/operator/api/context/resource_fakes_test.go new file mode 100644 index 0000000000..1a09ba178e --- /dev/null +++ b/pkg/operator/api/context/resource_fakes_test.go @@ -0,0 +1,136 @@ +/* +Copyright 2019 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package context + +import ( + "github.com/cortexlabs/cortex/pkg/operator/api/resource" + "github.com/cortexlabs/cortex/pkg/operator/api/userconfig" +) + +type fakeResource struct { + Name string + ID string + ResourceType resource.Type +} + +func (r *fakeResource) GetID() string { + return r.ID +} + +func (r *fakeResource) GetName() string { + return r.Name +} + +func (r *fakeResource) GetResourceType() resource.Type { + return r.ResourceType +} + +func (r *fakeResource) GetIndex() int { + return -1 +} + +func (r *fakeResource) SetIndex(int) {} + +func (r *fakeResource) GetFilePath() string { + return "path/to/file" +} + +func (r *fakeResource) SetFilePath(string) {} + +func (r *fakeResource) GetEmbed() *userconfig.Embed { + return nil +} + +func (r *fakeResource) SetEmbed(*userconfig.Embed) {} + +var c1 = &fakeResource{ + Name: "c1", + ID: "a_c1", + ResourceType: resource.ConstantType, +} + +var c2 = &fakeResource{ + Name: "c2", + ID: "b_c2", + ResourceType: resource.ConstantType, +} + +var c3 = &fakeResource{ + Name: "c3", + ID: "c_c3", + ResourceType: resource.ConstantType, +} + +var rc1 = &fakeResource{ + Name: "rc1", + ID: "d_rc1", + ResourceType: resource.RawColumnType, +} + +var rc2 = &fakeResource{ + Name: "rc2", + ID: "e_rc2", + ResourceType: resource.RawColumnType, +} + +var rc3 = &fakeResource{ + Name: "rc3", + ID: "f_rc3", + ResourceType: resource.RawColumnType, +} + +var agg1 = &fakeResource{ + Name: "agg1", + ID: "g_agg1", + ResourceType: resource.AggregateType, +} + +var agg2 = &fakeResource{ + Name: "agg2", + ID: "h_agg2", + ResourceType: resource.AggregateType, +} + +var agg3 = &fakeResource{ + Name: "agg3", + ID: "i_agg3", + ResourceType: resource.AggregateType, +} + +var tc1 = &fakeResource{ + Name: "tc1", + ID: "j_tc1", + ResourceType: resource.TransformedColumnType, +} + +var tc2 = &fakeResource{ + Name: "tc2", + ID: "k_tc2", + ResourceType: resource.TransformedColumnType, +} + +var tc3 = &fakeResource{ + Name: "tc3", + ID: "l_tc3", + ResourceType: resource.TransformedColumnType, +} + +var constants = []Resource{c1, c2, c3} +var rawCols = []Resource{rc1, rc2, rc3} +var aggregates = []Resource{agg1, agg2, agg3} +var transformedCols = []Resource{tc1, tc2, tc3} +var allResources = []Resource{c1, c2, c3, rc1, rc2, rc3, agg1, agg2, agg3, tc1, tc2, tc3} diff --git a/pkg/operator/api/context/transformed_columns.go b/pkg/operator/api/context/transformed_columns.go index 5b6f9a0dc4..006ab39d5d 100644 --- a/pkg/operator/api/context/transformed_columns.go +++ b/pkg/operator/api/context/transformed_columns.go @@ -17,8 +17,6 @@ limitations under the License. package context import ( - "github.com/cortexlabs/cortex/pkg/lib/cast" - "github.com/cortexlabs/cortex/pkg/lib/sets/strset" "github.com/cortexlabs/cortex/pkg/operator/api/userconfig" ) @@ -30,26 +28,10 @@ type TransformedColumn struct { Type userconfig.ColumnType `json:"type"` } -func (column *TransformedColumn) GetType() userconfig.ColumnType { +func (column *TransformedColumn) GetColumnType() userconfig.ColumnType { return column.Type } -// Returns map[string]string because after autogen, arg values are constant or aggregate names -func (column *TransformedColumn) Args() map[string]string { - args, _ := cast.InterfaceToStrStrMap(column.Inputs.Args) - return args -} - -func (column *TransformedColumn) InputAggregateNames(ctx *Context) strset.Set { - inputAggregateNames := strset.New() - for _, valueResourceName := range column.Args() { - if _, ok := ctx.Aggregates[valueResourceName]; ok { - inputAggregateNames.Add(valueResourceName) - } - } - return inputAggregateNames -} - func (columns TransformedColumns) OneByID(id string) *TransformedColumn { for _, transformedColumn := range columns { if transformedColumn.ID == id { @@ -58,7 +40,3 @@ func (columns TransformedColumns) OneByID(id string) *TransformedColumn { } return nil } - -func (column *TransformedColumn) GetInputRawColumnNames() []string { - return column.InputColumnNames() -} diff --git a/pkg/operator/api/resource/type.go b/pkg/operator/api/resource/type.go index c7544f105c..8dd1b2f0e4 100644 --- a/pkg/operator/api/resource/type.go +++ b/pkg/operator/api/resource/type.go @@ -31,9 +31,10 @@ const ( AggregateType // 4 APIType // 5 ModelType // 6 - EnvironmentType // 8 - AggregatorType // 9 - TransformerType // 10 + EnvironmentType // 7 + AggregatorType // 8 + TransformerType // 9 + EstimatorType // 10 TemplateType // 11 EmbedType // 12 TrainingDatasetType // 13 @@ -53,6 +54,7 @@ var ( "environment", "aggregator", "transformer", + "estimator", "template", "embed", "training_dataset", @@ -71,6 +73,7 @@ var ( "environments", "aggregators", "transformers", + "estimators", "templates", "embeds", "training_datasets", diff --git a/pkg/operator/api/userconfig/aggregates.go b/pkg/operator/api/userconfig/aggregates.go index 613ad89b3a..e3137bab26 100644 --- a/pkg/operator/api/userconfig/aggregates.go +++ b/pkg/operator/api/userconfig/aggregates.go @@ -17,9 +17,8 @@ limitations under the License. package userconfig import ( - "sort" - "github.com/cortexlabs/cortex/pkg/lib/configreader" + "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/operator/api/resource" ) @@ -29,7 +28,7 @@ type Aggregate struct { ResourceFields Aggregator string `json:"aggregator" yaml:"aggregator"` AggregatorPath *string `json:"aggregator_path" yaml:"aggregator_path"` - Inputs *Inputs `json:"inputs" yaml:"inputs"` + Input interface{} `json:"input" yaml:"input"` Compute *SparkCompute `json:"compute" yaml:"compute"` Tags Tags `json:"tags" yaml:"tags"` } @@ -54,7 +53,13 @@ var aggregateValidation = &configreader.StructValidation{ StructField: "AggregatorPath", StringPtrValidation: &configreader.StringPtrValidation{}, }, - inputValuesFieldValidation, + { + StructField: "Input", + InterfaceValidation: &configreader.InterfaceValidation{ + Required: true, + AllowCortexResources: true, + }, + }, sparkComputeFieldValidation("Compute"), tagsFieldValidation, typeFieldValidation, @@ -62,6 +67,12 @@ var aggregateValidation = &configreader.StructValidation{ } func (aggregates Aggregates) Validate() error { + for _, aggregate := range aggregates { + if err := aggregate.Validate(); err != nil { + return err + } + } + resources := make([]Resource, len(aggregates)) for i, res := range aggregates { resources[i] = res @@ -71,6 +82,19 @@ func (aggregates Aggregates) Validate() error { if len(dups) > 0 { return ErrorDuplicateResourceName(dups...) } + + return nil +} + +func (aggregate *Aggregate) Validate() error { + if aggregate.AggregatorPath == nil && aggregate.Aggregator == "" { + return errors.Wrap(ErrorSpecifyOnlyOneMissing("aggregator", "aggregator_path"), Identify(aggregate)) + } + + if aggregate.AggregatorPath != nil && aggregate.Aggregator != "" { + return errors.Wrap(ErrorSpecifyOnlyOne("aggregator", "aggregator_path"), Identify(aggregate)) + } + return nil } @@ -94,9 +118,3 @@ func (aggregates Aggregates) Get(name string) *Aggregate { } return nil } - -func (aggregate *Aggregate) InputColumnNames() []string { - inputs, _ := configreader.FlattenAllStrValues(aggregate.Inputs.Columns) - sort.Strings(inputs) - return inputs -} diff --git a/pkg/operator/api/userconfig/aggregators.go b/pkg/operator/api/userconfig/aggregators.go index 3eefddf752..8b7f429711 100644 --- a/pkg/operator/api/userconfig/aggregators.go +++ b/pkg/operator/api/userconfig/aggregators.go @@ -25,9 +25,9 @@ type Aggregators []*Aggregator type Aggregator struct { ResourceFields - Inputs *Inputs `json:"inputs" yaml:"inputs"` - OutputType interface{} `json:"output_type" yaml:"output_type"` - Path string `json:"path" yaml:"path"` + Input *InputSchema `json:"input" yaml:"input"` + OutputType OutputSchema `json:"output_type" yaml:"output_type"` + Path string `json:"path" yaml:"path"` } var aggregatorValidation = &cr.StructValidation{ @@ -51,12 +51,18 @@ var aggregatorValidation = &cr.StructValidation{ StructField: "OutputType", InterfaceValidation: &cr.InterfaceValidation{ Required: true, - Validator: func(outputType interface{}) (interface{}, error) { - return outputType, ValidateValueType(outputType) + Validator: func(t interface{}) (interface{}, error) { + return ValidateOutputSchema(t) }, }, }, - inputTypesFieldValidation, + { + StructField: "Input", + InterfaceValidation: &cr.InterfaceValidation{ + Required: true, + Validator: inputSchemaValidator, + }, + }, typeFieldValidation, }, } @@ -71,6 +77,7 @@ func (aggregators Aggregators) Validate() error { if len(dups) > 0 { return ErrorDuplicateResourceName(dups...) } + return nil } diff --git a/pkg/operator/api/userconfig/apis.go b/pkg/operator/api/userconfig/apis.go index 854dfee201..e7cd27909d 100644 --- a/pkg/operator/api/userconfig/apis.go +++ b/pkg/operator/api/userconfig/apis.go @@ -17,6 +17,8 @@ limitations under the License. package userconfig import ( + "github.com/cortexlabs/yaml" + cr "github.com/cortexlabs/cortex/pkg/lib/configreader" "github.com/cortexlabs/cortex/pkg/operator/api/resource" ) @@ -25,9 +27,9 @@ type APIs []*API type API struct { ResourceFields - ModelName string `json:"model_name" yaml:"model_name"` - Compute *APICompute `json:"compute" yaml:"compute"` - Tags Tags `json:"tags" yaml:"tags"` + Model string `json:"model" yaml:"model"` + Compute *APICompute `json:"compute" yaml:"compute"` + Tags Tags `json:"tags" yaml:"tags"` } var apiValidation = &cr.StructValidation{ @@ -40,11 +42,16 @@ var apiValidation = &cr.StructValidation{ }, }, { - StructField: "ModelName", + StructField: "Model", DefaultField: "Name", + DefaultFieldFunc: func(name interface{}) interface{} { + model := "@" + name.(string) + escapedModel, _ := yaml.EscapeAtSymbol(model) + return escapedModel + }, StringValidation: &cr.StringValidation{ - Required: false, - AlphaNumericDashUnderscore: true, + Required: false, + RequireCortexResources: true, }, }, apiComputeFieldValidation, diff --git a/pkg/operator/api/userconfig/column_type.go b/pkg/operator/api/userconfig/column_type.go index 9195df82be..53fff322d2 100644 --- a/pkg/operator/api/userconfig/column_type.go +++ b/pkg/operator/api/userconfig/column_type.go @@ -25,35 +25,35 @@ type ColumnTypes []ColumnType const ( UnknownColumnType ColumnType = iota + InferredColumnType IntegerColumnType FloatColumnType StringColumnType IntegerListColumnType FloatListColumnType StringListColumnType - InferredColumnType ) var columnTypes = []string{ "unknown", + "INFERRED_COLUMN", "INT_COLUMN", "FLOAT_COLUMN", "STRING_COLUMN", "INT_LIST_COLUMN", "FLOAT_LIST_COLUMN", "STRING_LIST_COLUMN", - "INFERRED_COLUMN", } var columnJSONPlaceholders = []string{ "_", + "VALUE", "INT", "FLOAT", "\"STRING\"", "[INT]", "[FLOAT]", "[\"STRING\"]", - "VALUE", } func ColumnTypeFromString(s string) ColumnType { @@ -65,8 +65,8 @@ func ColumnTypeFromString(s string) ColumnType { return UnknownColumnType } -func ColumnTypeStrings() []string { - return columnTypes[1:] +func ValidColumnTypeStrings() []string { + return columnTypes[2:] } func (t ColumnType) String() string { diff --git a/pkg/operator/api/userconfig/columns.go b/pkg/operator/api/userconfig/columns.go index a27f36c313..e31188b0fd 100644 --- a/pkg/operator/api/userconfig/columns.go +++ b/pkg/operator/api/userconfig/columns.go @@ -17,12 +17,7 @@ limitations under the License. package userconfig import ( - "github.com/cortexlabs/cortex/pkg/lib/cast" - "github.com/cortexlabs/cortex/pkg/lib/configreader" - "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/slices" - s "github.com/cortexlabs/cortex/pkg/lib/strings" - "github.com/cortexlabs/cortex/pkg/operator/api/resource" ) type Column interface { @@ -30,71 +25,6 @@ type Column interface { IsRaw() bool } -func (config *Config) ValidateColumns() error { - columnResources := make([]Resource, len(config.RawColumns)+len(config.TransformedColumns)) - for i, res := range config.RawColumns { - columnResources[i] = res - } - - for i, res := range config.TransformedColumns { - columnResources[i+len(config.RawColumns)] = res - } - - dups := FindDuplicateResourceName(columnResources...) - if len(dups) > 0 { - return ErrorDuplicateResourceName(dups...) - } - - for _, aggregate := range config.Aggregates { - err := ValidateColumnInputsExistAndRaw(aggregate.Inputs.Columns, config) - if err != nil { - return errors.Wrap(err, Identify(aggregate), InputsKey, ColumnsKey) - } - } - - for _, transformedColumn := range config.TransformedColumns { - err := ValidateColumnInputsExistAndRaw(transformedColumn.Inputs.Columns, config) - if err != nil { - return errors.Wrap(err, Identify(transformedColumn), InputsKey, ColumnsKey) - } - } - - return nil -} - -func ValidateColumnInputsExistAndRaw(columnInputValues map[string]interface{}, config *Config) error { - for columnInputName, columnInputValue := range columnInputValues { - if columnName, ok := columnInputValue.(string); ok { - err := ValidateColumnNameExistsAndRaw(columnName, config) - if err != nil { - return errors.Wrap(err, columnInputName) - } - continue - } - if columnNames, ok := cast.InterfaceToStrSlice(columnInputValue); ok { - for i, columnName := range columnNames { - err := ValidateColumnNameExistsAndRaw(columnName, config) - if err != nil { - return errors.Wrap(err, columnInputName, s.Index(i)) - } - } - continue - } - return errors.Wrap(configreader.ErrorInvalidPrimitiveType(columnInputValue, configreader.PrimTypeString, configreader.PrimTypeStringList), columnInputName) // unexpected - } - return nil -} - -func ValidateColumnNameExistsAndRaw(columnName string, config *Config) error { - if config.IsTransformedColumn(columnName) { - return ErrorColumnMustBeRaw(columnName) - } - if !config.IsRawColumn(columnName) { - return ErrorUndefinedResource(columnName, resource.RawColumnType) - } - return nil -} - func (config *Config) ColumnNames() []string { return append(config.RawColumns.Names(), config.TransformedColumns.Names()...) } diff --git a/pkg/operator/api/userconfig/compound_type.go b/pkg/operator/api/userconfig/compound_type.go new file mode 100644 index 0000000000..0620d7ee66 --- /dev/null +++ b/pkg/operator/api/userconfig/compound_type.go @@ -0,0 +1,174 @@ +/* +Copyright 2019 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package userconfig + +import ( + "strings" + + "github.com/cortexlabs/cortex/pkg/lib/cast" + "github.com/cortexlabs/cortex/pkg/lib/configreader" +) + +type CompoundType string + +type compoundTypeParsed struct { + Original string + columnTypes map[ColumnType]bool + valueTypes map[ValueType]bool +} + +func CompoundTypeFromString(val interface{}) (CompoundType, error) { + parsed, err := parseCompoundType(val) + if err != nil { + return "", err + } + return CompoundType(parsed.Original), nil +} + +func parseCompoundType(val interface{}) (*compoundTypeParsed, error) { + compoundStr, ok := val.(string) + if !ok { + return nil, ErrorInvalidCompoundType(val) + } + + parsed := &compoundTypeParsed{ + Original: compoundStr, + columnTypes: map[ColumnType]bool{}, + valueTypes: map[ValueType]bool{}, + } + + for _, str := range strings.Split(compoundStr, "|") { + columnType := ColumnTypeFromString(str) + if columnType != UnknownColumnType && columnType != InferredColumnType { + if parsed.columnTypes[columnType] == true { + return nil, ErrorDuplicateTypeInTypeString(columnType.String(), compoundStr) + } + parsed.columnTypes[columnType] = true + continue + } + + valueType := ValueTypeFromString(str) + if valueType != UnknownValueType { + if parsed.valueTypes[valueType] == true { + return nil, ErrorDuplicateTypeInTypeString(valueType.String(), compoundStr) + } + parsed.valueTypes[valueType] = true + continue + } + + return nil, ErrorInvalidCompoundType(compoundStr) + } + + if len(parsed.columnTypes) == 0 && len(parsed.valueTypes) == 0 { + return nil, ErrorInvalidCompoundType("") + } + + if len(parsed.columnTypes) > 0 && len(parsed.valueTypes) > 0 { + return nil, ErrorCannotMixValueAndColumnTypes(compoundStr) + } + + return parsed, nil +} + +func (compoundType *CompoundType) String() string { + return string(*compoundType) +} + +func (compoundType *CompoundType) IsColumns() bool { + parsed, _ := parseCompoundType(string(*compoundType)) + return len(parsed.columnTypes) > 0 +} + +func (compoundType *CompoundType) IsValues() bool { + parsed, _ := parseCompoundType(string(*compoundType)) + return len(parsed.valueTypes) > 0 +} + +func (compoundType *CompoundType) SupportsType(t interface{}) bool { + parsed, _ := parseCompoundType(string(*compoundType)) + + if columnType, ok := t.(ColumnType); ok { + if columnType == IntegerColumnType { + return parsed.columnTypes[IntegerColumnType] || parsed.columnTypes[FloatColumnType] + } + if columnType == IntegerListColumnType { + return parsed.columnTypes[IntegerListColumnType] || parsed.columnTypes[FloatListColumnType] + } + return parsed.columnTypes[columnType] || columnType == InferredColumnType + } + + if valueType, ok := t.(ValueType); ok { + if valueType == IntegerValueType { + return parsed.valueTypes[IntegerValueType] || parsed.valueTypes[FloatValueType] + } + return parsed.valueTypes[valueType] + } + + if typeStr, ok := t.(string); ok { + columnType := ColumnTypeFromString(typeStr) + if columnType != UnknownColumnType { + return compoundType.SupportsType(columnType) + } + valueType := ValueTypeFromString(typeStr) + if valueType != UnknownValueType { + return compoundType.SupportsType(valueType) + } + } + + return false +} + +func (compoundType *CompoundType) CastValue(value interface{}) (interface{}, error) { + parsed, _ := parseCompoundType(string(*compoundType)) + if len(parsed.columnTypes) > 0 { + return nil, ErrorColumnTypeLiteral(value) + } + + var validPrimitiveTypes []configreader.PrimitiveType + + if parsed.valueTypes[IntegerValueType] { + validPrimitiveTypes = append(validPrimitiveTypes, configreader.PrimTypeInt) + valueInt, ok := cast.InterfaceToInt64(value) + if ok { + return valueInt, nil + } + } + + if parsed.valueTypes[FloatValueType] { + validPrimitiveTypes = append(validPrimitiveTypes, configreader.PrimTypeFloat) + valueFloat, ok := cast.InterfaceToFloat64(value) + if ok { + return valueFloat, nil + } + } + + if parsed.valueTypes[StringValueType] { + validPrimitiveTypes = append(validPrimitiveTypes, configreader.PrimTypeString) + if valueStr, ok := value.(string); ok { + return valueStr, nil + } + } + + if parsed.valueTypes[BoolValueType] { + validPrimitiveTypes = append(validPrimitiveTypes, configreader.PrimTypeBool) + if valueBool, ok := value.(bool); ok { + return valueBool, nil + } + } + + return nil, ErrorUnsupportedLiteralType(value, compoundType.String()) +} diff --git a/pkg/operator/api/userconfig/compound_type_test.go b/pkg/operator/api/userconfig/compound_type_test.go new file mode 100644 index 0000000000..162699687d --- /dev/null +++ b/pkg/operator/api/userconfig/compound_type_test.go @@ -0,0 +1,84 @@ +/* +Copyright 2019 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package userconfig + +import ( + "testing" + + "github.com/stretchr/testify/require" + + cr "github.com/cortexlabs/cortex/pkg/lib/configreader" +) + +func TestCompoundTypeFromString(t *testing.T) { + var err error + + _, err = CompoundTypeFromString("STRING") + require.Nil(t, err) + _, err = CompoundTypeFromString("INT") + require.Nil(t, err) + _, err = CompoundTypeFromString("BOOL|FLOAT") + require.Nil(t, err) + + _, err = CompoundTypeFromString("STRING_COLUMN") + require.Nil(t, err) + _, err = CompoundTypeFromString("INT_COLUMN") + require.Nil(t, err) + _, err = CompoundTypeFromString("INT_COLUMN|FLOAT_COLUMN") + require.Nil(t, err) + + _, err = CompoundTypeFromString("INT|INT") + require.Error(t, err) + + _, err = CompoundTypeFromString("INT|FLOAT_COLUMN") + require.Error(t, err) + + _, err = CompoundTypeFromString("test") + require.Error(t, err) + _, err = CompoundTypeFromString("|") + require.Error(t, err) + _, err = CompoundTypeFromString("") + require.Error(t, err) + _, err = CompoundTypeFromString(2) + require.Error(t, err) + + require.Error(t, err) +} + +func checkCompoundCastValueEqual(t *testing.T, typeStr string, valueYAML string, expected interface{}) { + compoundType, err := CompoundTypeFromString(typeStr) + require.NoError(t, err) + casted, err := compoundType.CastValue(cr.MustReadYAMLStr(valueYAML)) + require.NoError(t, err) + require.Equal(t, casted, expected) +} + +func checkCompoundCastValueError(t *testing.T, typeStr string, valueYAML string) { + compoundType, err := CompoundTypeFromString(typeStr) + require.NoError(t, err) + _, err = compoundType.CastValue(cr.MustReadYAMLStr(valueYAML)) + require.Error(t, err) +} + +func TestCompoundCastValue(t *testing.T) { + checkCompoundCastValueEqual(t, `INT|FLOAT`, `2`, int64(2)) + checkCompoundCastValueError(t, `STRING`, `2`) + checkCompoundCastValueEqual(t, `FLOAT`, `2`, float64(2)) + checkCompoundCastValueEqual(t, `STRING|FLOAT`, `2`, float64(2)) + checkCompoundCastValueError(t, `BOOL`, `2`) + checkCompoundCastValueEqual(t, `BOOL`, `true`, true) +} diff --git a/pkg/operator/api/userconfig/config.go b/pkg/operator/api/userconfig/config.go index e731697ac6..49afdca6e2 100644 --- a/pkg/operator/api/userconfig/config.go +++ b/pkg/operator/api/userconfig/config.go @@ -19,7 +19,6 @@ package userconfig import ( "fmt" "io/ioutil" - "strings" "github.com/cortexlabs/cortex/pkg/lib/cast" "github.com/cortexlabs/cortex/pkg/lib/configreader" @@ -42,9 +41,11 @@ type Config struct { APIs APIs `json:"apis" yaml:"apis"` Aggregators Aggregators `json:"aggregators" yaml:"aggregators"` Transformers Transformers `json:"transformers" yaml:"transformers"` + Estimators Estimators `json:"estimators" yaml:"estimators"` Constants Constants `json:"constants" yaml:"constants"` Templates Templates `json:"templates" yaml:"templates"` Embeds Embeds `json:"embeds" yaml:"embeds"` + Resources map[string][]Resource } var typeFieldValidation = &cr.StructFieldValidation{ @@ -61,6 +62,7 @@ func mergeConfigs(target *Config, source *Config) error { target.APIs = append(target.APIs, source.APIs...) target.Aggregators = append(target.Aggregators, source.Aggregators...) target.Transformers = append(target.Transformers, source.Transformers...) + target.Estimators = append(target.Estimators, source.Estimators...) target.Constants = append(target.Constants, source.Constants...) target.Templates = append(target.Templates, source.Templates...) target.Embeds = append(target.Embeds, source.Embeds...) @@ -121,6 +123,11 @@ func (config *Config) ValidatePartial() error { return err } } + if config.Estimators != nil { + if err := config.Estimators.Validate(); err != nil { + return err + } + } if config.Constants != nil { if err := config.Constants.Validate(); err != nil { return err @@ -145,15 +152,29 @@ func (config *Config) Validate(envName string) error { return ErrorUndefinedConfig(resource.AppType) } - err = config.ValidateColumns() - if err != nil { - return err + // Check for duplicate names across types that must have unique names + var resources []Resource + for _, res := range config.RawColumns { + resources = append(resources, res) + } + for _, res := range config.TransformedColumns { + resources = append(resources, res) + } + for _, res := range config.Constants { + resources = append(resources, res) + } + for _, res := range config.Aggregates { + resources = append(resources, res) + } + dups := FindDuplicateResourceName(resources...) + if len(dups) > 0 { + return ErrorDuplicateResourceName(dups...) } // Check ingested columns match raw columns rawColumnNames := config.RawColumns.Names() for _, env := range config.Environments { - ingestedColumnNames := env.Data.GetIngestedColumns() + ingestedColumnNames := env.Data.GetIngestedColumnNames() missingColumnNames := slices.SubtractStrSlice(rawColumnNames, ingestedColumnNames) if len(missingColumnNames) > 0 { return errors.Wrap(ErrorRawColumnNotInEnv(env.Name), Identify(config.RawColumns.Get(missingColumnNames[0]))) @@ -164,78 +185,6 @@ func (config *Config) Validate(envName string) error { } } - // Check model columns exist - columnNames := config.ColumnNames() - for _, model := range config.Models { - if !slices.HasString(columnNames, model.TargetColumn) { - return errors.Wrap(ErrorUndefinedResource(model.TargetColumn, resource.RawColumnType, resource.TransformedColumnType), - Identify(model), TargetColumnKey) - } - missingColumnNames := slices.SubtractStrSlice(model.FeatureColumns, columnNames) - if len(missingColumnNames) > 0 { - return errors.Wrap(ErrorUndefinedResource(missingColumnNames[0], resource.RawColumnType, resource.TransformedColumnType), - Identify(model), FeatureColumnsKey) - } - - missingAggregateNames := slices.SubtractStrSlice(model.Aggregates, config.Aggregates.Names()) - if len(missingAggregateNames) > 0 { - return errors.Wrap(ErrorUndefinedResource(missingAggregateNames[0], resource.AggregateType), - Identify(model), AggregatesKey) - } - - // check training columns - missingTrainingColumnNames := slices.SubtractStrSlice(model.TrainingColumns, columnNames) - if len(missingTrainingColumnNames) > 0 { - return errors.Wrap(ErrorUndefinedResource(missingTrainingColumnNames[0], resource.RawColumnType, resource.TransformedColumnType), - Identify(model), TrainingColumnsKey) - } - } - - // Check api models exist - modelNames := config.Models.Names() - for _, api := range config.APIs { - if !slices.HasString(modelNames, api.ModelName) { - return errors.Wrap(ErrorUndefinedResource(api.ModelName, resource.ModelType), - Identify(api), ModelNameKey) - } - } - - // Check local aggregators exist or a path to one is defined - aggregatorNames := config.Aggregators.Names() - for _, aggregate := range config.Aggregates { - if aggregate.AggregatorPath == nil && aggregate.Aggregator == "" { - return errors.Wrap(ErrorSpecifyOnlyOneMissing("aggregator", "aggregator_path"), Identify(aggregate)) - } - - if aggregate.AggregatorPath != nil && aggregate.Aggregator != "" { - return errors.Wrap(ErrorSpecifyOnlyOne("aggregator", "aggregator_path"), Identify(aggregate)) - } - - if aggregate.Aggregator != "" && - !strings.Contains(aggregate.Aggregator, ".") && - !slices.HasString(aggregatorNames, aggregate.Aggregator) { - return errors.Wrap(ErrorUndefinedResource(aggregate.Aggregator, resource.AggregatorType), Identify(aggregate), AggregatorKey) - } - } - - // Check local transformers exist or a path to one is defined - transformerNames := config.Transformers.Names() - for _, transformedColumn := range config.TransformedColumns { - if transformedColumn.TransformerPath == nil && transformedColumn.Transformer == "" { - return errors.Wrap(ErrorSpecifyOnlyOneMissing("transformer", "transformer_path"), Identify(transformedColumn)) - } - - if transformedColumn.TransformerPath != nil && transformedColumn.Transformer != "" { - return errors.Wrap(ErrorSpecifyOnlyOne("transformer", "transformer_path"), Identify(transformedColumn)) - } - - if transformedColumn.Transformer != "" && - !strings.Contains(transformedColumn.Transformer, ".") && - !slices.HasString(transformerNames, transformedColumn.Transformer) { - return errors.Wrap(ErrorUndefinedResource(transformedColumn.Transformer, resource.TransformerType), Identify(transformedColumn), TransformerKey) - } - } - for _, env := range config.Environments { if env.Name == envName { config.Environment = env @@ -352,6 +301,12 @@ func newPartial(configData interface{}, filePath string, emb *Embed, template *T if !errors.HasErrors(errs) { config.Transformers = append(config.Transformers, newResource.(*Transformer)) } + case resource.EstimatorType: + newResource = &Estimator{} + errs = cr.Struct(newResource, data, estimatorValidation) + if !errors.HasErrors(errs) { + config.Estimators = append(config.Estimators, newResource.(*Estimator)) + } case resource.TemplateType: if emb != nil { errs = []error{resource.ErrorTemplateInTemplate()} @@ -385,6 +340,10 @@ func newPartial(configData interface{}, filePath string, emb *Embed, template *T newResource.SetIndex(i) newResource.SetFilePath(filePath) newResource.SetEmbed(emb) + if config.Resources == nil { + config.Resources = make(map[string][]Resource) + } + config.Resources[newResource.GetName()] = append(config.Resources[newResource.GetName()], newResource) } } @@ -441,7 +400,7 @@ func New(configs map[string][]byte, envName string) (*Config, error) { } for _, env := range config.Environments { - ingestedColumnNames := env.Data.GetIngestedColumns() + ingestedColumnNames := env.Data.GetIngestedColumnNames() missingColumnNames := slices.SubtractStrSlice(ingestedColumnNames, config.RawColumns.Names()) for _, inferredColumnName := range missingColumnNames { inferredRawColumn := &RawInferredColumn{ diff --git a/pkg/operator/api/userconfig/config_key.go b/pkg/operator/api/userconfig/config_key.go index 0bdfeb7a5d..b5a5830e20 100644 --- a/pkg/operator/api/userconfig/config_key.go +++ b/pkg/operator/api/userconfig/config_key.go @@ -17,39 +17,82 @@ limitations under the License. package userconfig const ( - UnknownKey = "unknown" - NameKey = "name" - KindKey = "kind" - DataKey = "data" - SchemaKey = "schema" - ColumnsKey = "columns" - FeatureColumnsKey = "feature_columns" - TrainingColumnsKey = "training_columns" - TargetColumnKey = "target_column" - AggregatesKey = "aggregates" - ModelNameKey = "model_name" - InputsKey = "inputs" - ArgsKey = "args" - TypeKey = "type" - AggregatorKey = "aggregator" - TransformerKey = "transformer" - PathKey = "path" - ValueKey = "value" - YAMLKey = "yaml" + // Shared + UnknownKey = "unknown" + NameKey = "name" + KindKey = "kind" + InputKey = "input" + ComputeKey = "compute" + TypeKey = "type" + PathKey = "path" + OutputTypeKey = "output_type" + TagsKey = "tags" + + // input schema options + OptionalOptKey = "_optional" + DefaultOptKey = "_default" + MinCountOptKey = "_min_count" + MaxCountOptKey = "_max_count" // environment + DataKey = "data" + SchemaKey = "schema" + LogLevelKey = "log_level" LimitKey = "limit" NumRowsKey = "num_rows" FractionOfRowsKey = "fraction_of_rows" RandomizeKey = "randomize" RandomSeedKey = "random_seed" - // model - NumEpochsKey = "num_epochs" - NumStepsKey = "num_steps" - SaveCheckpointSecsKey = "save_checkpoints_secs" - SaveCheckpointStepsKey = "save_checkpoints_steps" - DataPartitionRatioKey = "data_partition_ratio" - TrainingKey = "training" - EvaluationKey = "evaluation" + // templates / embeds + TemplateKey = "template" + YAMLKey = "yaml" + ArgsKey = "args" + + // constants + ValueKey = "value" + + // raw columns + RequiredKey = "required" + MinKey = "min" + MaxKey = "max" + ValuesKey = "values" + + // aggregator / aggregate + AggregatorKey = "aggregator" + AggregatorPathKey = "aggregator_path" + + // transformer / transformed_column + TransformerKey = "transformer" + TransformerPathKey = "transformer_path" + + // estimator / model + EstimatorKey = "estimator" + EstimatorPathKey = "estimator_path" + TrainingInputKey = "training_input" + HparamsKey = "hparams" + TargetColumnKey = "target_column" + PredictionKeyKey = "prediction_key" + DataPartitionRatioKey = "data_partition_ratio" + TrainingKey = "training" + EvaluationKey = "evaluation" + BatchSizeKey = "batch_size" + NumStepsKey = "num_steps" + NumEpochsKey = "num_epochs" + ShuffleKey = "shuffle" + TfRandomSeedKey = "tf_random_seed" + TfRandomizeSeedKey = "tf_randomize_seed" + SaveSummaryStepsKey = "save_summary_steps" + SaveCheckpointsSecsKey = "save_checkpoints_secs" + SaveCheckpointsStepsKey = "save_checkpoints_steps" + LogStepCountStepsKey = "log_step_count_steps" + KeepCheckpointMaxKey = "keep_checkpoint_max" + KeepCheckpointEveryNHoursKey = "keep_checkpoint_every_n_hours" + StartDelaySecsKey = "start_delay_secs" + ThrottleSecsKey = "throttle_secs" + DatasetComputeKey = "dataset_compute" + + // API + ModelKey = "model" + ModelNameKey = "model_name" ) diff --git a/pkg/operator/api/userconfig/constants.go b/pkg/operator/api/userconfig/constants.go index a84c678940..44284cfa43 100644 --- a/pkg/operator/api/userconfig/constants.go +++ b/pkg/operator/api/userconfig/constants.go @@ -26,9 +26,9 @@ type Constants []*Constant type Constant struct { ResourceFields - Type interface{} `json:"type" yaml:"type"` - Value interface{} `json:"value" yaml:"value"` - Tags Tags `json:"tags" yaml:"tags"` + Type OutputSchema `json:"type" yaml:"type"` + Value interface{} `json:"value" yaml:"value"` + Tags Tags `json:"tags" yaml:"tags"` } var constantValidation = &cr.StructValidation{ @@ -43,9 +43,9 @@ var constantValidation = &cr.StructValidation{ { StructField: "Type", InterfaceValidation: &cr.InterfaceValidation{ - Required: true, + Required: false, Validator: func(t interface{}) (interface{}, error) { - return t, ValidateValueType(t) + return ValidateOutputSchema(t) }, }, }, @@ -53,9 +53,6 @@ var constantValidation = &cr.StructValidation{ StructField: "Value", InterfaceValidation: &cr.InterfaceValidation{ Required: true, - Validator: func(value interface{}) (interface{}, error) { - return value, ValidateValue(value) - }, }, }, tagsFieldValidation, @@ -84,19 +81,17 @@ func (constants Constants) Validate() error { } func (constant *Constant) Validate() error { - castedValue, err := CastValue(constant.Value, constant.Type) - if err != nil { - return errors.Wrap(err, Identify(constant), ValueKey) + if constant.Type != nil { + castedValue, err := CastOutputValue(constant.Value, constant.Type) + if err != nil { + return errors.Wrap(err, Identify(constant), ValueKey) + } + constant.Value = castedValue } - constant.Value = castedValue return nil } -func (constant *Constant) GetType() interface{} { - return constant.Type -} - func (constant *Constant) GetResourceType() resource.Type { return resource.ConstantType } diff --git a/pkg/operator/api/userconfig/environments.go b/pkg/operator/api/userconfig/environments.go index 58956033e9..c1c95a6f33 100644 --- a/pkg/operator/api/userconfig/environments.go +++ b/pkg/operator/api/userconfig/environments.go @@ -17,6 +17,8 @@ limitations under the License. package userconfig import ( + "github.com/cortexlabs/yaml" + "github.com/cortexlabs/cortex/pkg/lib/configreader" cr "github.com/cortexlabs/cortex/pkg/lib/configreader" "github.com/cortexlabs/cortex/pkg/lib/errors" @@ -119,7 +121,7 @@ var logLevelValidation = &cr.StructValidation{ } type Data interface { - GetIngestedColumns() []string + GetIngestedColumnNames() []string GetExternalPath() string Validate() error } @@ -181,7 +183,8 @@ var csvDataFieldValidations = []*cr.StructFieldValidation{ { StructField: "Schema", StringListValidation: &cr.StringListValidation{ - Required: true, + Required: true, + RequireCortexResources: true, }, }, { @@ -301,7 +304,7 @@ var parquetDataFieldValidations = []*cr.StructFieldValidation{ type ParquetColumn struct { ParquetColumnName string `json:"parquet_column_name" yaml:"parquet_column_name"` - RawColumnName string `json:"raw_column_name" yaml:"raw_column_name"` + RawColumn string `json:"raw_column" yaml:"raw_column"` } var parquetColumnValidation = &cr.StructValidation{ @@ -313,9 +316,10 @@ var parquetColumnValidation = &cr.StructValidation{ }, }, { - StructField: "RawColumnName", + StructField: "RawColumn", StringValidation: &cr.StringValidation{ - Required: true, + Required: true, + RequireCortexResources: true, }, }, }, @@ -338,9 +342,9 @@ func (environments Environments) Validate() error { return ErrorDuplicateResourceName(dups...) } - ingestedColumns := environments[0].Data.GetIngestedColumns() + ingestedColumns := environments[0].Data.GetIngestedColumnNames() for _, env := range environments[1:] { - if !strset.New(ingestedColumns...).IsEqual(strset.New(env.Data.GetIngestedColumns()...)) { + if !strset.New(ingestedColumns...).IsEqual(strset.New(env.Data.GetIngestedColumnNames()...)) { return ErrorEnvSchemaMismatch(environments[0], env) } } @@ -365,7 +369,7 @@ func (env *Environment) Validate() error { } } - dups := slices.FindDuplicateStrs(env.Data.GetIngestedColumns()) + dups := slices.FindDuplicateStrs(env.Data.GetIngestedColumnNames()) if len(dups) > 0 { return errors.Wrap(configreader.ErrorDuplicatedValue(dups[0]), Identify(env), DataKey, SchemaKey, "column name") } @@ -389,14 +393,20 @@ func (parqData *ParquetData) GetExternalPath() string { return parqData.Path } -func (csvData *CSVData) GetIngestedColumns() []string { - return csvData.Schema +func (csvData *CSVData) GetIngestedColumnNames() []string { + columnNames := make([]string, len(csvData.Schema)) + for i, col := range csvData.Schema { + colName, _ := yaml.ExtractAtSymbolText(col) + columnNames[i] = colName + } + return columnNames } -func (parqData *ParquetData) GetIngestedColumns() []string { +func (parqData *ParquetData) GetIngestedColumnNames() []string { columnNames := make([]string, len(parqData.Schema)) for i, parqCol := range parqData.Schema { - columnNames[i] = parqCol.RawColumnName + colName, _ := yaml.ExtractAtSymbolText(parqCol.RawColumn) + columnNames[i] = colName } return columnNames } diff --git a/pkg/operator/api/userconfig/errors.go b/pkg/operator/api/userconfig/errors.go index 0f1cbc1d98..58ae216bed 100644 --- a/pkg/operator/api/userconfig/errors.go +++ b/pkg/operator/api/userconfig/errors.go @@ -20,7 +20,7 @@ import ( "fmt" "strings" - "github.com/cortexlabs/cortex/pkg/lib/cast" + "github.com/cortexlabs/cortex/pkg/lib/configreader" "github.com/cortexlabs/cortex/pkg/lib/sets/strset" s "github.com/cortexlabs/cortex/pkg/lib/strings" "github.com/cortexlabs/cortex/pkg/operator/api/resource" @@ -30,7 +30,7 @@ type ErrorKind int const ( ErrUnknown ErrorKind = iota - ErrDuplicateConfigName + ErrDuplicateResourceName ErrDuplicateResourceValue ErrDuplicateConfig ErrMalformedConfig @@ -40,31 +40,45 @@ const ( ErrUndefinedConfig ErrRawColumnNotInEnv ErrUndefinedResource - ErrUndefinedResourceBuiltin - ErrColumnMustBeRaw + ErrResourceWrongType ErrSpecifyAllOrNone ErrSpecifyOnlyOne ErrOneOfPrerequisitesNotDefined ErrTemplateExtraArg ErrTemplateMissingArg - ErrInvalidColumnInputType - ErrInvalidColumnRuntimeType - ErrInvalidValueDataType - ErrUnsupportedColumnType - ErrUnsupportedDataType - ErrArgNameCannotBeType + ErrInvalidCompoundType + ErrDuplicateTypeInTypeString + ErrCannotMixValueAndColumnTypes + ErrColumnTypeLiteral + ErrColumnTypeNotAllowed + ErrCompoundTypeInOutputType + ErrUserKeysCannotStartWithUnderscore + ErrMixedInputArgOptionsAndUserKeys + ErrOptionOnNonIterable + ErrMinCountGreaterThanMaxCount + ErrTooManyElements + ErrTooFewElements + ErrInvalidInputType + ErrInvalidOutputType + ErrUnsupportedLiteralType + ErrUnsupportedLiteralMapKey + ErrUnsupportedOutputType + ErrMustBeDefined + ErrCannotBeNull ErrTypeListLength + ErrTypeMapZeroLength ErrGenericTypeMapLength ErrK8sQuantityMustBeInt - ErrRegressionTargetType - ErrClassificationTargetType + ErrTargetColumnIntOrFloat + ErrPredictionKeyOnModelWithEstimator ErrSpecifyOnlyOneMissing ErrEnvSchemaMismatch + ErrImplDoesNotExist ) var errorKinds = []string{ "err_unknown", - "err_duplicate_config_name", + "err_duplicate_resource_name", "err_duplicate_resource_value", "err_duplicate_config", "err_malformed_config", @@ -74,29 +88,43 @@ var errorKinds = []string{ "err_undefined_config", "err_raw_column_not_in_env", "err_undefined_resource", - "err_undefined_resource_builtin", - "err_column_must_be_raw", + "err_resource_wrong_type", "err_specify_all_or_none", "err_specify_only_one", "err_one_of_prerequisites_not_defined", "err_template_extra_arg", "err_template_missing_arg", - "err_invalid_column_input_type", - "err_invalid_column_runtime_type", - "err_invalid_value_data_type", - "err_unsupported_column_type", - "err_unsupported_data_type", - "err_arg_name_cannot_be_type", + "err_invalid_compound_type", + "err_duplicate_type_in_type_string", + "err_cannot_mix_value_and_column_types", + "err_column_type_literal", + "err_column_type_not_allowed", + "err_compound_type_in_output_type", + "err_user_keys_cannot_start_with_underscore", + "err_mixed_input_arg_options_and_user_keys", + "err_option_on_non_iterable", + "err_min_count_greater_than_max_count", + "err_too_many_elements", + "err_too_few_elements", + "err_invalid_input_type", + "err_invalid_output_type", + "err_unsupported_literal_type", + "err_unsupported_literal_map_key", + "err_unsupported_output_type", + "err_must_be_defined", + "err_cannot_be_null", "err_type_list_length", + "err_type_map_zero_length", "err_generic_type_map_length", "err_k8s_quantity_must_be_int", - "err_regression_target_type", - "err_classification_target_type", + "err_target_column_int_or_float", + "err_prediction_key_on_model_with_estimator", "err_specify_only_one_missing", "err_env_schema_mismatch", + "err_impl_does_not_exist", } -var _ = [1]int{}[int(ErrEnvSchemaMismatch)-(len(errorKinds)-1)] // Ensure list length matches +var _ = [1]int{}[int(ErrImplDoesNotExist)-(len(errorKinds)-1)] // Ensure list length matches func (t ErrorKind) String() string { return errorKinds[t] @@ -177,7 +205,7 @@ func ErrorDuplicateResourceName(resources ...Resource) error { pathStr := strings.Join(pathStrs, ", ") return Error{ - Kind: ErrDuplicateConfigName, + Kind: ErrDuplicateResourceName, message: fmt.Sprintf("name %s must be unique across %s (%s)", s.UserStr(resources[0].GetName()), s.StrsAnd(resourceTypes.Slice()), pathStr), } } @@ -243,23 +271,27 @@ func ErrorRawColumnNotInEnv(envName string) error { } func ErrorUndefinedResource(resourceName string, resourceTypes ...resource.Type) error { + message := fmt.Sprintf("%s %s is not defined", s.StrsOr(resource.Types(resourceTypes).StringList()), s.UserStr(resourceName)) + if strings.HasPrefix(resourceName, "cortex.") { + message = fmt.Sprintf("%s is not defined as a built-in %s in the Cortex namespace", s.UserStr(resourceName), s.StrsOr(resource.Types(resourceTypes).StringList())) + } + return Error{ Kind: ErrUndefinedResource, - message: fmt.Sprintf("%s %s is not defined", s.StrsOr(resource.Types(resourceTypes).StringList()), s.UserStr(resourceName)), + message: message, } } -func ErrorUndefinedResourceBuiltin(resourceName string, resourceTypes ...resource.Type) error { - return Error{ - Kind: ErrUndefinedResourceBuiltin, - message: fmt.Sprintf("%s %s is not defined in the Cortex namespace", s.StrsOr(resource.Types(resourceTypes).StringList()), s.UserStr(resourceName)), +func ErrorResourceWrongType(resources []Resource, validResourceTypes ...resource.Type) error { + name := resources[0].GetName() + resourceTypeStrs := make([]string, len(resources)) + for i, res := range resources { + resourceTypeStrs[i] = res.GetResourceType().String() } -} -func ErrorColumnMustBeRaw(columnName string) error { return Error{ - Kind: ErrColumnMustBeRaw, - message: fmt.Sprintf("%s is a transformed column, but only raw columns are allowed", s.UserStr(columnName)), + Kind: ErrResourceWrongType, + message: fmt.Sprintf("%s is a %s, but only %s are allowed in this context", s.UserStr(name), s.StrsAnd(resourceTypeStrs), s.StrsOr(resource.Types(validResourceTypes).PluralList())), } } @@ -310,46 +342,136 @@ func ErrorTemplateMissingArg(template *Template, argName string) error { } } -func ErrorInvalidColumnInputType(provided interface{}) error { +func ErrorInvalidCompoundType(provided interface{}) error { + return Error{ + Kind: ErrInvalidCompoundType, + message: fmt.Sprintf("invalid type (got %s, expected %s, or a combination of these types (separated by |)", DataTypeUserStr(provided), strings.Join(s.UserStrs(append(ValueTypeStrings(), ValidColumnTypeStrings()...)), ", ")), + } +} + +func ErrorDuplicateTypeInTypeString(duplicated string, provided string) error { + return Error{ + Kind: ErrDuplicateTypeInTypeString, + message: fmt.Sprintf("invalid type (%s is duplicated in %s)", DataTypeUserStr(duplicated), DataTypeUserStr(provided)), + } +} + +func ErrorCannotMixValueAndColumnTypes(provided interface{}) error { + return Error{ + Kind: ErrCannotMixValueAndColumnTypes, + message: fmt.Sprintf("invalid type (%s contains both column and value types)", DataTypeUserStr(provided)), + } +} + +func ErrorColumnTypeLiteral(provided interface{}) error { + return Error{ + Kind: ErrColumnTypeLiteral, + message: fmt.Sprintf("%s: literal values cannot be provided for column input types", s.UserStrStripped(provided)), + } +} + +func ErrorColumnTypeNotAllowed(provided interface{}) error { + return Error{ + Kind: ErrColumnTypeNotAllowed, + message: fmt.Sprintf("%s: column types cannot be used in this context, only value types are allowed (e.g. INT)", DataTypeUserStr(provided)), + } +} + +func ErrorCompoundTypeInOutputType(provided interface{}) error { + return Error{ + Kind: ErrCompoundTypeInOutputType, + message: fmt.Sprintf("%s: compound types (i.e. multiple types separated by \"|\") cannot be used in output type schemas", DataTypeUserStr(provided)), + } +} + +func ErrorUserKeysCannotStartWithUnderscore(key string) error { + return Error{ + Kind: ErrUserKeysCannotStartWithUnderscore, + message: fmt.Sprintf("%s: keys cannot start with underscores", key), + } +} + +func ErrorMixedInputArgOptionsAndUserKeys() error { return Error{ - Kind: ErrInvalidColumnInputType, - message: fmt.Sprintf("invalid column input type (got %s, expected %s, a combination of these types (separated by |), or a list of one of these types", DataTypeUserStr(provided), strings.Join(s.UserStrs(ColumnTypeStrings()), ", ")), + Kind: ErrMixedInputArgOptionsAndUserKeys, + message: "input arguments cannot contain both Cortex argument options (which start with underscores) and user-provided keys (which don't start with underscores)", } } -func ErrorInvalidColumnRuntimeType() error { +func ErrorOptionOnNonIterable(key string) error { return Error{ - Kind: ErrInvalidColumnRuntimeType, - message: fmt.Sprintf("invalid column runtime type (expected %s)", s.StrsOr(ColumnTypeStrings())), + Kind: ErrOptionOnNonIterable, + message: fmt.Sprintf("the %s option can only be used on list or maps", key), } } -func ErrorInvalidValueDataType(provided interface{}) error { +func ErrorMinCountGreaterThanMaxCount() error { return Error{ - Kind: ErrInvalidValueDataType, - message: fmt.Sprintf("invalid value data type (got %s, expected %s, a combination of these types (separated by |), a list of one of these types, or a map containing these types", DataTypeUserStr(provided), strings.Join(s.UserStrs(ValueTypeStrings()), ", ")), + Kind: ErrMinCountGreaterThanMaxCount, + message: fmt.Sprintf("the value provided for %s cannot be greater than the value provided for %s", MinCountOptKey, MaxCountOptKey), } } -func ErrorUnsupportedColumnType(provided interface{}, allowedTypes []string) error { - allowedTypesInterface, _ := cast.InterfaceToInterfaceSlice(allowedTypes) +func ErrorTooManyElements(t configreader.PrimitiveType, maxCount int64) error { return Error{ - Kind: ErrUnsupportedColumnType, - message: fmt.Sprintf("unsupported column type (got %s, expected %s)", DataTypeStr(provided), DataTypeStrsOr(allowedTypesInterface)), + Kind: ErrTooManyElements, + message: fmt.Sprintf("the provided %s contains more than the maximum allowed number of elements (%s), which is specified via %s", string(t), s.Int64(maxCount), MaxCountOptKey), } } -func ErrorUnsupportedDataType(provided interface{}, allowedType interface{}) error { +func ErrorTooFewElements(t configreader.PrimitiveType, minCount int64) error { return Error{ - Kind: ErrUnsupportedDataType, - message: fmt.Sprintf("unsupported data type (got %s, expected %s)", DataTypeStr(provided), DataTypeStr(allowedType)), + Kind: ErrTooFewElements, + message: fmt.Sprintf("the provided %s contains fewer than the minimum allowed number of elements (%s), which is specified via %s", string(t), s.Int64(minCount), MinCountOptKey), } } -func ErrorArgNameCannotBeType(provided string) error { +func ErrorInvalidInputType(provided interface{}) error { return Error{ - Kind: ErrArgNameCannotBeType, - message: fmt.Sprintf("data types cannot be used as arg names (got %s)", s.UserStr(provided)), + Kind: ErrInvalidInputType, + message: fmt.Sprintf("invalid type (got %s, expected %s, a combination of these types (separated by |), or a list or map containing these types", DataTypeUserStr(provided), strings.Join(s.UserStrs(append(ValueTypeStrings(), ValidColumnTypeStrings()...)), ", ")), + } +} + +func ErrorInvalidOutputType(provided interface{}) error { + return Error{ + Kind: ErrInvalidOutputType, + message: fmt.Sprintf("invalid type (got %s, expected %s, or a list or map containing these types", DataTypeUserStr(provided), strings.Join(s.UserStrs(ValueTypeStrings()), ", ")), + } +} + +func ErrorUnsupportedLiteralType(provided interface{}, allowedType interface{}) error { + return Error{ + Kind: ErrUnsupportedLiteralType, + message: fmt.Sprintf("input value's type is not supported by the schema (got %s, expected input with type %s)", DataTypeStr(provided), DataTypeStr(allowedType)), + } +} + +func ErrorUnsupportedLiteralMapKey(key interface{}, allowedType interface{}) error { + return Error{ + Kind: ErrUnsupportedLiteralMapKey, + message: fmt.Sprintf("%s: map key is not supported by the schema (%s)", s.UserStrStripped(key), DataTypeStr(allowedType)), + } +} + +func ErrorUnsupportedOutputType(provided interface{}, allowedType interface{}) error { + return Error{ + Kind: ErrUnsupportedOutputType, + message: fmt.Sprintf("unsupported type (got %s, expected %s)", DataTypeStr(provided), DataTypeStr(allowedType)), + } +} + +func ErrorMustBeDefined(allowedType interface{}) error { + return Error{ + Kind: ErrMustBeDefined, + message: fmt.Sprintf("must be defined (and it's value must fit the schema %s)", DataTypeStr(allowedType)), + } +} + +func ErrorCannotBeNull() error { + return Error{ + Kind: ErrCannotBeNull, + message: "cannot be null", } } @@ -360,10 +482,17 @@ func ErrorTypeListLength(provided interface{}) error { } } +func ErrorTypeMapZeroLength(provided interface{}) error { + return Error{ + Kind: ErrTypeMapZeroLength, + message: fmt.Sprintf("type maps must cannot have zero length (got %s)", DataTypeStr(provided)), + } +} + func ErrorGenericTypeMapLength(provided interface{}) error { return Error{ Kind: ErrGenericTypeMapLength, - message: fmt.Sprintf("generic type maps must contain exactly one key (i.e. the desired data type of all keys in the map) (got %s)", DataTypeStr(provided)), + message: fmt.Sprintf("maps with type keys (e.g. \"STRING\") must contain exactly one element (got %s)", DataTypeStr(provided)), } } @@ -374,17 +503,17 @@ func ErrorK8sQuantityMustBeInt(quantityStr string) error { } } -func ErrorRegressionTargetType() error { +func ErrorTargetColumnIntOrFloat() error { return Error{ - Kind: ErrRegressionTargetType, - message: "regression models can only predict float target values", + Kind: ErrTargetColumnIntOrFloat, + message: "models can only predict values of type INT_COLUMN (i.e. classification) or FLOAT_COLUMN (i.e. regression)", } } -func ErrorClassificationTargetType() error { +func ErrorPredictionKeyOnModelWithEstimator() error { return Error{ - Kind: ErrClassificationTargetType, - message: "classification models can only predict integer target values (i.e. {0, 1, ..., num_classes-1})", + Kind: ErrPredictionKeyOnModelWithEstimator, + message: fmt.Sprintf("models which use a pre-defined \"%s\" cannot define \"%s\" themselves (\"%s\" should be defined on the \"%s\", not the \"%s\")", EstimatorKey, PredictionKeyKey, PredictionKeyKey, resource.EstimatorType.String(), resource.ModelType.String()), } } @@ -405,9 +534,16 @@ func ErrorEnvSchemaMismatch(env1, env2 *Environment) error { Kind: ErrEnvSchemaMismatch, message: fmt.Sprintf("schemas diverge between environments (%s lists %s, and %s lists %s)", env1.Name, - s.StrsAnd(env1.Data.GetIngestedColumns()), + s.StrsAnd(env1.Data.GetIngestedColumnNames()), env2.Name, - s.StrsAnd(env2.Data.GetIngestedColumns()), + s.StrsAnd(env2.Data.GetIngestedColumnNames()), ), } } + +func ErrorImplDoesNotExist(path string) error { + return Error{ + Kind: ErrImplDoesNotExist, + message: fmt.Sprintf("%s: implementation file does not exist", path), + } +} diff --git a/pkg/operator/api/userconfig/estimators.go b/pkg/operator/api/userconfig/estimators.go new file mode 100644 index 0000000000..83bf653ad5 --- /dev/null +++ b/pkg/operator/api/userconfig/estimators.go @@ -0,0 +1,134 @@ +/* +Copyright 2019 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package userconfig + +import ( + cr "github.com/cortexlabs/cortex/pkg/lib/configreader" + "github.com/cortexlabs/cortex/pkg/operator/api/resource" +) + +type Estimators []*Estimator + +type Estimator struct { + ResourceFields + TargetColumn ColumnType `json:"target_column" yaml:"target_column"` + Input *InputSchema `json:"input" yaml:"input"` + TrainingInput *InputSchema `json:"training_input" yaml:"training_input"` + Hparams *InputSchema `json:"hparams" yaml:"hparams"` + PredictionKey string `json:"prediction_key" yaml:"prediction_key"` + Path string `json:"path" yaml:"path"` +} + +var estimatorValidation = &cr.StructValidation{ + StructFieldValidations: []*cr.StructFieldValidation{ + { + StructField: "Name", + StringValidation: &cr.StringValidation{ + Required: true, + AlphaNumericDashUnderscore: true, + }, + }, + { + StructField: "Path", + StringValidation: &cr.StringValidation{}, + DefaultField: "Name", + DefaultFieldFunc: func(name interface{}) interface{} { + return "implementations/estimators/" + name.(string) + ".py" + }, + }, + { + StructField: "TargetColumn", + StringValidation: &cr.StringValidation{ + Required: true, + Validator: func(col string) (string, error) { + colType := ColumnTypeFromString(col) + if colType != IntegerColumnType && colType != FloatColumnType { + return "", ErrorTargetColumnIntOrFloat() + } + return col, nil + }, + }, + Parser: func(str string) (interface{}, error) { + return ColumnTypeFromString(str), nil + }, + }, + { + StructField: "Input", + InterfaceValidation: &cr.InterfaceValidation{ + Required: true, + Validator: inputSchemaValidator, + }, + }, + { + StructField: "TrainingInput", + InterfaceValidation: &cr.InterfaceValidation{ + Required: false, + Validator: inputSchemaValidator, + }, + }, + { + StructField: "Hparams", + InterfaceValidation: &cr.InterfaceValidation{ + Required: false, + Validator: inputSchemaValidatorValueTypesOnly, + }, + }, + { + StructField: "PredictionKey", + StringValidation: &cr.StringValidation{ + Default: "", + AllowEmpty: true, + }, + }, + typeFieldValidation, + }, +} + +func (estimators Estimators) Validate() error { + resources := make([]Resource, len(estimators)) + for i, res := range estimators { + resources[i] = res + } + + dups := FindDuplicateResourceName(resources...) + if len(dups) > 0 { + return ErrorDuplicateResourceName(dups...) + } + + return nil +} + +func (estimators Estimators) Get(name string) *Estimator { + for _, estimator := range estimators { + if estimator.Name == name { + return estimator + } + } + return nil +} + +func (estimator *Estimator) GetResourceType() resource.Type { + return resource.EstimatorType +} + +func (estimators Estimators) Names() []string { + names := make([]string, len(estimators)) + for i, estimator := range estimators { + names[i] = estimator.Name + } + return names +} diff --git a/pkg/operator/api/userconfig/inputs.go b/pkg/operator/api/userconfig/inputs.go deleted file mode 100644 index 7664ad567e..0000000000 --- a/pkg/operator/api/userconfig/inputs.go +++ /dev/null @@ -1,84 +0,0 @@ -/* -Copyright 2019 Cortex Labs, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package userconfig - -import ( - cr "github.com/cortexlabs/cortex/pkg/lib/configreader" -) - -type Inputs struct { - Columns map[string]interface{} `json:"columns" yaml:"columns"` - Args map[string]interface{} `json:"args" yaml:"args"` -} - -var inputTypesFieldValidation = &cr.StructFieldValidation{ - StructField: "Inputs", - StructValidation: &cr.StructValidation{ - Required: true, - StructFieldValidations: []*cr.StructFieldValidation{ - { - StructField: "Columns", - InterfaceMapValidation: &cr.InterfaceMapValidation{ - AllowEmpty: true, - Default: make(map[string]interface{}), - Validator: func(columnTypes map[string]interface{}) (map[string]interface{}, error) { - return columnTypes, ValidateColumnInputTypes(columnTypes) - }, - }, - }, - { - StructField: "Args", - InterfaceMapValidation: &cr.InterfaceMapValidation{ - AllowEmpty: true, - Default: make(map[string]interface{}), - Validator: func(argTypes map[string]interface{}) (map[string]interface{}, error) { - return argTypes, ValidateArgTypes(argTypes) - }, - }, - }, - }, - }, -} - -var inputValuesFieldValidation = &cr.StructFieldValidation{ - StructField: "Inputs", - StructValidation: &cr.StructValidation{ - Required: true, - StructFieldValidations: []*cr.StructFieldValidation{ - { - StructField: "Columns", - InterfaceMapValidation: &cr.InterfaceMapValidation{ - AllowEmpty: true, - Default: make(map[string]interface{}), - Validator: func(columnInputValues map[string]interface{}) (map[string]interface{}, error) { - return columnInputValues, ValidateColumnInputValues(columnInputValues) - }, - }, - }, - { - StructField: "Args", - InterfaceMapValidation: &cr.InterfaceMapValidation{ - AllowEmpty: true, - Default: make(map[string]interface{}), - Validator: func(argValues map[string]interface{}) (map[string]interface{}, error) { - return argValues, ValidateArgValues(argValues) - }, - }, - }, - }, - }, -} diff --git a/pkg/operator/api/userconfig/model_type.go b/pkg/operator/api/userconfig/model_type.go deleted file mode 100644 index 800daf008e..0000000000 --- a/pkg/operator/api/userconfig/model_type.go +++ /dev/null @@ -1,78 +0,0 @@ -/* -Copyright 2019 Cortex Labs, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package userconfig - -type ModelType int - -const ( - UnknownModelType ModelType = iota - ClassificationModelType - RegressionModelType -) - -var modelTypes = []string{ - "unknown", - "classification", - "regression", -} - -func ModelTypeFromString(s string) ModelType { - for i := 0; i < len(modelTypes); i++ { - if s == modelTypes[i] { - return ModelType(i) - } - } - return UnknownModelType -} - -func ModelTypeStrings() []string { - return modelTypes[1:] -} - -func (t ModelType) String() string { - return modelTypes[t] -} - -// MarshalText satisfies TextMarshaler -func (t ModelType) MarshalText() ([]byte, error) { - return []byte(t.String()), nil -} - -// UnmarshalText satisfies TextUnmarshaler -func (t *ModelType) UnmarshalText(text []byte) error { - enum := string(text) - for i := 0; i < len(modelTypes); i++ { - if enum == modelTypes[i] { - *t = ModelType(i) - return nil - } - } - - *t = UnknownModelType - return nil -} - -// UnmarshalBinary satisfies BinaryUnmarshaler -// Needed for msgpack -func (t *ModelType) UnmarshalBinary(data []byte) error { - return t.UnmarshalText(data) -} - -// MarshalBinary satisfies BinaryMarshaler -func (t ModelType) MarshalBinary() ([]byte, error) { - return []byte(t.String()), nil -} diff --git a/pkg/operator/api/userconfig/models.go b/pkg/operator/api/userconfig/models.go index b1f7acfea1..9073ac6c91 100644 --- a/pkg/operator/api/userconfig/models.go +++ b/pkg/operator/api/userconfig/models.go @@ -22,7 +22,6 @@ import ( cr "github.com/cortexlabs/cortex/pkg/lib/configreader" "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/pointer" - "github.com/cortexlabs/cortex/pkg/lib/slices" "github.com/cortexlabs/cortex/pkg/operator/api/resource" ) @@ -30,14 +29,13 @@ type Models []*Model type Model struct { ResourceFields - Type ModelType `json:"type" yaml:"type"` - Path string `json:"path" yaml:"path"` + Estimator string `json:"estimator" yaml:"estimator"` + EstimatorPath *string `json:"estimator_path" yaml:"estimator_path"` TargetColumn string `json:"target_column" yaml:"target_column"` + Input interface{} `json:"input" yaml:"input"` + TrainingInput interface{} `json:"training_input" yaml:"training_input"` + Hparams interface{} `json:"hparams" yaml:"hparams"` PredictionKey string `json:"prediction_key" yaml:"prediction_key"` - FeatureColumns []string `json:"feature_columns" yaml:"feature_columns"` - TrainingColumns []string `json:"training_columns" yaml:"training_columns"` - Aggregates []string `json:"aggregates" yaml:"aggregates"` - Hparams map[string]interface{} `json:"hparams" yaml:"hparams"` DataPartitionRatio *ModelDataPartitionRatio `json:"data_partition_ratio" yaml:"data_partition_ratio"` Training *ModelTraining `json:"training" yaml:"training"` Evaluation *ModelEvaluation `json:"evaluation" yaml:"evaluation"` @@ -56,63 +54,48 @@ var modelValidation = &cr.StructValidation{ }, }, { - StructField: "Path", - StringValidation: &cr.StringValidation{}, - DefaultField: "Name", - DefaultFieldFunc: func(name interface{}) interface{} { - return "implementations/models/" + name.(string) + ".py" - }, - }, - { - StructField: "Type", + StructField: "Estimator", StringValidation: &cr.StringValidation{ - Default: ClassificationModelType.String(), - AllowedValues: ModelTypeStrings(), - }, - Parser: func(str string) (interface{}, error) { - return ModelTypeFromString(str), nil + AllowEmpty: true, + AlphaNumericDashDotUnderscoreOrEmpty: true, }, }, { - StructField: "TargetColumn", - StringValidation: &cr.StringValidation{ - Required: true, - }, + StructField: "EstimatorPath", + StringPtrValidation: &cr.StringPtrValidation{}, }, { - StructField: "PredictionKey", + StructField: "TargetColumn", StringValidation: &cr.StringValidation{ - Default: "", - AllowEmpty: true, + Required: true, + RequireCortexResources: true, }, }, { - StructField: "FeatureColumns", - StringListValidation: &cr.StringListValidation{ - Required: true, - DisallowDups: true, + StructField: "Input", + InterfaceValidation: &cr.InterfaceValidation{ + Required: true, + AllowCortexResources: true, }, }, { - StructField: "TrainingColumns", - StringListValidation: &cr.StringListValidation{ - AllowEmpty: true, - DisallowDups: true, - Default: make([]string, 0), + StructField: "TrainingInput", + InterfaceValidation: &cr.InterfaceValidation{ + Required: false, + AllowCortexResources: true, }, }, { - StructField: "Aggregates", - StringListValidation: &cr.StringListValidation{ - AllowEmpty: true, - Default: make([]string, 0), + StructField: "HParams", + InterfaceValidation: &cr.InterfaceValidation{ + Required: false, }, }, { - StructField: "Hparams", - InterfaceMapValidation: &cr.InterfaceMapValidation{ + StructField: "PredictionKey", + StringValidation: &cr.StringValidation{ + Default: "", AllowEmpty: true, - Default: make(map[string]interface{}), }, }, { @@ -342,7 +325,7 @@ func (model *Model) Validate() error { if model.Training.SaveCheckpointsSecs == nil && model.Training.SaveCheckpointsSteps == nil { model.Training.SaveCheckpointsSecs = pointer.Int64(600) } else if model.Training.SaveCheckpointsSecs != nil && model.Training.SaveCheckpointsSteps != nil { - return errors.Wrap(ErrorSpecifyOnlyOne(SaveCheckpointSecsKey, SaveCheckpointStepsKey), Identify(model), TrainingKey) + return errors.Wrap(ErrorSpecifyOnlyOne(SaveCheckpointsSecsKey, SaveCheckpointsStepsKey), Identify(model), TrainingKey) } if model.Training.NumSteps == nil && model.Training.NumEpochs == nil { @@ -357,17 +340,19 @@ func (model *Model) Validate() error { return errors.Wrap(ErrorSpecifyOnlyOne(NumEpochsKey, NumStepsKey), Identify(model), EvaluationKey) } - for _, trainingColumn := range model.TrainingColumns { - if slices.HasString(model.FeatureColumns, trainingColumn) { - return errors.Wrap(ErrorDuplicateResourceValue(trainingColumn, TrainingColumnsKey, FeatureColumnsKey), Identify(model)) - } + if model.EstimatorPath == nil && model.Estimator == "" { + return errors.Wrap(ErrorSpecifyOnlyOneMissing("estimator", "estimator_path"), Identify(model)) } - return nil -} + if model.EstimatorPath != nil && model.Estimator != "" { + return errors.Wrap(ErrorSpecifyOnlyOne("estimator", "estimator_path"), Identify(model)) + } -func (model *Model) AllColumnNames() []string { - return slices.MergeStrSlices(model.FeatureColumns, model.TrainingColumns, []string{model.TargetColumn}) + if model.Estimator != "" && model.PredictionKey != "" { + return ErrorPredictionKeyOnModelWithEstimator() + } + + return nil } func (model *Model) GetResourceType() resource.Type { diff --git a/pkg/operator/api/userconfig/raw_columns.go b/pkg/operator/api/userconfig/raw_columns.go index 8bda695782..6aeb740da8 100644 --- a/pkg/operator/api/userconfig/raw_columns.go +++ b/pkg/operator/api/userconfig/raw_columns.go @@ -23,7 +23,7 @@ import ( type RawColumn interface { Column - GetType() ColumnType + GetColumnType() ColumnType GetCompute() *SparkCompute } @@ -217,19 +217,19 @@ func (rawColumns RawColumns) Get(name string) RawColumn { return nil } -func (column *RawIntColumn) GetType() ColumnType { +func (column *RawIntColumn) GetColumnType() ColumnType { return column.Type } -func (column *RawFloatColumn) GetType() ColumnType { +func (column *RawFloatColumn) GetColumnType() ColumnType { return column.Type } -func (column *RawStringColumn) GetType() ColumnType { +func (column *RawStringColumn) GetColumnType() ColumnType { return column.Type } -func (column *RawInferredColumn) GetType() ColumnType { +func (column *RawInferredColumn) GetColumnType() ColumnType { return column.Type } diff --git a/pkg/operator/api/userconfig/transformed_columns.go b/pkg/operator/api/userconfig/transformed_columns.go index 476b71f368..5161f95175 100644 --- a/pkg/operator/api/userconfig/transformed_columns.go +++ b/pkg/operator/api/userconfig/transformed_columns.go @@ -17,9 +17,8 @@ limitations under the License. package userconfig import ( - "sort" - "github.com/cortexlabs/cortex/pkg/lib/configreader" + "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/operator/api/resource" ) @@ -29,7 +28,7 @@ type TransformedColumn struct { ResourceFields Transformer string `json:"transformer" yaml:"transformer"` TransformerPath *string `json:"transformer_path" yaml:"transformer_path"` - Inputs *Inputs `json:"inputs" yaml:"inputs"` + Input interface{} `json:"input" yaml:"input"` Compute *SparkCompute `json:"compute" yaml:"compute"` Tags Tags `json:"tags" yaml:"tags"` } @@ -54,7 +53,13 @@ var transformedColumnValidation = &configreader.StructValidation{ StructField: "TransformerPath", StringPtrValidation: &configreader.StringPtrValidation{}, }, - inputValuesFieldValidation, + { + StructField: "Input", + InterfaceValidation: &configreader.InterfaceValidation{ + Required: true, + AllowCortexResources: true, + }, + }, sparkComputeFieldValidation("Compute"), tagsFieldValidation, typeFieldValidation, @@ -62,6 +67,12 @@ var transformedColumnValidation = &configreader.StructValidation{ } func (columns TransformedColumns) Validate() error { + for _, column := range columns { + if err := column.Validate(); err != nil { + return err + } + } + resources := make([]Resource, len(columns)) for i, res := range columns { resources[i] = res @@ -75,6 +86,18 @@ func (columns TransformedColumns) Validate() error { return nil } +func (column *TransformedColumn) Validate() error { + if column.TransformerPath == nil && column.Transformer == "" { + return errors.Wrap(ErrorSpecifyOnlyOneMissing("transformer", "transformer_path"), Identify(column)) + } + + if column.TransformerPath != nil && column.Transformer != "" { + return errors.Wrap(ErrorSpecifyOnlyOne("transformer", "transformer_path"), Identify(column)) + } + + return nil +} + func (column *TransformedColumn) IsRaw() bool { return false } @@ -85,23 +108,17 @@ func (column *TransformedColumn) GetResourceType() resource.Type { func (columns TransformedColumns) Names() []string { names := make([]string, len(columns)) - for i, transformedColumn := range columns { - names[i] = transformedColumn.GetName() + for i, column := range columns { + names[i] = column.GetName() } return names } func (columns TransformedColumns) Get(name string) *TransformedColumn { - for _, transformedColumn := range columns { - if transformedColumn.GetName() == name { - return transformedColumn + for _, column := range columns { + if column.GetName() == name { + return column } } return nil } - -func (column *TransformedColumn) InputColumnNames() []string { - inputs, _ := configreader.FlattenAllStrValues(column.Inputs.Columns) - sort.Strings(inputs) - return inputs -} diff --git a/pkg/operator/api/userconfig/transformers.go b/pkg/operator/api/userconfig/transformers.go index 6eada9effa..1d04cec18f 100644 --- a/pkg/operator/api/userconfig/transformers.go +++ b/pkg/operator/api/userconfig/transformers.go @@ -25,9 +25,9 @@ type Transformers []*Transformer type Transformer struct { ResourceFields - Inputs *Inputs `json:"inputs" yaml:"inputs"` - OutputType ColumnType `json:"output_type" yaml:"output_type"` - Path string `json:"path" yaml:"path"` + Input *InputSchema `json:"input" yaml:"input"` + OutputType ColumnType `json:"output_type" yaml:"output_type"` + Path string `json:"path" yaml:"path"` } var transformerValidation = &cr.StructValidation{ @@ -47,17 +47,23 @@ var transformerValidation = &cr.StructValidation{ return "implementations/transformers/" + name.(string) + ".py" }, }, + { + StructField: "Input", + InterfaceValidation: &cr.InterfaceValidation{ + Required: true, + Validator: inputSchemaValidator, + }, + }, { StructField: "OutputType", StringValidation: &cr.StringValidation{ Required: true, - AllowedValues: ColumnTypeStrings(), + AllowedValues: ValidColumnTypeStrings(), }, Parser: func(str string) (interface{}, error) { return ColumnTypeFromString(str), nil }, }, - inputTypesFieldValidation, typeFieldValidation, }, } diff --git a/pkg/operator/api/userconfig/types.go b/pkg/operator/api/userconfig/types.go index cea5ac756b..f764245aac 100644 --- a/pkg/operator/api/userconfig/types.go +++ b/pkg/operator/api/userconfig/types.go @@ -17,17 +17,13 @@ limitations under the License. package userconfig import ( - "regexp" "strings" + "github.com/cortexlabs/cortex/pkg/consts" + "github.com/cortexlabs/cortex/pkg/lib/cast" s "github.com/cortexlabs/cortex/pkg/lib/strings" ) -var ( - typeStrRegex = regexp.MustCompile(`"(INT|FLOAT|STRING|BOOL)(_COLUMN)?(\|(INT|FLOAT|STRING|BOOL)(_COLUMN)?)*"`) - singleDataTypeRegexp = regexp.MustCompile(`^\w*\w$`) -) - func DataTypeStrsOr(dataTypes []interface{}) string { dataTypeStrs := make([]string, len(dataTypes)) for i, dataType := range dataTypes { @@ -37,8 +33,8 @@ func DataTypeStrsOr(dataTypes []interface{}) string { } func DataTypeStr(dataType interface{}) string { - dataTypeStr := s.ObjFlat(dataType) - matches := typeStrRegex.FindAllString(dataTypeStr, -1) + dataTypeStr := s.ObjFlat(flattenTypeSchema(dataType)) + matches := consts.TypeStrRegex.FindAllString(dataTypeStr, -1) for _, match := range matches { trimmed := s.TrimPrefixAndSuffix(match, `"`) dataTypeStr = strings.Replace(dataTypeStr, match, trimmed, -1) @@ -48,8 +44,37 @@ func DataTypeStr(dataType interface{}) string { func DataTypeUserStr(dataType interface{}) string { dataTypeStr := DataTypeStr(dataType) - if singleDataTypeRegexp.MatchString(dataTypeStr) { + if consts.SingleTypeStrRegex.MatchString(dataTypeStr) { dataTypeStr = s.UserStr(dataTypeStr) } return dataTypeStr } + +// Remove cortex arg options from input schemas +func flattenTypeSchema(schema interface{}) interface{} { + if inputSchema, ok := schema.(InputSchema); ok { + return flattenTypeSchema(inputSchema.Type) + } + + if inputSchemaPtr, ok := schema.(*InputSchema); ok { + return flattenTypeSchema(inputSchemaPtr.Type) + } + + if schemaSlice, ok := cast.InterfaceToInterfaceSlice(schema); ok { + flattenedSlice := make([]interface{}, len(schemaSlice)) + for i := range schemaSlice { + flattenedSlice[i] = flattenTypeSchema(schemaSlice[i]) + } + return flattenedSlice + } + + if schemaMap, ok := cast.InterfaceToInterfaceInterfaceMap(schema); ok { + flattenedMap := make(map[interface{}]interface{}, len(schemaMap)) + for k, v := range schemaMap { + flattenedMap[k] = flattenTypeSchema(v) + } + return flattenedMap + } + + return schema +} diff --git a/pkg/operator/api/userconfig/validators.go b/pkg/operator/api/userconfig/validators.go index 3a107a6bc7..ef39ce3cd5 100644 --- a/pkg/operator/api/userconfig/validators.go +++ b/pkg/operator/api/userconfig/validators.go @@ -21,321 +21,314 @@ import ( "github.com/cortexlabs/cortex/pkg/lib/cast" "github.com/cortexlabs/cortex/pkg/lib/configreader" + cr "github.com/cortexlabs/cortex/pkg/lib/configreader" "github.com/cortexlabs/cortex/pkg/lib/errors" - "github.com/cortexlabs/cortex/pkg/lib/maps" - "github.com/cortexlabs/cortex/pkg/lib/slices" + "github.com/cortexlabs/cortex/pkg/lib/pointer" s "github.com/cortexlabs/cortex/pkg/lib/strings" ) -func isValidColumnInputType(columnTypeStr string) bool { - for _, columnTypeStrItem := range strings.Split(columnTypeStr, "|") { - if !slices.HasString(ColumnTypeStrings(), columnTypeStrItem) { - return false - } - } - return true +type InputSchema struct { + Type InputTypeSchema `json:"_type" yaml:"_type"` + Optional bool `json:"_optional" yaml:"_optional"` + Default interface{} `json:"_default" yaml:"_default"` + AllowNull bool `json:"_allow_null" yaml:"_allow_null"` + MinCount *int64 `json:"_min_count" yaml:"_min_count"` + MaxCount *int64 `json:"_max_count" yaml:"_max_count"` } -func isValidValueType(valueTypeStr string) bool { - for _, valueTypeStrItem := range strings.Split(valueTypeStr, "|") { - if !slices.HasString(ValueTypeStrings(), valueTypeStrItem) { - return false - } - } - return true +type InputTypeSchema interface{} // CompundType, length-one array of *InputSchema, or map of {scalar|CompoundType -> *InputSchema} + +type OutputSchema interface{} // ValueType, length-one array of OutputSchema, or map of {scalar|ValueType -> OutputSchema} (no *_COLUMN types, compound types, or input options like _default) + +func inputSchemaValidator(in interface{}) (interface{}, error) { + return ValidateInputSchema(in, false) // This casts it to *InputSchema } -func ValidateColumnInputTypes(columnTypes map[string]interface{}) error { - for columnInputName, columnType := range columnTypes { - if columnTypeStr, ok := columnType.(string); ok { - if !isValidColumnInputType(columnTypeStr) { - return errors.Wrap(ErrorInvalidColumnInputType(columnTypeStr), columnInputName) - } - continue - } +func inputSchemaValidatorValueTypesOnly(in interface{}) (interface{}, error) { + return ValidateInputSchema(in, true) // This casts it to *InputSchema +} - if columnTypeStrs, ok := cast.InterfaceToStrSlice(columnType); ok { - if len(columnTypeStrs) != 1 { - return errors.Wrap(ErrorTypeListLength(columnTypeStrs), columnInputName) - } - if !isValidColumnInputType(columnTypeStrs[0]) { - return errors.Wrap(ErrorInvalidColumnInputType(columnTypeStrs), columnInputName) - } - continue +func ValidateInputSchema(in interface{}, disallowColumnTypes bool) (*InputSchema, error) { + // Check for cortex options vs short form + if inMap, ok := cast.InterfaceToStrInterfaceMap(in); ok { + foundUnderscore, foundNonUnderscore := false, false + for key := range inMap { + if strings.HasPrefix(key, "_") { + foundUnderscore = true + } else { + foundNonUnderscore = true + } + } + + if foundUnderscore { + if foundNonUnderscore { + return nil, ErrorMixedInputArgOptionsAndUserKeys() + } + + inputSchemaValidation := &cr.StructValidation{ + StructFieldValidations: []*cr.StructFieldValidation{ + { + StructField: "Type", + InterfaceValidation: &cr.InterfaceValidation{ + Required: true, + Validator: func(t interface{}) (interface{}, error) { + return validateInputTypeSchema(t, disallowColumnTypes) + }, + }, + }, + { + StructField: "Optional", + BoolValidation: &cr.BoolValidation{}, + }, + { + StructField: "Default", + InterfaceValidation: &cr.InterfaceValidation{}, + }, + { + StructField: "AllowNull", + InterfaceValidation: &cr.InterfaceValidation{}, + }, + { + StructField: "MinCount", + Int64PtrValidation: &cr.Int64PtrValidation{ + GreaterThanOrEqualTo: pointer.Int64(0), + }, + }, + { + StructField: "MaxCount", + Int64PtrValidation: &cr.Int64PtrValidation{ + GreaterThanOrEqualTo: pointer.Int64(0), + }, + }, + }, + } + inputSchema := &InputSchema{} + errs := cr.Struct(inputSchema, inMap, inputSchemaValidation) + + if errors.HasErrors(errs) { + return nil, errors.FirstError(errs...) + } + + if err := validateInputSchemaOptions(inputSchema); err != nil { + return nil, err + } + + return inputSchema, nil } - - return errors.Wrap(ErrorInvalidColumnInputType(columnType), columnInputName) } - return nil -} + typeSchema, err := validateInputTypeSchema(in, disallowColumnTypes) + if err != nil { + return nil, err + } + inputSchema := &InputSchema{ + Type: typeSchema, + } -func ValidateColumnInputValues(columnInputValues map[string]interface{}) error { - for columnInputName, columnInputValue := range columnInputValues { - if _, ok := columnInputValue.(string); ok { - continue - } - if columnNames, ok := cast.InterfaceToStrSlice(columnInputValue); ok { - if columnNames == nil { - return errors.Wrap(configreader.ErrorCannotBeNull(), columnInputName) - } - continue - } - return errors.Wrap( - configreader.ErrorInvalidPrimitiveType(columnInputValue, configreader.PrimTypeString, configreader.PrimTypeStringList), - columnInputName, - ) + if err := validateInputSchemaOptions(inputSchema); err != nil { + return nil, err } - return nil + return inputSchema, nil } -func ValidateColumnRuntimeTypes(columnRuntimeTypes map[string]interface{}) error { - for columnInputName, columnTypeInter := range columnRuntimeTypes { - if columnType, ok := columnTypeInter.(ColumnType); ok { - if columnType == UnknownColumnType { - return errors.Wrap(ErrorInvalidColumnRuntimeType(), columnInputName) // unexpected - } - continue +func validateInputTypeSchema(in interface{}, disallowColumnTypes bool) (InputTypeSchema, error) { + // String + if inStr, ok := in.(string); ok { + compoundType, err := CompoundTypeFromString(inStr) + if err != nil { + return nil, err } - if columnTypes, ok := columnTypeInter.([]ColumnType); ok { - for i, columnType := range columnTypes { - if columnType == UnknownColumnType { - return errors.Wrap(ErrorInvalidColumnRuntimeType(), columnInputName, s.Index(i)) // unexpected - } - } - continue + if disallowColumnTypes && compoundType.IsColumns() { + return nil, ErrorColumnTypeNotAllowed(inStr) } - return errors.Wrap(ErrorInvalidColumnRuntimeType(), columnInputName) // unexpected + return compoundType, nil } - return nil -} - -// columnRuntimeTypes is {string -> ColumnType or []ColumnType}, columnSchemaTypes is {string -> string or []string} -func CheckColumnRuntimeTypesMatch(columnRuntimeTypes map[string]interface{}, columnSchemaTypes map[string]interface{}) error { - err := ValidateColumnInputTypes(columnSchemaTypes) - if err != nil { - return err - } - err = ValidateColumnRuntimeTypes(columnRuntimeTypes) - if err != nil { - return err + // List + if inSlice, ok := cast.InterfaceToInterfaceSlice(in); ok { + if len(inSlice) != 1 { + return nil, ErrorTypeListLength(inSlice) + } + inputSchema, err := ValidateInputSchema(inSlice[0], disallowColumnTypes) + if err != nil { + return nil, errors.Wrap(err, s.Index(0)) + } + return []interface{}{inputSchema}, nil } - for columnInputName, columnSchemaType := range columnSchemaTypes { - if len(columnRuntimeTypes) == 0 { - return configreader.ErrorMapMustBeDefined(maps.InterfaceMapKeys(columnSchemaTypes)...) + // Map + if inMap, ok := cast.InterfaceToInterfaceInterfaceMap(in); ok { + if len(inMap) == 0 { + return nil, ErrorTypeMapZeroLength(inMap) } - columnRuntimeTypeInter, ok := columnRuntimeTypes[columnInputName] - if !ok { - return errors.Wrap(configreader.ErrorMustBeDefined(), columnInputName) + var typeKey CompoundType + var typeValue interface{} + for k, v := range inMap { + var err error + typeKey, err = CompoundTypeFromString(k) + if err == nil { + typeValue = v + break + } } - if columnSchemaTypeStr, ok := columnSchemaType.(string); ok { - validTypes := strings.Split(columnSchemaTypeStr, "|") - columnRuntimeType, ok := columnRuntimeTypeInter.(ColumnType) - if !ok { - return errors.Wrap(ErrorUnsupportedColumnType(columnRuntimeTypeInter, validTypes), columnInputName) + // Generic map + if typeValue != nil { + if len(inMap) != 1 { + return nil, ErrorGenericTypeMapLength(inMap) } - - if columnRuntimeType == InferredColumnType { - continue + if disallowColumnTypes && typeKey.IsColumns() { + return nil, ErrorColumnTypeNotAllowed(typeKey) } - - if !slices.HasString(validTypes, columnRuntimeType.String()) { - return errors.Wrap(ErrorUnsupportedColumnType(columnRuntimeType, validTypes), columnInputName) + valueInputSchema, err := ValidateInputSchema(typeValue, disallowColumnTypes) + if err != nil { + return nil, errors.Wrap(err, string(typeKey)) } - continue + return map[interface{}]interface{}{typeKey: valueInputSchema}, nil } - if columnSchemaTypeStrs, ok := cast.InterfaceToStrSlice(columnSchemaType); ok { - validTypes := strings.Split(columnSchemaTypeStrs[0], "|") - columnRuntimeTypeSlice, ok := columnRuntimeTypeInter.([]ColumnType) - if !ok { - return errors.Wrap(ErrorUnsupportedColumnType(columnRuntimeTypeInter, columnSchemaTypeStrs), columnInputName) + // Fixed map + outMap := map[interface{}]interface{}{} + for key, value := range inMap { + if !cast.IsScalarType(key) { + return nil, configreader.ErrorInvalidPrimitiveType(key, configreader.PrimTypeScalars...) } - for i, columnRuntimeType := range columnRuntimeTypeSlice { - if !slices.HasString(validTypes, columnRuntimeType.String()) { - return errors.Wrap(ErrorUnsupportedColumnType(columnRuntimeType, validTypes), columnInputName, s.Index(i)) + if keyStr, ok := key.(string); ok { + if strings.HasPrefix(keyStr, "_") { + return nil, ErrorUserKeysCannotStartWithUnderscore(keyStr) } } - continue - } - return errors.Wrap(ErrorInvalidColumnInputType(columnSchemaType), columnInputName) // unexpected - } - - for columnInputName := range columnRuntimeTypes { - if _, ok := columnSchemaTypes[columnInputName]; !ok { - return configreader.ErrorUnsupportedKey(columnInputName) + valueInputSchema, err := ValidateInputSchema(value, disallowColumnTypes) + if err != nil { + return nil, errors.Wrap(err, s.UserStrStripped(key)) + } + outMap[key] = valueInputSchema } + return outMap, nil } - return nil + return nil, ErrorInvalidInputType(in) } -func ValidateArgTypes(argTypes map[string]interface{}) error { - for argName, valueType := range argTypes { - if isValidValueType(argName) { - return ErrorArgNameCannotBeType(argName) - } - err := ValidateValueType(valueType) - if err != nil { - return errors.Wrap(err, argName) - } +func validateInputSchemaOptions(inputSchema *InputSchema) error { + if inputSchema.Default != nil { + inputSchema.Optional = true } - return nil -} -func ValidateValueType(valueType interface{}) error { - if valueTypeStr, ok := valueType.(string); ok { - if !isValidValueType(valueTypeStr) { - return ErrorInvalidValueDataType(valueTypeStr) + _, isSlice := cast.InterfaceToInterfaceSlice(inputSchema.Type) + isGenericMap := false + if interfaceMap, ok := cast.InterfaceToInterfaceInterfaceMap(inputSchema.Type); ok { + for k := range interfaceMap { + _, isGenericMap = k.(CompoundType) + break } - return nil } - if valueTypeStrs, ok := cast.InterfaceToStrSlice(valueType); ok { - if len(valueTypeStrs) != 1 { - return errors.Wrap(ErrorTypeListLength(valueTypeStrs)) - } - if !isValidValueType(valueTypeStrs[0]) { - return ErrorInvalidValueDataType(valueTypeStrs[0]) + if inputSchema.MinCount != nil { + if !isGenericMap && !isSlice { + return ErrorOptionOnNonIterable(MinCountOptKey) } - return nil } - if valueTypeMap, ok := cast.InterfaceToInterfaceInterfaceMap(valueType); ok { - foundGenericKey := false - for key := range valueTypeMap { - if strKey, ok := key.(string); ok { - if isValidValueType(strKey) { - foundGenericKey = true - break - } - } - } - if foundGenericKey && len(valueTypeMap) != 1 { - return ErrorGenericTypeMapLength(valueTypeMap) + if inputSchema.MaxCount != nil { + if !isGenericMap && !isSlice { + return ErrorOptionOnNonIterable(MaxCountOptKey) } - - for key, val := range valueTypeMap { - if foundGenericKey { - err := ValidateValueType(key) - if err != nil { - return err - } - } - err := ValidateValueType(val) - if err != nil { - return errors.Wrap(err, s.UserStrStripped(key)) - } - } - return nil } - return ErrorInvalidValueDataType(valueType) -} + if inputSchema.MinCount != nil && inputSchema.MaxCount != nil && *inputSchema.MinCount > *inputSchema.MaxCount { + return ErrorMinCountGreaterThanMaxCount() + } -func ValidateArgValues(argValues map[string]interface{}) error { - for argName, value := range argValues { - err := ValidateValue(value) + // Validate default against schema + if inputSchema.Default != nil { + var err error + inputSchema.Default, err = CastInputValue(inputSchema.Default, inputSchema) if err != nil { - return errors.Wrap(err, argName) + return errors.Wrap(err, DefaultOptKey) } } - return nil -} -func ValidateValue(value interface{}) error { return nil } -func CastValue(value interface{}, valueType interface{}) (interface{}, error) { - err := ValidateValueType(valueType) - if err != nil { - return nil, err - } - err = ValidateValue(value) - if err != nil { - return nil, err - } - +func CastInputValue(value interface{}, inputSchema *InputSchema) (interface{}, error) { + // Check for null if value == nil { - return nil, nil + if inputSchema.AllowNull { + return nil, nil + } + return nil, ErrorCannotBeNull() } - if valueTypeStr, ok := valueType.(string); ok { - validTypes := strings.Split(valueTypeStr, "|") - var validTypeNames []configreader.PrimitiveType + typeSchema := inputSchema.Type - if slices.HasString(validTypes, IntegerValueType.String()) { - validTypeNames = append(validTypeNames, configreader.PrimTypeInt) - valueInt, ok := cast.InterfaceToInt64(value) - if ok { - return valueInt, nil - } + // CompoundType + if compoundType, ok := typeSchema.(CompoundType); ok { + return compoundType.CastValue(value) + } + + // array of *InputSchema + if inputSchemas, ok := cast.InterfaceToInterfaceSlice(typeSchema); ok { + values, ok := cast.InterfaceToInterfaceSlice(value) + if !ok { + return nil, ErrorUnsupportedLiteralType(value, typeSchema) } - if slices.HasString(validTypes, FloatValueType.String()) { - validTypeNames = append(validTypeNames, configreader.PrimTypeFloat) - valueFloat, ok := cast.InterfaceToFloat64(value) - if ok { - return valueFloat, nil - } + + if inputSchema.MinCount != nil && int64(len(values)) < *inputSchema.MinCount { + return nil, ErrorTooFewElements(configreader.PrimTypeList, *inputSchema.MinCount) } - if slices.HasString(validTypes, StringValueType.String()) { - validTypeNames = append(validTypeNames, configreader.PrimTypeString) - if valueStr, ok := value.(string); ok { - return valueStr, nil - } + if inputSchema.MaxCount != nil && int64(len(values)) > *inputSchema.MaxCount { + return nil, ErrorTooManyElements(configreader.PrimTypeList, *inputSchema.MaxCount) } - if slices.HasString(validTypes, BoolValueType.String()) { - validTypeNames = append(validTypeNames, configreader.PrimTypeBool) - if valueBool, ok := value.(bool); ok { - return valueBool, nil + + valuesCasted := make([]interface{}, len(values)) + for i, valueItem := range values { + valueItemCasted, err := CastInputValue(valueItem, inputSchemas[0].(*InputSchema)) + if err != nil { + return nil, errors.Wrap(err, s.Index(i)) } + valuesCasted[i] = valueItemCasted } - return nil, configreader.ErrorInvalidPrimitiveType(value, validTypeNames...) + return valuesCasted, nil } - if valueTypeMap, ok := cast.InterfaceToInterfaceInterfaceMap(valueType); ok { + // Map + if typeSchemaMap, ok := cast.InterfaceToInterfaceInterfaceMap(typeSchema); ok { valueMap, ok := cast.InterfaceToInterfaceInterfaceMap(value) if !ok { - return nil, configreader.ErrorInvalidPrimitiveType(value, configreader.PrimTypeMap) + return nil, ErrorUnsupportedLiteralType(value, typeSchema) } - if len(valueTypeMap) == 0 { - if len(valueMap) == 0 { - return make(map[interface{}]interface{}), nil + var genericKey CompoundType + var genericValue *InputSchema + for k, v := range typeSchemaMap { + ok := false + if genericKey, ok = k.(CompoundType); ok { + genericValue = v.(*InputSchema) } - return nil, errors.Wrap(configreader.ErrorMustBeEmpty(), s.UserStr(valueMap)) } - isGenericMap := false - var genericMapKeyType string - var genericMapValueType interface{} - if len(valueTypeMap) == 1 { - for valueTypeKey, valueTypeVal := range valueTypeMap { // Will only be length one - if valueTypeKeyStr, ok := valueTypeKey.(string); ok { - if isValidValueType(valueTypeKeyStr) { - isGenericMap = true - genericMapKeyType = valueTypeKeyStr - genericMapValueType = valueTypeVal - } - } + valueMapCasted := make(map[interface{}]interface{}, len(valueMap)) + + // Generic map + if genericValue != nil { + if inputSchema.MinCount != nil && int64(len(valueMap)) < *inputSchema.MinCount { + return nil, ErrorTooFewElements(configreader.PrimTypeMap, *inputSchema.MinCount) + } + if inputSchema.MaxCount != nil && int64(len(valueMap)) > *inputSchema.MaxCount { + return nil, ErrorTooManyElements(configreader.PrimTypeMap, *inputSchema.MaxCount) } - } - if isGenericMap { - valueMapCasted := make(map[interface{}]interface{}, len(valueMap)) for valueKey, valueVal := range valueMap { - valueKeyCasted, err := CastValue(valueKey, genericMapKeyType) + valueKeyCasted, err := CastInputValue(valueKey, &InputSchema{Type: genericKey}) if err != nil { return nil, err } - valueValCasted, err := CastValue(valueVal, genericMapValueType) + valueValCasted, err := CastInputValue(valueVal, genericValue) if err != nil { return nil, errors.Wrap(err, s.UserStrStripped(valueKey)) } @@ -344,163 +337,204 @@ func CastValue(value interface{}, valueType interface{}) (interface{}, error) { return valueMapCasted, nil } - // Non-generic map - valueMapCasted := make(map[interface{}]interface{}, len(valueMap)) - for valueKey, valueType := range valueTypeMap { - valueVal, ok := valueMap[valueKey] - if !ok { - return nil, errors.Wrap(configreader.ErrorMustBeDefined(), s.UserStrStripped(valueKey)) - } - valueValCasted, err := CastValue(valueVal, valueType) - if err != nil { - return nil, errors.Wrap(err, s.UserStrStripped(valueKey)) + // Fixed map + for typeSchemaKey, typeSchemaValue := range typeSchemaMap { + valueVal, ok := valueMap[typeSchemaKey] + if ok { + valueValCasted, err := CastInputValue(valueVal, typeSchemaValue.(*InputSchema)) + if err != nil { + return nil, errors.Wrap(err, s.UserStrStripped(typeSchemaKey)) + } + valueMapCasted[typeSchemaKey] = valueValCasted + } else { + if !typeSchemaValue.(*InputSchema).Optional { + return nil, ErrorMustBeDefined(typeSchemaValue) + } + // don't set default (python has to) } - valueMapCasted[valueKey] = valueValCasted } for valueKey := range valueMap { - if _, ok := valueTypeMap[valueKey]; !ok { - return nil, configreader.ErrorUnsupportedKey(valueKey) + if _, ok := typeSchemaMap[valueKey]; !ok { + return nil, ErrorUnsupportedLiteralMapKey(valueKey, typeSchemaMap) } } return valueMapCasted, nil } - if valueTypeStrs, ok := cast.InterfaceToStrSlice(valueType); ok { - valueTypeStr := valueTypeStrs[0] - valueSlice, ok := cast.InterfaceToInterfaceSlice(value) - if !ok { - return nil, configreader.ErrorInvalidPrimitiveType(value, configreader.PrimTypeList) - } - valueSliceCasted := make([]interface{}, len(valueSlice)) - for i, valueItem := range valueSlice { - valueItemCasted, err := CastValue(valueItem, valueTypeStr) - if err != nil { - return nil, errors.Wrap(err, s.Index(i)) + return nil, ErrorInvalidInputType(typeSchema) // unexpected +} + +func ValidateOutputSchema(in interface{}) (OutputSchema, error) { + // String + if inStr, ok := in.(string); ok { + valueType := ValueTypeFromString(inStr) + if valueType == UnknownValueType { + if colType := ColumnTypeFromString(inStr); colType != UnknownColumnType && colType != InferredColumnType { + return nil, ErrorColumnTypeNotAllowed(inStr) + } + if _, err := CompoundTypeFromString(inStr); err == nil { + return nil, ErrorCompoundTypeInOutputType(inStr) } - valueSliceCasted[i] = valueItemCasted + return nil, ErrorInvalidOutputType(inStr) } - return valueSliceCasted, nil + return valueType, nil } - return nil, ErrorInvalidValueDataType(valueType) // unexpected -} - -func CheckArgRuntimeTypesMatch(argRuntimeTypes map[string]interface{}, argSchemaTypes map[string]interface{}) error { - err := ValidateArgTypes(argSchemaTypes) - if err != nil { - return err - } - err = ValidateArgTypes(argRuntimeTypes) - if err != nil { - return err + // List + if inSlice, ok := cast.InterfaceToInterfaceSlice(in); ok { + if len(inSlice) != 1 { + return nil, ErrorTypeListLength(inSlice) + } + outputSchema, err := ValidateOutputSchema(inSlice[0]) + if err != nil { + return nil, errors.Wrap(err, s.Index(0)) + } + return []interface{}{outputSchema}, nil } - for argName, argSchemaType := range argSchemaTypes { - if len(argRuntimeTypes) == 0 { - return configreader.ErrorMapMustBeDefined(maps.InterfaceMapKeys(argSchemaTypes)...) + // Map + if inMap, ok := cast.InterfaceToInterfaceInterfaceMap(in); ok { + if len(inMap) == 0 { + return nil, ErrorTypeMapZeroLength(inMap) } - argRuntimeType, ok := argRuntimeTypes[argName] - if !ok { - return errors.Wrap(configreader.ErrorMustBeDefined(), argName) + var typeKey ValueType + var typeValue interface{} + for k, v := range inMap { + if kStr, ok := k.(string); ok { + typeKey = ValueTypeFromString(kStr) + if typeKey != UnknownValueType { + typeValue = v + break + } + if colType := ColumnTypeFromString(kStr); colType != UnknownColumnType && colType != InferredColumnType { + return nil, ErrorColumnTypeNotAllowed(kStr) + } + if _, err := CompoundTypeFromString(kStr); err == nil { + return nil, ErrorCompoundTypeInOutputType(kStr) + } + } } - err := CheckValueRuntimeTypesMatch(argRuntimeType, argSchemaType) - if err != nil { - return errors.Wrap(err, argName) + + // Generic map + if typeValue != nil { + if len(inMap) != 1 { + return nil, ErrorGenericTypeMapLength(inMap) + } + valueOutputSchema, err := ValidateOutputSchema(typeValue) + if err != nil { + return nil, errors.Wrap(err, string(typeKey)) + } + return map[interface{}]interface{}{typeKey: valueOutputSchema}, nil } - } - for argName := range argRuntimeTypes { - if _, ok := argSchemaTypes[argName]; !ok { - return configreader.ErrorUnsupportedKey(argName) + // Fixed map + castedSchemaMap := map[interface{}]interface{}{} + for key, value := range inMap { + if !cast.IsScalarType(key) { + return nil, configreader.ErrorInvalidPrimitiveType(key, configreader.PrimTypeScalars...) + } + if keyStr, ok := key.(string); ok { + if strings.HasPrefix(keyStr, "_") { + return nil, ErrorUserKeysCannotStartWithUnderscore(keyStr) + } + } + + valueOutputSchema, err := ValidateOutputSchema(value) + if err != nil { + return nil, errors.Wrap(err, s.UserStrStripped(key)) + } + castedSchemaMap[key] = valueOutputSchema } + return castedSchemaMap, nil } - return nil + return nil, ErrorInvalidOutputType(in) } -func CheckValueRuntimeTypesMatch(runtimeType interface{}, schemaType interface{}) error { - if schemaTypeStr, ok := schemaType.(string); ok { - validTypes := strings.Split(schemaTypeStr, "|") - runtimeTypeStr, ok := runtimeType.(string) +func CastOutputValue(value interface{}, outputSchema OutputSchema) (interface{}, error) { + // Check for missing + if value == nil { + return nil, ErrorCannotBeNull() + } + + // ValueType + if valueType, ok := outputSchema.(ValueType); ok { + return valueType.CastValue(value) + } + + // Array + if typeSchemas, ok := cast.InterfaceToInterfaceSlice(outputSchema); ok { + values, ok := cast.InterfaceToInterfaceSlice(value) if !ok { - return ErrorUnsupportedDataType(runtimeType, schemaTypeStr) + return nil, ErrorUnsupportedLiteralType(value, outputSchema) } - for _, runtimeTypeOption := range strings.Split(runtimeTypeStr, "|") { - if !slices.HasString(validTypes, runtimeTypeOption) { - return ErrorUnsupportedDataType(runtimeTypeStr, schemaTypeStr) + valuesCasted := make([]interface{}, len(values)) + for i, valueItem := range values { + valueItemCasted, err := CastOutputValue(valueItem, typeSchemas[0]) + if err != nil { + return nil, errors.Wrap(err, s.Index(i)) } + valuesCasted[i] = valueItemCasted } - return nil + return valuesCasted, nil } - if schemaTypeMap, ok := cast.InterfaceToInterfaceInterfaceMap(schemaType); ok { - runtimeTypeMap, ok := cast.InterfaceToInterfaceInterfaceMap(runtimeType) + // Map + if typeSchemaMap, ok := cast.InterfaceToInterfaceInterfaceMap(outputSchema); ok { + valueMap, ok := cast.InterfaceToInterfaceInterfaceMap(value) if !ok { - return ErrorUnsupportedDataType(runtimeType, schemaTypeMap) - } - - isGenericMap := false - var genericMapKeyType string - var genericMapValueType interface{} - if len(schemaTypeMap) == 1 { - for schemaTypeKey, schemaTypeValue := range schemaTypeMap { // Will only be length one - if schemaTypeMapStr, ok := schemaTypeKey.(string); ok { - if isValidValueType(schemaTypeMapStr) { - isGenericMap = true - genericMapKeyType = schemaTypeMapStr - genericMapValueType = schemaTypeValue - } - } + return nil, ErrorUnsupportedLiteralType(value, outputSchema) + } + + isGeneric := false + var genericKey ValueType + var genericValue interface{} + for k, v := range typeSchemaMap { + ok := false + if genericKey, ok = k.(ValueType); ok { + isGeneric = true + genericValue = v } } - if isGenericMap { - for runtimeTypeKey, runtimeTypeValue := range runtimeTypeMap { // Should only be one item - err := CheckValueRuntimeTypesMatch(runtimeTypeKey, genericMapKeyType) + valueMapCasted := make(map[interface{}]interface{}, len(valueMap)) + + // Generic map + if isGeneric { + for valueKey, valueVal := range valueMap { + valueKeyCasted, err := CastOutputValue(valueKey, genericKey) if err != nil { - return err + return nil, err } - err = CheckValueRuntimeTypesMatch(runtimeTypeValue, genericMapValueType) + valueValCasted, err := CastOutputValue(valueVal, genericValue) if err != nil { - return errors.Wrap(err, s.UserStrStripped(runtimeTypeKey)) + return nil, errors.Wrap(err, s.UserStrStripped(valueKey)) } + valueMapCasted[valueKeyCasted] = valueValCasted } - return nil + return valueMapCasted, nil } - // Non-generic map - for schemaTypeKey, schemaTypeValue := range schemaTypeMap { - runtimeTypeValue, ok := runtimeTypeMap[schemaTypeKey] + // Fixed map + for typeSchemaKey, typeSchemaValue := range typeSchemaMap { + valueVal, ok := valueMap[typeSchemaKey] if !ok { - return errors.Wrap(configreader.ErrorMustBeDefined(), s.UserStrStripped(schemaTypeKey)) + return nil, ErrorMustBeDefined(typeSchemaValue) } - err := CheckValueRuntimeTypesMatch(runtimeTypeValue, schemaTypeValue) + valueValCasted, err := CastOutputValue(valueVal, typeSchemaValue) if err != nil { - return errors.Wrap(err, s.UserStrStripped(schemaTypeKey)) + return nil, errors.Wrap(err, s.UserStrStripped(typeSchemaKey)) } + valueMapCasted[typeSchemaKey] = valueValCasted } - for runtimeTypeKey := range runtimeTypeMap { - if _, ok := schemaTypeMap[runtimeTypeKey]; !ok { - return configreader.ErrorUnsupportedKey(runtimeTypeKey) - } - } - return nil - } - - if schemaTypeStrs, ok := cast.InterfaceToStrSlice(schemaType); ok { - validTypes := strings.Split(schemaTypeStrs[0], "|") - runtimeTypeStrs, ok := cast.InterfaceToStrSlice(runtimeType) - if !ok { - return ErrorUnsupportedDataType(runtimeType, schemaTypeStrs) - } - for _, runtimeTypeOption := range strings.Split(runtimeTypeStrs[0], "|") { - if !slices.HasString(validTypes, runtimeTypeOption) { - return ErrorUnsupportedDataType(runtimeTypeStrs, schemaTypeStrs) + for valueKey := range valueMap { + if _, ok := typeSchemaMap[valueKey]; !ok { + return nil, ErrorUnsupportedLiteralMapKey(valueKey, typeSchemaMap) } } - return nil + return valueMapCasted, nil } - return ErrorInvalidValueDataType(schemaType) // unexpected + return nil, ErrorInvalidOutputType(outputSchema) // unexpected } diff --git a/pkg/operator/api/userconfig/validators_test.go b/pkg/operator/api/userconfig/validators_test.go index 8dc660dc24..79f057edab 100644 --- a/pkg/operator/api/userconfig/validators_test.go +++ b/pkg/operator/api/userconfig/validators_test.go @@ -21,530 +21,152 @@ import ( "github.com/stretchr/testify/require" - "github.com/cortexlabs/cortex/pkg/lib/cast" cr "github.com/cortexlabs/cortex/pkg/lib/configreader" ) -func TestValidateColumnInputTypes(t *testing.T) { - var columnTypes map[string]interface{} - - columnTypes = cr.MustReadYAMLStrMap("num: FLOAT_COLUMN") - require.NoError(t, ValidateColumnInputTypes(columnTypes)) - - columnTypes = cr.MustReadYAMLStrMap( - ` - float: FLOAT_COLUMN - int: INT_COLUMN - str: STRING_COLUMN - int_list: FLOAT_LIST_COLUMN - float_list: INT_LIST_COLUMN - str_list: STRING_LIST_COLUMN - `) - require.NoError(t, ValidateColumnInputTypes(columnTypes)) - - columnTypes = cr.MustReadYAMLStrMap( - ` - num1: FLOAT_COLUMN|INT_COLUMN - num2: INT_COLUMN|FLOAT_COLUMN - num3: STRING_COLUMN|INT_COLUMN - num4: INT_COLUMN|FLOAT_COLUMN|STRING_COLUMN - num5: STRING_COLUMN|INT_COLUMN|FLOAT_COLUMN - num6: STRING_LIST_COLUMN|INT_LIST_COLUMN|FLOAT_LIST_COLUMN - num7: STRING_COLUMN|INT_LIST_COLUMN|FLOAT_LIST_COLUMN - `) - require.NoError(t, ValidateColumnInputTypes(columnTypes)) - - columnTypes = cr.MustReadYAMLStrMap( - ` - nums1: [INT_COLUMN] - nums2: [FLOAT_COLUMN] - nums3: [INT_COLUMN|FLOAT_COLUMN] - nums4: [FLOAT_COLUMN|INT_COLUMN] - nums5: [STRING_COLUMN|INT_COLUMN|FLOAT_COLUMN] - nums6: [INT_LIST_COLUMN] - nums7: [INT_LIST_COLUMN|STRING_LIST_COLUMN] - nums8: [INT_LIST_COLUMN|STRING_COLUMN] - strs: [STRING_COLUMN] - num1: FLOAT_COLUMN - num2: INT_COLUMN - str_list: STRING_LIST_COLUMN - `) - require.NoError(t, ValidateColumnInputTypes(columnTypes)) - - columnTypes = cr.MustReadYAMLStrMap("num: bad") - require.Error(t, ValidateColumnInputTypes(columnTypes)) - - columnTypes = cr.MustReadYAMLStrMap("num: BOOL") - require.Error(t, ValidateColumnInputTypes(columnTypes)) - - columnTypes = cr.MustReadYAMLStrMap("num: [STRING_COLUMN, INT_COLUMN]") - require.Error(t, ValidateColumnInputTypes(columnTypes)) - - columnTypes = cr.MustReadYAMLStrMap("num: FLOAT") - require.Error(t, ValidateColumnInputTypes(columnTypes)) - columnTypes = cr.MustReadYAMLStrMap("num: FLOAT_COLUMNs") - require.Error(t, ValidateColumnInputTypes(columnTypes)) - - columnTypes = cr.MustReadYAMLStrMap("num: 1") - require.Error(t, ValidateColumnInputTypes(columnTypes)) - - columnTypes = cr.MustReadYAMLStrMap("num: [1]") - require.Error(t, ValidateColumnInputTypes(columnTypes)) - - columnTypes = cr.MustReadYAMLStrMap("num: {nested: STRING_COLUMN}") - require.Error(t, ValidateColumnInputTypes(columnTypes)) -} - -func TestValidateColumnInputValues(t *testing.T) { - var columnInputValues map[string]interface{} - - columnInputValues = cr.MustReadYAMLStrMap("num: age") - require.NoError(t, ValidateColumnInputValues(columnInputValues)) - - columnInputValues = cr.MustReadYAMLStrMap( - ` - num1: age - num2: income - str: prior_default - `) - require.NoError(t, ValidateColumnInputValues(columnInputValues)) - - columnInputValues = cr.MustReadYAMLStrMap( - ` - num1: age - num2: income - `) - require.NoError(t, ValidateColumnInputValues(columnInputValues)) - - columnInputValues = cr.MustReadYAMLStrMap( - ` - nums1: [age, income] - nums2: [income] - nums3: [age, income, years_employed] - strs: [prior_default, approved] - num1: age - num2: income - str: prior_default - `) - require.NoError(t, ValidateColumnInputValues(columnInputValues)) - - columnInputValues = cr.MustReadYAMLStrMap("num: 1") - require.Error(t, ValidateColumnInputValues(columnInputValues)) - - columnInputValues = cr.MustReadYAMLStrMap("num: [1]") - require.Error(t, ValidateColumnInputValues(columnInputValues)) - - columnInputValues = cr.MustReadYAMLStrMap("num: {nested: STRING_COLUMN}") - require.Error(t, ValidateColumnInputValues(columnInputValues)) -} - -func TestCheckColumnRuntimeTypesMatch(t *testing.T) { - var columnTypes map[string]interface{} - var runtimeTypes map[string]interface{} - - columnTypes = cr.MustReadYAMLStrMap("in: INT_COLUMN") - runtimeTypes = readRuntimeTypes("in: INT_COLUMN") - require.NoError(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - runtimeTypes = readRuntimeTypes("in: FLOAT_COLUMN") - require.Error(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - runtimeTypes = readRuntimeTypes("in: [INT_COLUMN]") - require.Error(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - - columnTypes = cr.MustReadYAMLStrMap("in: INT_COLUMN|FLOAT_COLUMN") - runtimeTypes = readRuntimeTypes("in: INT_COLUMN") - require.NoError(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - runtimeTypes = readRuntimeTypes("in: FLOAT_COLUMN") - require.NoError(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - runtimeTypes = readRuntimeTypes("in: STRING_COLUMN") - require.Error(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - - columnTypes = cr.MustReadYAMLStrMap("in: STRING_COLUMN|INT_COLUMN|FLOAT_COLUMN") - runtimeTypes = readRuntimeTypes("in: INT_COLUMN") - require.NoError(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - runtimeTypes = readRuntimeTypes("in: FLOAT_COLUMN") - require.NoError(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - runtimeTypes = readRuntimeTypes("in: STRING_COLUMN") - require.NoError(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - runtimeTypes = readRuntimeTypes("in: BAD_COLUMN") - require.Error(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - - columnTypes = cr.MustReadYAMLStrMap("in: [INT_COLUMN]") - runtimeTypes = readRuntimeTypes("in: [INT_COLUMN]") - require.NoError(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - runtimeTypes = readRuntimeTypes("in: [INT_COLUMN, INT_COLUMN, INT_COLUMN]") - require.NoError(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - runtimeTypes = readRuntimeTypes("in: INT_COLUMN") - require.Error(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - runtimeTypes = readRuntimeTypes("in: [FLOAT_COLUMN]") - require.Error(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - runtimeTypes = readRuntimeTypes("in: [INT_COLUMN, FLOAT_COLUMN, INT_COLUMN]") - require.Error(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - - columnTypes = cr.MustReadYAMLStrMap("in: [INT_COLUMN|FLOAT_COLUMN]") - runtimeTypes = readRuntimeTypes("in: [INT_COLUMN]") - require.NoError(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - runtimeTypes = readRuntimeTypes("in: [INT_COLUMN, INT_COLUMN, INT_COLUMN]") - require.NoError(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - runtimeTypes = readRuntimeTypes("in: INT_COLUMN") - require.Error(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - runtimeTypes = readRuntimeTypes("in: [FLOAT_COLUMN]") - require.NoError(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - runtimeTypes = readRuntimeTypes("in: [INT_COLUMN, FLOAT_COLUMN, INT_COLUMN]") - require.NoError(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - runtimeTypes = readRuntimeTypes("in: [INT_COLUMN, FLOAT_COLUMN, STRING_COLUMN]") - require.Error(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - - columnTypes = cr.MustReadYAMLStrMap("in: [STRING_COLUMN|INT_COLUMN|FLOAT_COLUMN]") - runtimeTypes = readRuntimeTypes("in: [INT_COLUMN]") - require.NoError(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - runtimeTypes = readRuntimeTypes("in: [INT_COLUMN, INT_COLUMN, INT_COLUMN]") - require.NoError(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - runtimeTypes = readRuntimeTypes("in: INT_COLUMN") - require.Error(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - runtimeTypes = readRuntimeTypes("in: [BAD_COLUMN]") - require.Error(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - runtimeTypes = readRuntimeTypes("in: [STRING_COLUMN]") - require.NoError(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - runtimeTypes = readRuntimeTypes("in: [INT_COLUMN, FLOAT_COLUMN, STRING_COLUMN]") - require.NoError(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - - columnTypes = cr.MustReadYAMLStrMap("in1: [INT_COLUMN]\nin2: STRING_COLUMN") - runtimeTypes = readRuntimeTypes("in1: [INT_COLUMN]\nin2: STRING_COLUMN") - require.NoError(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - runtimeTypes = readRuntimeTypes("in1: [INT_COLUMN, INT_COLUMN]\nin2: STRING_COLUMN") - require.NoError(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - runtimeTypes = readRuntimeTypes("in2: STRING_COLUMN\nin1: [INT_COLUMN, INT_COLUMN]") - require.NoError(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - runtimeTypes = readRuntimeTypes("in1: [INT_COLUMN]") - require.Error(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - runtimeTypes = readRuntimeTypes("in1: [INT_COLUMN]\nin2: STRING_COLUMN\nin3: INT_COLUMN") - require.Error(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - - columnTypes = cr.MustReadYAMLStrMap("in1: [INT_COLUMN|FLOAT_COLUMN|STRING_COLUMN]\nin2: STRING_COLUMN") - runtimeTypes = readRuntimeTypes("in1: [INT_COLUMN]\nin2: STRING_COLUMN") - require.NoError(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) - runtimeTypes = readRuntimeTypes("in1: [INT_COLUMN, FLOAT_COLUMN, STRING_COLUMN, FLOAT_COLUMN]\nin2: STRING_COLUMN") - require.NoError(t, CheckColumnRuntimeTypesMatch(runtimeTypes, columnTypes)) -} - -func readRuntimeTypes(yamlStr string) map[string]interface{} { - runtimeTypes := make(map[string]interface{}) - runtimeTypesStr := cr.MustReadYAMLStrMap(yamlStr) - - for k, v := range runtimeTypesStr { - if runtimeTypeStr, ok := v.(string); ok { - runtimeTypes[k] = ColumnTypeFromString(runtimeTypeStr) - } else if runtimeTypeStrs, ok := cast.InterfaceToStrSlice(v); ok { - runtimeTypesSlice := make([]ColumnType, len(runtimeTypeStrs)) - for i, runtimeTypeStr := range runtimeTypeStrs { - runtimeTypesSlice[i] = ColumnTypeFromString(runtimeTypeStr) - } - runtimeTypes[k] = runtimeTypesSlice - } - } - - return runtimeTypes -} - -func TestValidateArgTypes(t *testing.T) { - var argTypes map[string]interface{} - - argTypes = cr.MustReadYAMLStrMap("STRING: FLOAT") - require.Error(t, ValidateArgTypes(argTypes)) -} - -func TestValidateValueType(t *testing.T) { - var valueType interface{} - - valueType = "FLOAT" - require.NoError(t, ValidateValueType(valueType)) - - valueType = "FLOAT|INT|BOOL|STRING" - require.NoError(t, ValidateValueType(valueType)) - - valueType = []string{"INT|FLOAT"} - require.NoError(t, ValidateValueType(valueType)) - - valueType = cr.MustReadYAMLStrMap("STRING: FLOAT") - require.NoError(t, ValidateValueType(valueType)) - - valueType = cr.MustReadYAMLStrMap("num: FLOAT") - require.NoError(t, ValidateValueType(valueType)) - - valueType = cr.MustReadYAMLStrMap("bools: [BOOL]") - require.NoError(t, ValidateValueType(valueType)) - - valueType = cr.MustReadYAMLStrMap("bools: [BOOL|FLOAT|INT]") - require.NoError(t, ValidateValueType(valueType)) - - valueType = cr.MustReadYAMLStrMap("STRING: INT") - require.NoError(t, ValidateValueType(valueType)) - - valueType = cr.MustReadYAMLStrMap("map: {BOOL|FLOAT: INT|STRING}") - require.NoError(t, ValidateValueType(valueType)) - - valueType = cr.MustReadYAMLStrMap("map: {mean: FLOAT, stddev: FLOAT}") - require.NoError(t, ValidateValueType(valueType)) - - valueType = cr.MustReadYAMLStrMap("map: {STRING: {lat: FLOAT, lon: FLOAT}}") - require.NoError(t, ValidateValueType(valueType)) - - valueType = cr.MustReadYAMLStrMap("map: {STRING: {lat: FLOAT, lon: [FLOAT]}}") - require.NoError(t, ValidateValueType(valueType)) - - valueType = cr.MustReadYAMLStrMap("map: {STRING: {FLOAT: INT}}") - require.NoError(t, ValidateValueType(valueType)) - - valueType = cr.MustReadYAMLStrMap("map: {STRING: {FLOAT: [INT]}}") - require.NoError(t, ValidateValueType(valueType)) - - valueType = cr.MustReadYAMLStrMap("map: {STRING: {lat: FLOAT, lon: {lat2: FLOAT, lon2: INT}}}") - require.NoError(t, ValidateValueType(valueType)) - - valueType = cr.MustReadYAMLStrMap("map6: {STRING: {lat: FLOAT, lon: {lat2: FLOAT, lon2: {INT: STRING}}}}") - require.NoError(t, ValidateValueType(valueType)) - - valueType = cr.MustReadYAMLStrMap("map6: {STRING: {lat: FLOAT, lon: {lat2: FLOAT, lon2: {INT: STRING}, mean: BOOL}}}") - require.NoError(t, ValidateValueType(valueType)) - - valueType = cr.MustReadYAMLStrMap( - ` - num: [INT] - str: STRING - map1: {STRING: INT} - map2: {mean: FLOAT, stddev: FLOAT} - map3: {STRING: {lat: FLOAT, lon: FLOAT}} - map3: {STRING: {lat: FLOAT, lon: [FLOAT]}} - map4: {STRING: {FLOAT: INT}} - map5: {STRING: {BOOL: [INT]}} - map6: {STRING: {lat: FLOAT, lon: {lat2: FLOAT, lon2: INT}}} - map6: {STRING: {lat: FLOAT, lon: {lat2: FLOAT, lon2: {INT: STRING}, mean: BOOL}}} - `) - require.NoError(t, ValidateValueType(valueType)) - - valueType = "FLOAT|INT|BAD" - require.Error(t, ValidateValueType(valueType)) - - valueType = []string{"INT", "FLOAT"} - require.Error(t, ValidateValueType(valueType)) - - valueType = cr.MustReadYAMLStrMap("num: FLOATs") - require.Error(t, ValidateValueType(valueType)) - - valueType = 1 - require.Error(t, ValidateValueType(valueType)) - - valueType = cr.MustReadYAMLStrMap("num: 1") - require.Error(t, ValidateValueType(valueType)) - - valueType = cr.MustReadYAMLStrMap("num: [1]") - require.Error(t, ValidateValueType(valueType)) - - valueType = cr.MustReadYAMLStrMap("STRING: test") - require.Error(t, ValidateValueType(valueType)) - - valueType = cr.MustReadYAMLStrMap("map: {STRING: INT, INT: FLOAT}") - require.Error(t, ValidateValueType(valueType)) - - valueType = cr.MustReadYAMLStrMap("map: {STRING: INT, INT: [FLOAT]}") - require.Error(t, ValidateValueType(valueType)) - - valueType = cr.MustReadYAMLStrMap("map: {mean: FLOAT, INT: FLOAT}") - require.Error(t, ValidateValueType(valueType)) - - valueType = cr.MustReadYAMLStrMap("map: {mean: FLOAT, INT: [FLOAT]}") - require.Error(t, ValidateValueType(valueType)) - - valueType = cr.MustReadYAMLStrMap("map: {STRING: {lat: FLOAT, STRING: FLOAT}}") - require.Error(t, ValidateValueType(valueType)) - - valueType = cr.MustReadYAMLStrMap("map: {STRING: {STRING: test}}") - require.Error(t, ValidateValueType(valueType)) -} - -func TestCastValue(t *testing.T) { - var valueType interface{} - var value interface{} - var casted interface{} +func TestValidateOutputSchema(t *testing.T) { var err error - valueType = "INT" - value = int64(2) - _, err = CastValue(value, valueType) + _, err = ValidateOutputSchema(cr.MustReadYAMLStr( + `STRING`)) require.NoError(t, err) - value = float64(2.2) - _, err = CastValue(value, valueType) + + _, err = ValidateOutputSchema(cr.MustReadYAMLStr( + `STRING|INT`)) require.Error(t, err) - value = nil - _, err = CastValue(value, valueType) - require.NoError(t, err) - valueType = "FLOAT" - value = float64(2.2) - _, err = CastValue(value, valueType) - require.NoError(t, err) - value = int64(2) - casted, err = CastValue(value, valueType) - require.NoError(t, err) - require.Equal(t, casted, float64(2)) + _, err = ValidateOutputSchema(cr.MustReadYAMLStr( + `STRING_COLUMN`)) + require.Error(t, err) - valueType = "BOOL" - value = false - _, err = CastValue(value, valueType) - require.NoError(t, err) - value = 2 - _, err = CastValue(value, valueType) + _, err = ValidateOutputSchema(cr.MustReadYAMLStr( + `bad`)) require.Error(t, err) - valueType = "FLOAT|INT" - value = float64(2.2) - _, err = CastValue(value, valueType) + _, err = ValidateOutputSchema(cr.MustReadYAMLStr( + `[STRING]`)) require.NoError(t, err) - value = int64(2) - casted, err = CastValue(value, valueType) - require.NoError(t, err) - require.Equal(t, casted, int64(2)) - valueType = cr.MustReadYAMLStrMap("STRING: FLOAT") - value = cr.MustReadYAMLStrMap("test: 2.2") - _, err = CastValue(value, valueType) - require.NoError(t, err) - value = cr.MustReadYAMLStrMap("test: 2.2\ntest2: 4.4") - casted, err = CastValue(value, valueType) - require.NoError(t, err) - require.Equal(t, casted, map[interface{}]interface{}{"test": 2.2, "test2": 4.4}) - value = cr.MustReadYAMLStrMap("test: test2") - _, err = CastValue(value, valueType) + _, err = ValidateOutputSchema(cr.MustReadYAMLStr( + `[STRING, INT]`)) require.Error(t, err) - value = map[int]float64{2: 2.2} - _, err = CastValue(value, valueType) + + _, err = ValidateOutputSchema(cr.MustReadYAMLStr( + `[STRING|INT]`)) require.Error(t, err) - value = make(map[string]float64) - _, err = CastValue(value, valueType) - require.NoError(t, err) - value = cr.MustReadYAMLStrMap("test: 2") // YAML - casted, err = CastValue(value, valueType) - require.NoError(t, err) - require.Equal(t, casted, map[interface{}]interface{}{"test": float64(2)}) - value = cr.MustReadJSONStr(`{"test": 2}`) // JSON - casted, err = CastValue(value, valueType) - require.NoError(t, err) - require.Equal(t, casted, map[interface{}]interface{}{"test": float64(2)}) - value = cr.MustReadYAMLStrMap("test: 2.0") // YAML - casted, err = CastValue(value, valueType) - require.NoError(t, err) - require.Equal(t, casted, map[interface{}]interface{}{"test": float64(2)}) - value = cr.MustReadJSONStr(`{"test": 2.0}`) // JSON - casted, err = CastValue(value, valueType) - require.NoError(t, err) - require.Equal(t, casted, map[interface{}]interface{}{"test": float64(2)}) - valueType = cr.MustReadYAMLStrMap("STRING: INT") - value = cr.MustReadYAMLStrMap("test: 2.2") - _, err = CastValue(value, valueType) + _, err = ValidateOutputSchema(cr.MustReadYAMLStr( + `[STRING_COLUMN]`)) require.Error(t, err) - value = cr.MustReadYAMLStrMap("test: 2\ntest2: 2.2") - _, err = CastValue(value, valueType) + + _, err = ValidateOutputSchema(cr.MustReadYAMLStr( + `[bad]`)) require.Error(t, err) - value = cr.MustReadYAMLStrMap("test: 2") // YAML - casted, err = CastValue(value, valueType) - require.NoError(t, err) - require.Equal(t, casted, map[interface{}]interface{}{"test": int64(2)}) - value = cr.MustReadJSONStr(`{"test": 2}`) // JSON - casted, err = CastValue(value, valueType) + + _, err = ValidateOutputSchema(cr.MustReadYAMLStr( + `{mean: FLOAT, stddev: FLOAT}`)) require.NoError(t, err) - require.Equal(t, casted, map[interface{}]interface{}{"test": int64(2)}) - value = cr.MustReadYAMLStrMap("test: 2.0") // YAML - _, err = CastValue(value, valueType) + + _, err = ValidateOutputSchema(cr.MustReadYAMLStr( + `{_type: FLOAT}`)) require.Error(t, err) - value = cr.MustReadJSONStr(`{"test": 2.0}`) // JSON - _, err = CastValue(value, valueType) + + _, err = ValidateOutputSchema(cr.MustReadYAMLStr( + `{_mean: FLOAT}`)) require.Error(t, err) - valueType = cr.MustReadYAMLStrMap("STRING: INT|FLOAT") - value = cr.MustReadYAMLStrMap("test: 2.2") - casted, err = CastValue(value, valueType) + _, err = ValidateOutputSchema(cr.MustReadYAMLStr( + `{INT: FLOAT}`)) require.NoError(t, err) - require.Equal(t, casted, map[interface{}]interface{}{"test": 2.2}) - value = map[string]int{"test": 2} - casted, err = CastValue(value, valueType) - require.NoError(t, err) - require.Equal(t, casted, map[interface{}]interface{}{"test": int64(2)}) - valueType = cr.MustReadYAMLStrMap("mean: FLOAT\nsum: INT") - value = cr.MustReadYAMLStrMap("mean: 2.2\nsum: 4") - _, err = CastValue(value, valueType) - require.NoError(t, err) - value = cr.MustReadYAMLStrMap("mean: 2.2\nsum: 4.4") - _, err = CastValue(value, valueType) + _, err = ValidateOutputSchema(cr.MustReadYAMLStr( + `{INT: INT|FLOAT}`)) require.Error(t, err) - value = cr.MustReadYAMLStrMap("mean: test\nsum: 4") - _, err = CastValue(value, valueType) - require.Error(t, err) - value = cr.MustReadYAMLStrMap("mean: 2.2") - _, err = CastValue(value, valueType) + + _, err = ValidateOutputSchema(cr.MustReadYAMLStr( + `{INT|FLOAT: FLOAT}`)) require.Error(t, err) - value = cr.MustReadYAMLStrMap("mean: 2.2\nsum: null") - _, err = CastValue(value, valueType) - require.NoError(t, err) - value = cr.MustReadYAMLStrMap("mean: 2.2\nsum: 4\nextra: test") - _, err = CastValue(value, valueType) + + _, err = ValidateOutputSchema(cr.MustReadYAMLStr( + `{INT: FLOAT_COLUMN}`)) require.Error(t, err) - valueType = []string{"INT"} - value = []int{1, 2, 3} - casted, err = CastValue(value, valueType) - require.NoError(t, err) - require.Equal(t, casted, []interface{}{int64(1), int64(2), int64(3)}) - value = []float64{1.1, 2.2, 3.3} - _, err = CastValue(value, valueType) + _, err = ValidateOutputSchema(cr.MustReadYAMLStr( + `{INT_COLUMN: FLOAT}`)) require.Error(t, err) - value = []float64{1, 2, 3} - _, err = CastValue(value, valueType) + + _, err = ValidateOutputSchema(cr.MustReadYAMLStr( + `{INT: FLOAT, FLOAT: FLOAT}`)) require.Error(t, err) +} - valueType = []string{"FLOAT"} - value = []float64{1.1, 2.2, 3.3} - casted, err = CastValue(value, valueType) - require.NoError(t, err) - require.Equal(t, casted, []interface{}{1.1, 2.2, 3.3}) - value = []float64{1, 2, 3} - casted, err = CastValue(value, valueType) +func checkCastOutputValueEqual(t *testing.T, outputSchemaYAML string, valueYAML string, expected interface{}) { + outputSchema, err := ValidateOutputSchema(cr.MustReadYAMLStr(outputSchemaYAML)) require.NoError(t, err) - require.Equal(t, casted, []interface{}{float64(1), float64(2), float64(3)}) - value = []int{1, 2, 3} - casted, err = CastValue(value, valueType) + casted, err := CastOutputValue(cr.MustReadYAMLStr(valueYAML), outputSchema) require.NoError(t, err) - require.Equal(t, casted, []interface{}{float64(1), float64(2), float64(3)}) + require.Equal(t, casted, expected) - valueType = []string{"FLOAT|INT|BOOL"} - value = []float64{1.1, 2.2, 3.3} - casted, err = CastValue(value, valueType) - require.NoError(t, err) - require.Equal(t, casted, []interface{}{float64(1.1), float64(2.2), float64(3.3)}) - value = []int{1, 2, 3} - casted, err = CastValue(value, valueType) + // All output schemas are valid input schemas, so test those too + checkCastInputValueEqual(t, outputSchemaYAML, valueYAML, expected) +} + +func checkCastOutputValueNoError(t *testing.T, outputSchemaYAML string, valueYAML string) { + outputSchema, err := ValidateOutputSchema(cr.MustReadYAMLStr(outputSchemaYAML)) require.NoError(t, err) - require.Equal(t, casted, []interface{}{int64(1), int64(2), int64(3)}) - value = []float64{1, 2, 3} - casted, err = CastValue(value, valueType) + _, err = CastOutputValue(cr.MustReadYAMLStr(valueYAML), outputSchema) require.NoError(t, err) - require.Equal(t, casted, []interface{}{float64(1), float64(2), float64(3)}) - value = []interface{}{int64(1), float64(2), float64(2.2), true, false} - casted, err = CastValue(value, valueType) + + // All output schemas are valid input schemas, so test those too + checkCastInputValueNoError(t, outputSchemaYAML, valueYAML) +} + +func checkCastOutputValueError(t *testing.T, outputSchemaYAML string, valueYAML string) { + outputSchema, err := ValidateOutputSchema(cr.MustReadYAMLStr(outputSchemaYAML)) require.NoError(t, err) - require.Equal(t, casted, []interface{}{int64(1), float64(2), float64(2.2), true, false}) - value = []interface{}{true, "str"} - _, err = CastValue(value, valueType) + _, err = CastOutputValue(cr.MustReadYAMLStr(valueYAML), outputSchema) require.Error(t, err) - valueType = []string{"FLOAT|INT|BOOL|STRING"} - value = []interface{}{int64(1), float64(2), float64(2.2), "str", false} - casted, err = CastValue(value, valueType) - require.NoError(t, err) - require.Equal(t, casted, []interface{}{int64(1), float64(2), float64(2.2), "str", false}) + // All output schemas are valid input schemas, so test those too + checkCastInputValueError(t, outputSchemaYAML, valueYAML) +} - valueType = cr.MustReadYAMLStrMap( +func TestCastOutputValue(t *testing.T) { + checkCastOutputValueEqual(t, `INT`, `2`, int64(2)) + checkCastOutputValueError(t, `INT`, `test`) + checkCastOutputValueError(t, `INT`, `2.2`) + checkCastOutputValueEqual(t, `FLOAT`, `2`, float64(2)) + checkCastOutputValueError(t, `FLOAT`, `test`) + checkCastOutputValueEqual(t, `BOOL`, `true`, true) + checkCastOutputValueEqual(t, `STRING`, `str`, "str") + checkCastOutputValueError(t, `STRING`, `1`) + + checkCastOutputValueEqual(t, `{STRING: FLOAT}`, `{test: 2.2, test2: 4.4}`, + map[interface{}]interface{}{"test": 2.2, "test2": 4.4}) + checkCastOutputValueError(t, `{STRING: FLOAT}`, `{test: test2}`) + checkCastOutputValueEqual(t, `{STRING: FLOAT}`, `{test: 2}`, + map[interface{}]interface{}{"test": float64(2)}) + checkCastOutputValueEqual(t, `{STRING: INT}`, `{test: 2}`, + map[interface{}]interface{}{"test": int64(2)}) + checkCastOutputValueError(t, `{STRING: INT}`, `{test: 2.0}`) + + checkCastOutputValueEqual(t, `{mean: FLOAT, sum: INT}`, `{mean: 2.2, sum: 4}`, + map[interface{}]interface{}{"mean": float64(2.2), "sum": int64(4)}) + checkCastOutputValueError(t, `{mean: FLOAT, sum: INT}`, `{mean: 2.2, sum: test}`) + checkCastOutputValueError(t, `{mean: FLOAT, sum: INT}`, `{mean: false, sum: 4}`) + checkCastOutputValueError(t, `{mean: FLOAT, sum: INT}`, `{mean: 2.2, 2: 4}`) + checkCastOutputValueError(t, `{mean: FLOAT, sum: INT}`, `{mean: 2.2, sum: Null}`) + checkCastOutputValueError(t, `{mean: FLOAT, sum: INT}`, `{mean: 2.2}`) + checkCastOutputValueError(t, `{mean: FLOAT, sum: INT}`, `{mean: 2.2, sum: Null}`) + checkCastOutputValueError(t, `{mean: FLOAT, sum: INT}`, `{mean: 2.2, sum: 4, stddev: 2}`) + + checkCastOutputValueEqual(t, `[INT]`, `[1, 2]`, + []interface{}{int64(1), int64(2)}) + checkCastOutputValueError(t, `[INT]`, `[1.0, 2]`) + checkCastOutputValueEqual(t, `[FLOAT]`, `[1.0, 2]`, + []interface{}{float64(1), float64(2)}) + + outputSchemaYAML := ` map: {STRING: FLOAT} str: STRING @@ -557,11 +179,9 @@ func TestCastValue(t *testing.T) { b: [STRING] c: {mean: FLOAT, sum: [INT], stddev: {STRING: INT}} bools: [BOOL] - anything: [BOOL|INT|FLOAT|STRING] - `) + ` - value = cr.MustReadYAMLStrMap( - ` + checkCastOutputValueNoError(t, outputSchemaYAML, ` map: {a: 2.2, b: 3} str: test1 floats: [2.2, 3.3, 4.4] @@ -573,7 +193,6 @@ func TestCastValue(t *testing.T) { b: [test1, test2, test3] c: {mean: 8.8, sum: [3, 2, 1], stddev: {a: 1, b: 2}} bools: [true] - anything: [] testB: lat: 3.14 lon: @@ -581,40 +200,9 @@ func TestCastValue(t *testing.T) { b: [testX, testY, testZ] c: {mean: 1.7, sum: [1], stddev: {z: 12}} bools: [true, false, true] - anything: [10, 2.2, test, false] - `) - _, err = CastValue(value, valueType) - require.NoError(t, err) - - value = cr.MustReadYAMLStrMap( - ` - map: {a: 2.2, b: 3} - str: test1 - floats: [2.2, 3.3, 4.4] - map2: - testA: - lat: 9.9 - lon: - a: 17 - b: [test1, test2, test3] - c: null - bools: [true] - anything: [] - testB: - lat: null - lon: - a: 88 - b: null - c: {mean: 1.7, sum: [1], stddev: {z: 12}} - bools: [true, false, true] - anything: [10, 2.2, test, false] - testC: null `) - _, err = CastValue(value, valueType) - require.NoError(t, err) - value = cr.MustReadYAMLStrMap( - ` + checkCastOutputValueError(t, outputSchemaYAML, ` map: {a: 2.2, b: 3} str: test1 floats: [2.2, 3.3, 4.4] @@ -626,20 +214,15 @@ func TestCastValue(t *testing.T) { b: [test1, test2, test3] c: {mean: 8.8, sum: [3, 2, 1], stddev: {a: 1, b: 2}} bools: [true] - anything: [] testB: lat: 3.14 lon: b: [testX, testY, testZ] c: {mean: 1.7, sum: [1], stddev: {z: 12}} bools: [true, false, true] - anything: [10, 2.2, test, false] `) - _, err = CastValue(value, valueType) - require.Error(t, err) - value = cr.MustReadYAMLStrMap( - ` + checkCastOutputValueError(t, outputSchemaYAML, ` map: {a: 2.2, b: 3} str: test1 floats: [2.2, 3.3, 4.4] @@ -651,7 +234,6 @@ func TestCastValue(t *testing.T) { b: [test1, test2, test3] c: {mean: 8.8, sum: [3, 2, 1], stddev: {a: 1, b: 2}} bools: [true] - anything: [] testB: lat: 3.14 lon: @@ -659,13 +241,9 @@ func TestCastValue(t *testing.T) { b: [testX, testY, testZ] c: {mean: 1.7, sum: [1], stddev: {z: 12}} bools: [true, false, true] - anything: [10, 2.2, test, false] `) - _, err = CastValue(value, valueType) - require.Error(t, err) - value = cr.MustReadYAMLStrMap( - ` + checkCastOutputValueError(t, outputSchemaYAML, ` map: {a: 2.2, b: 3} str: test1 floats: [2.2, 3.3, 4.4] @@ -677,7 +255,6 @@ func TestCastValue(t *testing.T) { b: [test1, test2, test3] c: {mean: 8.8, sum: [3, 2, 1], stddev: {a: 1, b: 2}} bools: [true] - anything: [] testB: lat: 3.14 lon: @@ -685,13 +262,9 @@ func TestCastValue(t *testing.T) { b: [testX, testY, 2] c: {mean: 1.7, sum: [1], stddev: {z: 12}} bools: [true, false, true] - anything: [10, 2.2, test, false] `) - _, err = CastValue(value, valueType) - require.Error(t, err) - value = cr.MustReadYAMLStrMap( - ` + checkCastOutputValueError(t, outputSchemaYAML, ` map: {a: 2.2, b: 3} str: test1 floats: [2.2, 3.3, 4.4] @@ -703,7 +276,6 @@ func TestCastValue(t *testing.T) { b: [test1, test2, test3] c: {mean: 8.8, sum: [3, 2, 1], stddev: {a: 1, b: test}} bools: [true] - anything: [] testB: lat: 3.14 lon: @@ -711,13 +283,9 @@ func TestCastValue(t *testing.T) { b: [testX, testY, testZ] c: {mean: 1.7, sum: [1], stddev: {z: 12}} bools: [true, false, true] - anything: [10, 2.2, test, false] `) - _, err = CastValue(value, valueType) - require.Error(t, err) - value = cr.MustReadYAMLStrMap( - ` + checkCastOutputValueError(t, outputSchemaYAML, ` map: {a: 2.2, b: 3} str: test1 floats: [2.2, 3.3, 4.4] @@ -729,7 +297,6 @@ func TestCastValue(t *testing.T) { b: [test1, test2, test3] c: {mean: 8.8, sum: [3, 2, 1], stddev: {a: 1, b: 2}} bools: [true] - anything: [] testB: lat: 3.14 lon: @@ -737,13 +304,9 @@ func TestCastValue(t *testing.T) { b: [testX, testY, testZ] c: {mean: 1.7, sum: [1], stddev: {z: 12}} bools: true - anything: [10, 2.2, test, false] `) - _, err = CastValue(value, valueType) - require.Error(t, err) - value = cr.MustReadYAMLStrMap( - ` + checkCastOutputValueError(t, outputSchemaYAML, ` map: {a: 2.2, b: 3} str: test1 floats: [2.2, 3.3, 4.4] @@ -755,7 +318,6 @@ func TestCastValue(t *testing.T) { b: [test1, test2, test3] c: {mean: 8.8, sum: [3, 2, 1], stddev: {a: 1, b: 2}} bools: [true] - anything: [] testB: lat: 3.14 lon: @@ -763,203 +325,1324 @@ func TestCastValue(t *testing.T) { b: [testX, testY, testZ] c: {mean: 1.7, sum: [1], stddev: {z: 12}} bools: [1, 2, 3] - anything: [10, 2.2, test, false] `) - _, err = CastValue(value, valueType) +} + +func checkCastInputValueEqual(t *testing.T, inputSchemaYAML string, valueYAML string, expected interface{}) { + inputSchema, err := ValidateInputSchema(cr.MustReadYAMLStr(inputSchemaYAML), false) + require.NoError(t, err) + casted, err := CastInputValue(cr.MustReadYAMLStr(valueYAML), inputSchema) + require.NoError(t, err) + require.Equal(t, casted, expected) +} + +func checkCastInputValueError(t *testing.T, inputSchemaYAML string, valueYAML string) { + inputSchema, err := ValidateInputSchema(cr.MustReadYAMLStr(inputSchemaYAML), false) + require.NoError(t, err) + _, err = CastInputValue(cr.MustReadYAMLStr(valueYAML), inputSchema) require.Error(t, err) } -func TestCheckValueRuntimeTypesMatch(t *testing.T) { - var schemaType interface{} - var runtimeType interface{} - - schemaType = "INT" - runtimeType = "INT" - require.NoError(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = "FLOAT" - require.Error(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = "FLOAT|INT" - require.Error(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - - schemaType = "FLOAT|INT" - runtimeType = "FLOAT|INT" - require.NoError(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = "INT|FLOAT" - require.NoError(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = "FLOAT" - require.NoError(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = "INT" - require.NoError(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = "STRING" - require.Error(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - - schemaType = []string{"INT"} - runtimeType = []string{"INT"} - require.NoError(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = []string{"STRING"} - require.Error(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - - schemaType = []string{"BOOL"} - runtimeType = []string{"BOOL"} - require.NoError(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = []string{"INT"} - require.Error(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - - schemaType = []string{"FLOAT|INT"} - runtimeType = []string{"INT|FLOAT"} - require.NoError(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = []string{"INT"} - require.NoError(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - - schemaType = []string{"BOOL|FLOAT|INT"} - runtimeType = []string{"FLOAT|INT|BOOL"} - require.NoError(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = []string{"INT|FLOAT"} - require.NoError(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = []string{"FLOAT"} - require.NoError(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = []string{"INT"} - require.NoError(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = []string{"BOOL"} - require.NoError(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = []string{"STRING"} - require.Error(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = "FLOAT" - require.Error(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - - schemaType = cr.MustReadYAMLStrMap("STRING: FLOAT") - runtimeType = cr.MustReadYAMLStrMap("STRING: FLOAT") - require.NoError(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = cr.MustReadYAMLStrMap("STRING: INT") - require.Error(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - - schemaType = cr.MustReadYAMLStrMap("STRING: [INT|FLOAT]") - runtimeType = cr.MustReadYAMLStrMap("STRING: [FLOAT|INT]") - require.NoError(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = cr.MustReadYAMLStrMap("STRING: [INT]") - require.NoError(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = cr.MustReadYAMLStrMap("STRING: [BOOL]") - require.Error(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = cr.MustReadYAMLStrMap("STRING: INT") - require.Error(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - - schemaType = cr.MustReadYAMLStrMap("INT|FLOAT: STRING") - runtimeType = cr.MustReadYAMLStrMap("FLOAT|INT: STRING") - require.NoError(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = cr.MustReadYAMLStrMap("INT: STRING") - require.NoError(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = cr.MustReadYAMLStrMap("BOOL: STRING") - require.Error(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - - schemaType = cr.MustReadYAMLStrMap("mean: FLOAT\nsum: INT") - runtimeType = cr.MustReadYAMLStrMap("mean: FLOAT\nsum: INT") - require.NoError(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = cr.MustReadYAMLStrMap("sum: INT\nmean: FLOAT") - require.NoError(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = cr.MustReadYAMLStrMap("sum: INT\nmean: INT") - require.Error(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = cr.MustReadYAMLStrMap("mean: FLOAT") - require.Error(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = cr.MustReadYAMLStrMap("mean: FLOAT\nsum: INT\nextra: STRING") - require.Error(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - - schemaType = cr.MustReadYAMLStrMap("mean: FLOAT\nsum: INT|FLOAT") - runtimeType = cr.MustReadYAMLStrMap("mean: FLOAT\nsum: FLOAT|INT") - require.NoError(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = cr.MustReadYAMLStrMap("sum: FLOAT\nmean: FLOAT") - require.NoError(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = cr.MustReadYAMLStrMap("sum: INT\nmean: FLOAT") - require.NoError(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = cr.MustReadYAMLStrMap("sum: INT\nmean: INT") - require.Error(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - - schemaType = cr.MustReadYAMLStrMap( - ` - map: {STRING: FLOAT} - str: STRING - floats: [FLOAT] - map2: - STRING: - lat: FLOAT - lon: - a: INT|FLOAT - b: [STRING] - c: {mean: FLOAT, sum: [INT], stddev: {STRING: INT|FLOAT}} - d: [BOOL] - `) - runtimeType = cr.MustReadYAMLStrMap( - ` - floats: [FLOAT] - str: STRING - map2: - STRING: - lat: FLOAT - lon: - c: {sum: [INT], mean: FLOAT, stddev: {STRING: FLOAT|INT}} - b: [STRING] - a: FLOAT|INT - d: [BOOL] - map: {STRING: FLOAT} - `) - require.NoError(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = cr.MustReadYAMLStrMap( - ` - floats: [FLOAT] - str: STRING - map2: - STRING: - lat: FLOAT - lon: - c: {sum: [INT], mean: FLOAT, stddev: {STRING: FLOAT|INT}} - b: [STRING] - a: INT - d: [BOOL] - map: {STRING: FLOAT} - `) - require.NoError(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = cr.MustReadYAMLStrMap( - ` - floats: [FLOAT] - str: STRING - map2: - STRING: - lat: FLOAT - lon: - c: {sum: [INT], mean: FLOAT, stddev: {STRING: FLOAT|INT}} - b: STRING - a: FLOAT|INT - d: [BOOL] - map: {STRING: FLOAT} - `) - require.Error(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = cr.MustReadYAMLStrMap( - ` - floats: [FLOAT] - str: STRING - map2: - STRING: - lat: FLOAT - lon: - c: {sum: [INT], stddev: {STRING: FLOAT|INT}} - b: [STRING] - a: FLOAT|INT - d: [BOOL] - map: {STRING: FLOAT} - `) - require.Error(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) - runtimeType = cr.MustReadYAMLStrMap( - ` - floats: [FLOAT] - str: STRING - map2: - STRING: - lat: FLOAT - lon: - c: {sum: [INT], mean: FLOAT, stddev: {STRING: FLOAT|INT}} - b: [STRING] - a: FLOAT|INT - d: BOOL - map: {STRING: FLOAT} - `) - require.Error(t, CheckValueRuntimeTypesMatch(runtimeType, schemaType)) +func checkCastInputValueNoError(t *testing.T, inputSchemaYAML string, valueYAML string) { + inputSchema, err := ValidateInputSchema(cr.MustReadYAMLStr(inputSchemaYAML), false) + require.NoError(t, err) + _, err = CastInputValue(cr.MustReadYAMLStr(valueYAML), inputSchema) + require.NoError(t, err) +} + +func TestCastInputValue(t *testing.T) { + // Note: all test cases in TestCastOutputValue also test CastInputValue() since output schemas are valid input schemas + + checkCastInputValueEqual(t, `FLOAT|INT`, `2`, int64(2)) + checkCastInputValueEqual(t, `INT|FLOAT`, `2`, int64(2)) + checkCastInputValueEqual(t, `FLOAT|INT`, `2.2`, float64(2.2)) + checkCastInputValueEqual(t, `INT|FLOAT`, `2.2`, float64(2.2)) + checkCastInputValueError(t, `STRING`, `2`) + checkCastInputValueEqual(t, `STRING|FLOAT`, `2`, float64(2)) + checkCastInputValueEqual(t, `{_type: [INT], _max_count: 2}`, `[2]`, []interface{}{int64(2)}) + checkCastInputValueError(t, `{_type: [INT], _max_count: 2}`, `[2, 3, 4]`) + checkCastInputValueEqual(t, `{_type: [INT], _min_count: 2}`, `[2, 3, 4]`, []interface{}{int64(2), int64(3), int64(4)}) + checkCastInputValueError(t, `{_type: [INT], _min_count: 2}`, `[2]`) + checkCastInputValueError(t, `{_type: INT, _optional: true}`, `Null`) + checkCastInputValueError(t, `{_type: INT, _optional: true}`, ``) + checkCastInputValueEqual(t, `{_type: INT, _allow_null: true}`, `Null`, nil) + checkCastInputValueEqual(t, `{_type: INT, _allow_null: true}`, ``, nil) + checkCastInputValueError(t, `{_type: {a: INT}}`, `Null`) + checkCastInputValueError(t, `{_type: {a: INT}, _optional: true}`, `Null`) + checkCastInputValueEqual(t, `{_type: {a: INT}, _allow_null: true}`, `Null`, nil) + checkCastInputValueEqual(t, `{_type: {a: INT}}`, `{a: 2}`, map[interface{}]interface{}{"a": int64(2)}) + checkCastInputValueError(t, `{_type: {a: INT}}`, `{a: Null}`) + checkCastInputValueError(t, `{a: {_type: INT, _optional: false}}`, `{a: Null}`) + checkCastInputValueError(t, `{a: {_type: INT, _optional: false}}`, `{}`) + checkCastInputValueError(t, `{a: {_type: INT, _optional: true}}`, `{a: Null}`) + checkCastInputValueEqual(t, `{a: {_type: INT, _optional: true}}`, `{}`, map[interface{}]interface{}{}) + checkCastInputValueEqual(t, `{a: {_type: INT, _allow_null: true}}`, `{a: Null}`, map[interface{}]interface{}{"a": nil}) + checkCastInputValueError(t, `{a: {_type: INT, _allow_null: true}}`, `{}`) + checkCastInputValueEqual(t, `{a: {_type: INT, _allow_null: true, _optional: true}}`, `{}`, map[interface{}]interface{}{}) +} + +func TestValidateInputSchema(t *testing.T) { + var inputSchema, inputSchema2, inputSchema3, inputSchema4 *InputSchema + var err error + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + `STRING`), false) + require.NoError(t, err) + inputSchema2, err = ValidateInputSchema(cr.MustReadYAMLStr( + `_type: STRING`), false) + require.NoError(t, err) + require.Equal(t, inputSchema, inputSchema2) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `STRING_COLUMN`), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `STRING_COLUMN`), true) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `INFERRED_COLUMN`), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `BAD_COLUMN`), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: STRING + _default: test + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: STRING_COLUMN + _default: test + `), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: STRING + _default: Null + `), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: STRING + _default: 2 + `), false) + require.Error(t, err) + + // Lists + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + `[STRING]`), false) + require.NoError(t, err) + inputSchema2, err = ValidateInputSchema(cr.MustReadYAMLStr( + `_type: [STRING]`), false) + require.NoError(t, err) + inputSchema3, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: + - _type: STRING + `), false) + require.NoError(t, err) + require.Equal(t, inputSchema, inputSchema2) + require.Equal(t, inputSchema, inputSchema3) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `[STRING|INT]`), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `[STRING_COLUMN|INT_COLUMN]`), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `[STRING_COLUMN|INT_COLUMN]`), true) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `[STRING|INT_COLUMN]`), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: [STRING] + _default: [test1, test2, test3] + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: [STRING_COLUMN] + _default: [test1, test2, test3] + `), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: [STRING] + _default: [test1, 2, test3] + `), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: [STRING|INT] + _default: [test1, 2, test3] + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: [STRING|FLOAT] + _default: [test1, 2, test3] + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: [STRING] + _default: test1 + `), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: [STRING] + _min_count: 2 + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: [STRING] + _min_count: 2 + _max_count: 2 + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: [STRING] + _min_count: 2 + _max_count: 1 + `), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: [STRING] + _default: [test1] + _min_count: 2 + `), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: [STRING] + _default: [test1, test2] + _min_count: 2 + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: [STRING] + _min_count: -1 + `), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: [STRING] + _min_count: test + `), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: [STRING] + _default: [test1, test2, test3] + _max_count: 2 + `), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: [STRING] + _default: [test1, test2] + _max_count: 2 + `), false) + require.NoError(t, err) + + // Maps + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + `arg1: STRING`), false) + require.NoError(t, err) + inputSchema2, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + arg1: + _type: STRING + `), false) + require.NoError(t, err) + inputSchema3, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: {arg1: STRING} + `), false) + require.NoError(t, err) + require.Equal(t, inputSchema, inputSchema2) + require.Equal(t, inputSchema, inputSchema3) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `arg1: STRING_COLUMN`), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `arg1: STRING_COLUMN`), true) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `STRING_COLUMN: STRING`), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `STRING_COLUMN: STRING`), true) + require.Error(t, err) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + `_arg1: STRING`), false) + require.Error(t, err) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + `arg1: test`), false) + require.Error(t, err) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + `STRING: test`), false) + require.Error(t, err) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + `STRING_COLUMN: test`), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: {arg1: STRING} + _min_count: 2 + `), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: {STRING: INT} + _default: {test: 2} + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: {FLOAT: INT} + _default: {2: 2} + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: {STRING: INT} + _default: {test: test} + `), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: {STRING: INT|STRING} + _default: {test: test} + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: {STRING: INT_COLUMN} + _min_count: 2 + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: {STRING_COLUMN: INT} + _min_count: 2 + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: {STRING_COLUMN: INT} + _default: {test: 2} + `), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: {STRING_COLUMN: INT_COLUMN} + _min_count: 2 + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: {STRING_COLUMN: INT_COLUMN|STRING} + `), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: {STRING_COLUMN: INT_COLUMN|STRING_COLUMN} + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: {STRING: INT} + _min_count: 2 + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + arg1: + _type: STRING + _optional: true + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + arg1: + _type: STRING + _default: test + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + arg1: + _type: STRING + _default: 2 + `), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + arg1: + _type: STRING + _default: Null + `), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + arg1: + _type: STRING + _default: Null + `), false) + require.Error(t, err) + + // Mixed + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `[[STRING]]`), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: [[STRING]] + _default: [[test1, test2]] + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: [[STRING_COLUMN]] + _default: [[test1, test2]] + `), false) + require.Error(t, err) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + - arg1: STRING + arg2: INT + `), false) + require.NoError(t, err) + inputSchema2, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: + - arg1: STRING + arg2: INT + `), false) + require.NoError(t, err) + inputSchema3, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + - arg1: {_type: STRING} + arg2: {_type: INT} + `), false) + require.NoError(t, err) + inputSchema4, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: + - arg1: + _type: STRING + arg2: + _type: INT + `), false) + require.NoError(t, err) + require.Equal(t, inputSchema, inputSchema2) + require.Equal(t, inputSchema, inputSchema3) + require.Equal(t, inputSchema, inputSchema4) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + - arg1: + _type: STRING + _default: test + arg2: + _type: INT + _default: 2 + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + - arg1: + _type: + arg_a: STRING + arg_b: + _type: INT + _default: 1 + arg2: + _type: INT + _default: 2 + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + arg_a: + arg1: + _type: + arg_a: STRING + arg_b: + _type: INT + _default: 1 + arg2: + _type: INT + _default: 2 + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + arg1: + INT: + arg_a: INT + arg_b: + _type: STRING + _default: test + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + - arg1: + INT: + arg_a: INT + arg_b: + _type: STRING + _default: test + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + arg1: + - INT: + arg_a: INT + arg_b: + _type: STRING + _default: test + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + arg1: + 2: STRING + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + arg1: + _type: {2: STRING} + _default: {2: test} + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + arg1: + 2: + _type: STRING + _default: test + `), false) + require.NoError(t, err) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + `[{INT_COLUMN: STRING|INT}]`), false) + require.NoError(t, err) + inputSchema2, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: [{INT_COLUMN: STRING|INT}] + `), false) + require.NoError(t, err) + require.Equal(t, inputSchema, inputSchema2) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `map: {BOOL|FLOAT: INT|STRING}`), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `map: {mean: FLOAT, stddev: FLOAT}`), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `map: {STRING: {lat: FLOAT, lon: FLOAT}}`), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `map: {STRING: {lat: FLOAT, lon: [FLOAT]}}`), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `map: {STRING: {FLOAT: INT}}`), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `map: {STRING: {FLOAT: [INT]}}`), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `map: {STRING: {lat: FLOAT, lon: {lat2: FLOAT, lon2: INT}}}`), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `map6: {STRING: {lat: FLOAT, lon: {lat2: FLOAT, lon2: {INT: STRING}}}}`), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `map6: {STRING: {lat: FLOAT, lon: {lat2: FLOAT, lon2: {INT: STRING}, mean: BOOL}}}`), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + num: [INT] + str: STRING + map1: {STRING: INT} + map2: {mean: FLOAT, stddev: FLOAT} + map3: {STRING: {lat: FLOAT, lon: FLOAT}} + map3: {STRING: {lat: FLOAT, lon: [FLOAT]}} + map4: {STRING: {FLOAT: INT}} + map5: {STRING: {BOOL: [INT]}} + map6: {STRING: {lat: FLOAT, lon: {lat2: FLOAT, lon2: INT}}} + map6: {STRING: {lat: FLOAT, lon: {lat2: FLOAT, lon2: {INT: STRING}, mean: BOOL}}} + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `map: {STRING: INT, INT: FLOAT}`), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `map: {STRING: INT, INT: [FLOAT]}`), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `map: {mean: FLOAT, INT: FLOAT}`), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `map: {mean: FLOAT, INT: [FLOAT]}`), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `map: {STRING: {lat: FLOAT, STRING: FLOAT}}`), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `map: {STRING: {STRING: test}}`), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `cols: [STRING_COLUMN, INT_COLUMN]`), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `cols: [STRING_COLUMNs]`), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `cols: [STRING_COLUMN|BAD]`), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `cols: Null`), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `cols: 1`), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `cols: [1]`), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + `cols: []`), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + float: FLOAT_COLUMN + int: INT_COLUMN + str: STRING_COLUMN + int_list: FLOAT_LIST_COLUMN + float_list: INT_LIST_COLUMN + str_list: STRING_LIST_COLUMN + num1: FLOAT_COLUMN|INT_COLUMN + num2: INT_COLUMN|FLOAT_COLUMN + num3: STRING_COLUMN|INT_COLUMN + num4: INT_COLUMN|FLOAT_COLUMN|STRING_COLUMN + num5: STRING_COLUMN|INT_COLUMN|FLOAT_COLUMN + num6: STRING_LIST_COLUMN|INT_LIST_COLUMN|FLOAT_LIST_COLUMN + num7: STRING_COLUMN|INT_LIST_COLUMN|FLOAT_LIST_COLUMN + nums1: [INT_COLUMN] + nums2: [FLOAT_COLUMN] + nums3: [INT_COLUMN|FLOAT_COLUMN] + nums4: [FLOAT_COLUMN|INT_COLUMN] + nums5: [STRING_COLUMN|INT_COLUMN|FLOAT_COLUMN] + nums6: [INT_LIST_COLUMN] + nums7: [INT_LIST_COLUMN|STRING_LIST_COLUMN] + nums8: [INT_LIST_COLUMN|STRING_COLUMN] + float_scalar: FLOAT + int_scalar: INT + str_scalar: STRING + bool_scalar: BOOL + num1_scalar: FLOAT|INT + num2_scalar: INT|FLOAT + num3_scalar: STRING|INT + num4_scalar: INT|FLOAT|STRING + num5_scalar: STRING|INT|FLOAT + nums1_scalar: [INT] + nums2_scalar: [FLOAT] + nums3_scalar: [INT|FLOAT] + nums4_scalar: [FLOAT|INT] + nums5_scalar: [STRING|INT|FLOAT] + nums6_scalar: [STRING|INT|FLOAT|BOOL] + `), false) + require.NoError(t, err) + + // Casting defaults + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: INT + _default: 2 + `), false) + require.NoError(t, err) + require.Equal(t, inputSchema.Default, int64(2)) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: INT + _default: test + `), false) + require.Error(t, err) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: INT + _default: 2.2 + `), false) + require.Error(t, err) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: FLOAT + _default: 2 + `), false) + require.NoError(t, err) + require.Equal(t, inputSchema.Default, float64(2)) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: FLOAT|INT + _default: 2 + `), false) + require.NoError(t, err) + require.Equal(t, inputSchema.Default, int64(2)) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: BOOL + _default: true + `), false) + require.NoError(t, err) + require.Equal(t, inputSchema.Default, true) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: {STRING: FLOAT} + _default: {test: 2.2, test2: 4.4} + `), false) + require.NoError(t, err) + require.Equal(t, inputSchema.Default, map[interface{}]interface{}{"test": 2.2, "test2": 4.4}) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: {STRING: FLOAT} + _default: {test: test2} + `), false) + require.Error(t, err) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: {STRING: FLOAT} + _default: {test: 2} + `), false) + require.NoError(t, err) + require.Equal(t, inputSchema.Default, map[interface{}]interface{}{"test": float64(2)}) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: {STRING: FLOAT} + _default: {test: 2.0} + `), false) + require.NoError(t, err) + require.Equal(t, inputSchema.Default, map[interface{}]interface{}{"test": float64(2)}) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: {STRING: INT} + _default: {test: 2} + `), false) + require.NoError(t, err) + require.Equal(t, inputSchema.Default, map[interface{}]interface{}{"test": int64(2)}) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: {STRING: INT} + _default: {test: 2.0} + `), false) + require.Error(t, err) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: {mean: FLOAT, sum: INT} + _default: {mean: 2.2, sum: 4} + `), false) + require.NoError(t, err) + require.Equal(t, inputSchema.Default, map[interface{}]interface{}{"mean": float64(2.2), "sum": int64(4)}) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: {mean: FLOAT, sum: INT} + _default: {mean: 2.2, sum: test} + `), false) + require.Error(t, err) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: {mean: FLOAT, sum: INT} + _default: {mean: false, sum: 4} + `), false) + require.Error(t, err) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: {mean: FLOAT, sum: INT} + _default: {mean: 2.2, 2: 4} + `), false) + require.Error(t, err) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: {mean: FLOAT, sum: INT} + _default: {mean: 2.2, sum: Null} + `), false) + require.Error(t, err) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: {mean: FLOAT, sum: INT} + _default: {mean: 2.2} + `), false) + require.Error(t, err) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: {mean: FLOAT, sum: INT} + _default: {mean: 2.2, sum: 4, stddev: 2} + `), false) + require.Error(t, err) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: [INT] + _default: [1, 2] + `), false) + require.NoError(t, err) + require.Equal(t, inputSchema.Default, []interface{}{int64(1), int64(2)}) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: [INT] + _default: [1.0, 2] + `), false) + require.Error(t, err) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: [FLOAT] + _default: [1.0, 2] + `), false) + require.NoError(t, err) + require.Equal(t, inputSchema.Default, []interface{}{float64(1), float64(2)}) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: [FLOAT|INT] + _default: [1.0, 2] + `), false) + require.NoError(t, err) + require.Equal(t, inputSchema.Default, []interface{}{float64(1), int64(2)}) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: [FLOAT|INT|BOOL] + _default: [1.0, 2, true, test] + `), false) + require.Error(t, err) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: [FLOAT|INT|BOOL|STRING] + _default: [1.0, 2, true, test] + `), false) + require.NoError(t, err) + require.Equal(t, inputSchema.Default, []interface{}{float64(1), int64(2), true, "test"}) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: + STRING: + a: + _type: INT + _optional: true + b: + _type: [STRING] + _optional: true + c: + _type: {mean: FLOAT, sum: [INT], stddev: {STRING: INT}} + _optional: true + d: + _type: INT + _default: 2 + _default: + testA: {} + testB: + a: 88 + b: [testX, testY, testZ] + c: {mean: 1.7, sum: [1], stddev: {z: 12}} + d: 17 + `), false) + require.NoError(t, err) + require.Equal(t, inputSchema.Default, map[interface{}]interface{}{ + "testA": map[interface{}]interface{}{}, + "testB": map[interface{}]interface{}{ + "a": int64(88), + "b": []interface{}{"testX", "testY", "testZ"}, + "c": map[interface{}]interface{}{ + "mean": float64(1.7), + "sum": []interface{}{int64(1)}, + "stddev": map[interface{}]interface{}{"z": int64(12)}, + }, + "d": int64(17), + }, + }) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: + STRING: + a: + _type: INT + _optional: true + b: + _type: [STRING] + c: + _type: {mean: FLOAT, sum: [INT], stddev: {STRING: INT}} + _optional: true + d: + _type: INT + _default: 2 + _default: + testA: Null + `), false) + require.Error(t, err) + + inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: + STRING: + _allow_null: true + _type: + a: + _type: INT + _optional: true + b: + _type: [STRING] + c: + _type: {mean: FLOAT, sum: [INT], stddev: {STRING: INT}} + _optional: true + d: + _type: INT + _default: 2 + _default: + testA: Null + `), false) + require.NoError(t, err) + require.Equal(t, inputSchema.Default, map[interface{}]interface{}{ + "testA": nil, + }) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: + STRING: + a: + _type: INT + _optional: true + b: + _type: [STRING] + c: + _type: {mean: FLOAT, sum: [INT], stddev: {STRING: INT}} + _optional: true + d: + _type: INT + _default: 2 + _default: + testA: {} + `), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: + STRING: + a: + _type: INT + _optional: true + b: + _type: [STRING] + c: + _type: {mean: FLOAT, sum: [INT], stddev: {STRING: INT}} + _optional: true + d: + _type: INT + _default: 2 + _default: + testA: + a: 88 + c: {mean: 1.7, sum: [1], stddev: {z: 12}} + d: 17 + `), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: + map: {STRING: FLOAT} + str: STRING + floats: [FLOAT] + map2: + STRING: + lat: FLOAT + lon: + a: INT + b: [STRING] + c: {mean: FLOAT, sum: [INT], stddev: {STRING: INT}} + bools: [BOOL] + anything: [BOOL|INT|FLOAT|STRING] + _default: + map: {a: 2.2, b: 3} + str: test1 + floats: [2.2, 3.3, 4.4] + map2: + testA: + lat: 9.9 + lon: + a: 17 + b: [test1, test2, test3] + c: {mean: 8.8, sum: [3, 2, 1], stddev: {a: 1, b: 2}} + bools: [true] + anything: [] + testB: + lat: 3.14 + lon: + a: 88 + b: [testX, testY, testZ] + c: {mean: 1.7, sum: [1], stddev: {z: 12}} + bools: [true, false, true] + anything: [10, 2.2, test, false] + `), false) + require.NoError(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: + map: {STRING: FLOAT} + str: STRING + floats: [FLOAT] + map2: + STRING: + lat: FLOAT + lon: + a: INT + b: [STRING] + c: {mean: FLOAT, sum: [INT], stddev: {STRING: INT}} + bools: [BOOL] + anything: [BOOL|INT|FLOAT|STRING] + _default: + map: {a: 2.2, b: 3} + str: test1 + floats: [2.2, 3.3, 4.4] + map2: + testA: + lat: 9.9 + lon: + a: 17 + b: [test1, test2, test3] + c: {mean: 8.8, sum: [3, 2, 1], stddev: {a: 1, b: 2}} + bools: [true] + anything: [] + testB: + lat: 3.14 + lon: + b: [testX, testY, testZ] + c: {mean: 1.7, sum: [1], stddev: {z: 12}} + bools: [true, false, true] + anything: [10, 2.2, test, false] + `), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: + map: {STRING: FLOAT} + str: STRING + floats: [FLOAT] + map2: + STRING: + lat: FLOAT + lon: + a: INT + b: [STRING] + c: {mean: FLOAT, sum: [INT], stddev: {STRING: INT}} + bools: [BOOL] + anything: [BOOL|INT|FLOAT|STRING] + _default: + map: {a: 2.2, b: 3} + str: test1 + floats: [2.2, 3.3, 4.4] + map2: + testA: + lat: 9.9 + lon: + a: 17 + b: [test1, test2, test3] + c: {mean: 8.8, sum: [3, 2, 1], stddev: {a: 1, b: 2}} + bools: [true] + anything: [] + testB: + lat: 3.14 + lon: + a: 88.8 + b: [testX, testY, testZ] + c: {mean: 1.7, sum: [1], stddev: {z: 12}} + bools: [true, false, true] + anything: [10, 2.2, test, false] + `), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: + map: {STRING: FLOAT} + str: STRING + floats: [FLOAT] + map2: + STRING: + lat: FLOAT + lon: + a: INT + b: [STRING] + c: {mean: FLOAT, sum: [INT], stddev: {STRING: INT}} + bools: [BOOL] + anything: [BOOL|INT|FLOAT|STRING] + _default: + map: {a: 2.2, b: 3} + str: test1 + floats: [2.2, 3.3, 4.4] + map2: + testA: + lat: 9.9 + lon: + a: 17 + b: [test1, test2, test3] + c: {mean: 8.8, sum: [3, 2, 1], stddev: {a: 1, b: 2}} + bools: [true] + anything: [] + testB: + lat: 3.14 + lon: + a: 88 + b: [testX, testY, 2] + c: {mean: 1.7, sum: [1], stddev: {z: 12}} + bools: [true, false, true] + anything: [10, 2.2, test, false] + `), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: + map: {STRING: FLOAT} + str: STRING + floats: [FLOAT] + map2: + STRING: + lat: FLOAT + lon: + a: INT + b: [STRING] + c: {mean: FLOAT, sum: [INT], stddev: {STRING: INT}} + bools: [BOOL] + anything: [BOOL|INT|FLOAT|STRING] + _default: + map: {a: 2.2, b: 3} + str: test1 + floats: [2.2, 3.3, 4.4] + map2: + testA: + lat: 9.9 + lon: + a: 17 + b: [test1, test2, test3] + c: {mean: 8.8, sum: [3, 2, 1], stddev: {a: 1, b: test}} + bools: [true] + anything: [] + testB: + lat: 3.14 + lon: + a: 88 + b: [testX, testY, testZ] + c: {mean: 1.7, sum: [1], stddev: {z: 12}} + bools: [true, false, true] + anything: [10, 2.2, test, false] + `), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: + map: {STRING: FLOAT} + str: STRING + floats: [FLOAT] + map2: + STRING: + lat: FLOAT + lon: + a: INT + b: [STRING] + c: {mean: FLOAT, sum: [INT], stddev: {STRING: INT}} + bools: [BOOL] + anything: [BOOL|INT|FLOAT|STRING] + _default: + map: {a: 2.2, b: 3} + str: test1 + floats: [2.2, 3.3, 4.4] + map2: + testA: + lat: 9.9 + lon: + a: 17 + b: [test1, test2, test3] + c: {mean: 8.8, sum: [3, 2, 1], stddev: {a: 1, b: 2}} + bools: [true] + anything: [] + testB: + lat: 3.14 + lon: + a: 88 + b: [testX, testY, testZ] + c: {mean: 1.7, sum: [1], stddev: {z: 12}} + bools: true + anything: [10, 2.2, test, false] + `), false) + require.Error(t, err) + + _, err = ValidateInputSchema(cr.MustReadYAMLStr( + ` + _type: + map: {STRING: FLOAT} + str: STRING + floats: [FLOAT] + map2: + STRING: + lat: FLOAT + lon: + a: INT + b: [STRING] + c: {mean: FLOAT, sum: [INT], stddev: {STRING: INT}} + bools: [BOOL] + anything: [BOOL|INT|FLOAT|STRING] + _default: + map: {a: 2.2, b: 3} + str: test1 + floats: [2.2, 3.3, 4.4] + map2: + testA: + lat: 9.9 + lon: + a: 17 + b: [test1, test2, test3] + c: {mean: 8.8, sum: [3, 2, 1], stddev: {a: 1, b: 2}} + bools: [true] + anything: [] + testB: + lat: 3.14 + lon: + a: 88 + b: [testX, testY, testZ] + c: {mean: 1.7, sum: [1], stddev: {z: 12}} + bools: [1, 2, 3] + anything: [10, 2.2, test, false] + `), false) + require.Error(t, err) } diff --git a/pkg/operator/api/userconfig/value_type.go b/pkg/operator/api/userconfig/value_type.go index d99848619a..a7a0254efb 100644 --- a/pkg/operator/api/userconfig/value_type.go +++ b/pkg/operator/api/userconfig/value_type.go @@ -16,7 +16,12 @@ limitations under the License. package userconfig -import "strings" +import ( + "strings" + + "github.com/cortexlabs/cortex/pkg/lib/cast" + "github.com/cortexlabs/cortex/pkg/lib/errors" +) type ValueType int type ValueTypes []ValueType @@ -95,3 +100,37 @@ func (ts ValueTypes) StringList() []string { func (ts ValueTypes) String() string { return strings.Join(ts.StringList(), ", ") } + +func (t *ValueType) CastValue(value interface{}) (interface{}, error) { + switch *t { + case IntegerValueType: + valueInt, ok := cast.InterfaceToInt64(value) + if !ok { + return nil, ErrorUnsupportedLiteralType(value, t.String()) + } + return valueInt, nil + + case FloatValueType: + valueFloat, ok := cast.InterfaceToFloat64(value) + if !ok { + return nil, ErrorUnsupportedLiteralType(value, t.String()) + } + return valueFloat, nil + + case StringValueType: + valueStr, ok := value.(string) + if !ok { + return nil, ErrorUnsupportedLiteralType(value, t.String()) + } + return valueStr, nil + + case BoolValueType: + valueBool, ok := value.(bool) + if !ok { + return nil, ErrorUnsupportedLiteralType(value, t.String()) + } + return valueBool, nil + } + + return nil, errors.New(t.String(), "unimplemented ValueType") // unexpected +} diff --git a/pkg/operator/api/userconfig/value_type_test.go b/pkg/operator/api/userconfig/value_type_test.go new file mode 100644 index 0000000000..0221551e3d --- /dev/null +++ b/pkg/operator/api/userconfig/value_type_test.go @@ -0,0 +1,49 @@ +/* +Copyright 2019 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package userconfig + +import ( + "testing" + + "github.com/stretchr/testify/require" + + cr "github.com/cortexlabs/cortex/pkg/lib/configreader" +) + +func checkValueCastValueEqual(t *testing.T, typeStr string, valueYAML string, expected interface{}) { + valueType := ValueTypeFromString(typeStr) + require.NotEqual(t, valueType, UnknownValueType) + casted, err := valueType.CastValue(cr.MustReadYAMLStr(valueYAML)) + require.NoError(t, err) + require.Equal(t, casted, expected) +} + +func checkValueCastValueError(t *testing.T, typeStr string, valueYAML string) { + valueType := ValueTypeFromString(typeStr) + require.NotEqual(t, valueType, UnknownValueType) + _, err := valueType.CastValue(cr.MustReadYAMLStr(valueYAML)) + require.Error(t, err) +} + +func TestValueCastValue(t *testing.T) { + checkValueCastValueEqual(t, `INT`, `2`, int64(2)) + checkValueCastValueError(t, `STRING`, `2`) + checkValueCastValueEqual(t, `STRING`, `test`, "test") + checkValueCastValueEqual(t, `FLOAT`, `2`, float64(2)) + checkValueCastValueError(t, `BOOL`, `2`) + checkValueCastValueEqual(t, `BOOL`, `true`, true) +} diff --git a/pkg/operator/context/aggregates.go b/pkg/operator/context/aggregates.go index 318718a61c..22cf50b5e6 100644 --- a/pkg/operator/context/aggregates.go +++ b/pkg/operator/context/aggregates.go @@ -23,7 +23,6 @@ import ( "github.com/cortexlabs/cortex/pkg/consts" "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/hash" - s "github.com/cortexlabs/cortex/pkg/lib/strings" "github.com/cortexlabs/cortex/pkg/operator/api/context" "github.com/cortexlabs/cortex/pkg/operator/api/resource" "github.com/cortexlabs/cortex/pkg/operator/api/userconfig" @@ -33,53 +32,42 @@ func getAggregates( config *userconfig.Config, constants context.Constants, rawColumns context.RawColumns, - userAggregators map[string]*context.Aggregator, + aggregators context.Aggregators, root string, ) (context.Aggregates, error) { aggregates := context.Aggregates{} for _, aggregateConfig := range config.Aggregates { - if _, ok := constants[aggregateConfig.Name]; ok { - return nil, userconfig.ErrorDuplicateResourceName(aggregateConfig, constants[aggregateConfig.Name]) - } + aggregator := aggregators[aggregateConfig.Aggregator] - aggregator, err := getAggregator(aggregateConfig.Aggregator, userAggregators) - if err != nil { - return nil, errors.Wrap(err, userconfig.Identify(aggregateConfig), userconfig.AggregatorKey) + var validInputResources []context.Resource + for _, res := range constants { + validInputResources = append(validInputResources, res) } - - err = validateAggregateInputs(aggregateConfig, constants, rawColumns, aggregator) - if err != nil { - return nil, errors.WithStack(err) + for _, res := range rawColumns { + validInputResources = append(validInputResources, res) } - constantIDMap := make(map[string]string, len(aggregateConfig.Inputs.Args)) - constantIDWithTagsMap := make(map[string]string, len(aggregateConfig.Inputs.Args)) - for argName, constantName := range aggregateConfig.Inputs.Args { - constantNameStr := constantName.(string) - constant, ok := constants[constantNameStr] - if !ok { - return nil, errors.Wrap(userconfig.ErrorUndefinedResource(constantNameStr, resource.ConstantType), - userconfig.Identify(aggregateConfig), userconfig.InputsKey, userconfig.ArgsKey, argName) - } - constantIDMap[argName] = constant.ID - constantIDWithTagsMap[argName] = constant.IDWithTags + castedInput, inputID, err := ValidateInput( + aggregateConfig.Input, + aggregator.Input, + []resource.Type{resource.RawColumnType, resource.ConstantType}, + validInputResources, + config.Resources, + nil, + nil, + ) + if err != nil { + return nil, errors.Wrap(err, userconfig.Identify(aggregateConfig), userconfig.InputKey) } + aggregateConfig.Input = castedInput var buf bytes.Buffer - buf.WriteString(rawColumns.ColumnInputsID(aggregateConfig.Inputs.Columns)) - buf.WriteString(s.Obj(constantIDMap)) + buf.WriteString(inputID) buf.WriteString(aggregator.ID) id := hash.Bytes(buf.Bytes()) - buf.Reset() - buf.WriteString(rawColumns.ColumnInputsIDWithTags(aggregateConfig.Inputs.Columns)) - buf.WriteString(s.Obj(constantIDWithTagsMap)) - buf.WriteString(aggregator.IDWithTags) - buf.WriteString(aggregateConfig.Tags.ID()) - idWithTags := hash.Bytes(buf.Bytes()) - aggregateKey := filepath.Join( root, consts.AggregatesDir, @@ -90,7 +78,6 @@ func getAggregates( ComputedResourceFields: &context.ComputedResourceFields{ ResourceFields: &context.ResourceFields{ ID: id, - IDWithTags: idWithTags, ResourceType: resource.AggregateType, }, }, @@ -102,55 +89,3 @@ func getAggregates( return aggregates, nil } - -func validateAggregateInputs( - aggregateConfig *userconfig.Aggregate, - constants context.Constants, - rawColumns context.RawColumns, - aggregator *context.Aggregator, -) error { - if aggregateConfig.AggregatorPath != nil { - return nil - } - - columnRuntimeTypes, err := context.GetColumnRuntimeTypes(aggregateConfig.Inputs.Columns, rawColumns) - if err != nil { - return errors.Wrap(err, userconfig.Identify(aggregateConfig), userconfig.InputsKey, userconfig.ColumnsKey) - } - err = userconfig.CheckColumnRuntimeTypesMatch(columnRuntimeTypes, aggregator.Inputs.Columns) - if err != nil { - return errors.Wrap(err, userconfig.Identify(aggregateConfig), userconfig.InputsKey, userconfig.ColumnsKey) - } - - argTypes, err := getAggregateArgTypes(aggregateConfig.Inputs.Args, constants) - if err != nil { - return errors.Wrap(err, userconfig.Identify(aggregateConfig), userconfig.InputsKey, userconfig.ArgsKey) - } - err = userconfig.CheckArgRuntimeTypesMatch(argTypes, aggregator.Inputs.Args) - if err != nil { - return errors.Wrap(err, userconfig.Identify(aggregateConfig), userconfig.InputsKey, userconfig.ArgsKey) - } - - return nil -} - -func getAggregateArgTypes( - args map[string]interface{}, - constants context.Constants, -) (map[string]interface{}, error) { - - if len(args) == 0 { - return nil, nil - } - - argTypes := make(map[string]interface{}, len(args)) - for argName, constantName := range args { - constantNameStr := constantName.(string) - constant, ok := constants[constantNameStr] - if !ok { - return nil, errors.Wrap(userconfig.ErrorUndefinedResource(constantNameStr, resource.ConstantType), argName) - } - argTypes[argName] = constant.Type - } - return argTypes, nil -} diff --git a/pkg/operator/context/aggregators.go b/pkg/operator/context/aggregators.go index 2094e2e302..26414a6948 100644 --- a/pkg/operator/context/aggregators.go +++ b/pkg/operator/context/aggregators.go @@ -39,7 +39,7 @@ func loadUserAggregators( for _, aggregatorConfig := range config.Aggregators { impl, ok := impls[aggregatorConfig.Path] if !ok { - return nil, errors.Wrap(ErrorImplDoesNotExist(aggregatorConfig.Path), userconfig.Identify(aggregatorConfig)) + return nil, errors.Wrap(userconfig.ErrorImplDoesNotExist(aggregatorConfig.Path), userconfig.Identify(aggregatorConfig)) } aggregator, err := newAggregator(*aggregatorConfig, impl, nil, pythonPackages) if err != nil { @@ -55,7 +55,7 @@ func loadUserAggregators( impl, ok := impls[*aggregateConfig.AggregatorPath] if !ok { - return nil, errors.Wrap(ErrorImplDoesNotExist(*aggregateConfig.AggregatorPath), userconfig.Identify(aggregateConfig)) + return nil, errors.Wrap(userconfig.ErrorImplDoesNotExist(*aggregateConfig.AggregatorPath), userconfig.Identify(aggregateConfig)) } implHash := hash.Bytes(impl) @@ -67,7 +67,8 @@ func loadUserAggregators( ResourceFields: userconfig.ResourceFields{ Name: implHash, }, - Path: *aggregateConfig.AggregatorPath, + OutputType: nil, + Path: *aggregateConfig.AggregatorPath, } aggregator, err := newAggregator(*anonAggregatorConfig, impl, nil, pythonPackages) if err != nil { @@ -91,7 +92,7 @@ func newAggregator( implID := hash.Bytes(impl) var buf bytes.Buffer - buf.WriteString(context.DataTypeID(aggregatorConfig.Inputs)) + buf.WriteString(context.DataTypeID(aggregatorConfig.Input)) buf.WriteString(context.DataTypeID(aggregatorConfig.OutputType)) buf.WriteString(implID) @@ -104,14 +105,12 @@ func newAggregator( aggregator := &context.Aggregator{ ResourceFields: &context.ResourceFields{ ID: id, - IDWithTags: id, ResourceType: resource.AggregatorType, }, Aggregator: &aggregatorConfig, Namespace: namespace, ImplKey: filepath.Join(consts.AggregatorsDir, implID+".py"), } - aggregator.Aggregator.Path = "" if err := uploadAggregator(aggregator, impl); err != nil { return nil, err @@ -141,20 +140,6 @@ func uploadAggregator(aggregator *context.Aggregator, impl []byte) error { return nil } -func getAggregator( - name string, - userAggregators map[string]*context.Aggregator, -) (*context.Aggregator, error) { - - if aggregator, ok := builtinAggregators[name]; ok { - return aggregator, nil - } - if aggregator, ok := userAggregators[name]; ok { - return aggregator, nil - } - return nil, userconfig.ErrorUndefinedResourceBuiltin(name, resource.AggregatorType) -} - func getAggregators( config *userconfig.Config, userAggregators map[string]*context.Aggregator, @@ -162,14 +147,23 @@ func getAggregators( aggregators := context.Aggregators{} for _, aggregateConfig := range config.Aggregates { - if _, ok := aggregators[aggregateConfig.Aggregator]; ok { + name := aggregateConfig.Aggregator + + if _, ok := aggregators[name]; ok { continue } - aggregator, err := getAggregator(aggregateConfig.Aggregator, userAggregators) - if err != nil { - return nil, errors.Wrap(err, userconfig.Identify(aggregateConfig), userconfig.AggregatorKey) + + if aggregator, ok := builtinAggregators[name]; ok { + aggregators[name] = aggregator + continue } - aggregators[aggregateConfig.Aggregator] = aggregator + + if aggregator, ok := userAggregators[name]; ok { + aggregators[name] = aggregator + continue + } + + return nil, errors.Wrap(userconfig.ErrorUndefinedResource(name, resource.AggregatorType), userconfig.Identify(aggregateConfig), userconfig.AggregatorKey) } return aggregators, nil diff --git a/pkg/operator/context/apis.go b/pkg/operator/context/apis.go index 1ff5045575..b2024f5c93 100644 --- a/pkg/operator/context/apis.go +++ b/pkg/operator/context/apis.go @@ -19,10 +19,12 @@ package context import ( "bytes" + "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/hash" "github.com/cortexlabs/cortex/pkg/operator/api/context" "github.com/cortexlabs/cortex/pkg/operator/api/resource" "github.com/cortexlabs/cortex/pkg/operator/api/userconfig" + "github.com/cortexlabs/yaml" ) func getAPIs(config *userconfig.Config, @@ -31,27 +33,28 @@ func getAPIs(config *userconfig.Config, apis := context.APIs{} for _, apiConfig := range config.APIs { - model := models[apiConfig.ModelName] + modelName, _ := yaml.ExtractAtSymbolText(apiConfig.Model) + + model := models[modelName] + if model == nil { + return nil, errors.Wrap(userconfig.ErrorUndefinedResource(modelName, resource.ModelType), userconfig.Identify(apiConfig), userconfig.ModelNameKey) + } var buf bytes.Buffer buf.WriteString(apiConfig.Name) buf.WriteString(model.ID) id := hash.Bytes(buf.Bytes()) - buf.WriteString(model.IDWithTags) - buf.WriteString(apiConfig.Tags.ID()) - idWithTags := hash.Bytes(buf.Bytes()) - apis[apiConfig.Name] = &context.API{ ComputedResourceFields: &context.ComputedResourceFields{ ResourceFields: &context.ResourceFields{ ID: id, - IDWithTags: idWithTags, ResourceType: resource.APIType, }, }, - API: apiConfig, - Path: context.APIPath(apiConfig.Name, config.App.Name), + API: apiConfig, + Path: context.APIPath(apiConfig.Name, config.App.Name), + ModelName: modelName, } } return apis, nil diff --git a/pkg/operator/context/autogenerator.go b/pkg/operator/context/autogenerator.go deleted file mode 100644 index a28e0d7e6a..0000000000 --- a/pkg/operator/context/autogenerator.go +++ /dev/null @@ -1,129 +0,0 @@ -/* -Copyright 2019 Cortex Labs, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package context - -import ( - "strings" - - "github.com/cortexlabs/cortex/pkg/lib/configreader" - "github.com/cortexlabs/cortex/pkg/lib/errors" - s "github.com/cortexlabs/cortex/pkg/lib/strings" - "github.com/cortexlabs/cortex/pkg/operator/api/context" - "github.com/cortexlabs/cortex/pkg/operator/api/resource" - "github.com/cortexlabs/cortex/pkg/operator/api/userconfig" -) - -func autoGenerateConfig( - config *userconfig.Config, - userAggregators map[string]*context.Aggregator, - userTransformers map[string]*context.Transformer, -) error { - - for _, aggregate := range config.Aggregates { - for argName, argVal := range aggregate.Inputs.Args { - if argValStr, ok := argVal.(string); ok { - if s.HasPrefixAndSuffix(argValStr, "\"") { - argVal = s.TrimPrefixAndSuffix(argValStr, "\"") - } else { - continue // assume it's a reference to a constant - } - } - - aggregator, err := getAggregator(aggregate.Aggregator, userAggregators) - if err != nil { - return errors.Wrap(err, userconfig.Identify(aggregate), userconfig.AggregatorKey) - } - - var argType interface{} - if aggregator.Inputs != nil { - var ok bool - argType, ok = aggregator.Inputs.Args[argName] - if !ok { - return errors.Wrap(configreader.ErrorUnsupportedKey(argName), userconfig.Identify(aggregate), userconfig.InputsKey, userconfig.ArgsKey) - } - } - - constantName := strings.Join([]string{ - resource.AggregateType.String(), - aggregate.Name, - userconfig.InputsKey, - userconfig.ArgsKey, - argName, - }, "/") - - constant := &userconfig.Constant{ - ResourceFields: userconfig.ResourceFields{ - Name: constantName, - }, - Type: argType, - Value: argVal, - Tags: make(map[string]interface{}), - } - config.Constants = append(config.Constants, constant) - - aggregate.Inputs.Args[argName] = constantName - } - } - - for _, transformedColumn := range config.TransformedColumns { - for argName, argVal := range transformedColumn.Inputs.Args { - if argValStr, ok := argVal.(string); ok { - if s.HasPrefixAndSuffix(argValStr, "\"") { - argVal = s.TrimPrefixAndSuffix(argValStr, "\"") - } else { - continue // assume it's a reference to a constant or aggregate - } - } - - transformer, err := getTransformer(transformedColumn.Transformer, userTransformers) - if err != nil { - return errors.Wrap(err, userconfig.Identify(transformedColumn), userconfig.TransformerKey) - } - - var argType interface{} - if transformer.Inputs != nil { - var ok bool - argType, ok = transformer.Inputs.Args[argName] - if !ok { - return errors.Wrap(configreader.ErrorUnsupportedKey(argName), userconfig.Identify(transformedColumn), userconfig.InputsKey, userconfig.ArgsKey) - } - } - - constantName := strings.Join([]string{ - resource.TransformedColumnType.String(), - transformedColumn.Name, - userconfig.InputsKey, - userconfig.ArgsKey, - argName, - }, "/") - - constant := &userconfig.Constant{ - ResourceFields: userconfig.ResourceFields{ - Name: constantName, - }, - Type: argType, - Value: argVal, - Tags: make(map[string]interface{}), - } - config.Constants = append(config.Constants, constant) - - transformedColumn.Inputs.Args[argName] = constantName - } - } - - return nil -} diff --git a/pkg/operator/context/constants.go b/pkg/operator/context/constants.go index bdb2edcba1..53114edfa9 100644 --- a/pkg/operator/context/constants.go +++ b/pkg/operator/context/constants.go @@ -21,20 +21,17 @@ import ( "path/filepath" "github.com/cortexlabs/cortex/pkg/consts" - "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/hash" - "github.com/cortexlabs/cortex/pkg/lib/msgpack" "github.com/cortexlabs/cortex/pkg/lib/sets/strset" s "github.com/cortexlabs/cortex/pkg/lib/strings" "github.com/cortexlabs/cortex/pkg/operator/api/context" "github.com/cortexlabs/cortex/pkg/operator/api/resource" "github.com/cortexlabs/cortex/pkg/operator/api/userconfig" - "github.com/cortexlabs/cortex/pkg/operator/config" ) var uploadedConstants = strset.New() -func loadConstants(constantConfigs userconfig.Constants) (context.Constants, error) { +func getConstants(constantConfigs userconfig.Constants) (context.Constants, error) { constants := context.Constants{} for _, constantConfig := range constantConfigs { constant, err := newConstant(*constantConfig) @@ -52,44 +49,15 @@ func newConstant(constantConfig userconfig.Constant) (*context.Constant, error) buf.WriteString(context.DataTypeID(constantConfig.Type)) buf.WriteString(s.Obj(constantConfig.Value)) id := hash.Bytes(buf.Bytes()) - idWithTags := hash.String(id + constantConfig.Tags.ID()) constant := &context.Constant{ ResourceFields: &context.ResourceFields{ ID: id, - IDWithTags: idWithTags, ResourceType: resource.ConstantType, }, Constant: &constantConfig, Key: filepath.Join(consts.ConstantsDir, id+".msgpack"), } - if err := uploadConstant(constant); err != nil { - return nil, err - } - - constant.Constant.Value = nil return constant, nil } - -func uploadConstant(constant *context.Constant) error { - if uploadedConstants.Has(constant.ID) { - return nil - } - - isUploaded, err := config.AWS.IsS3File(constant.Key) - if err != nil { - return errors.Wrap(err, userconfig.Identify(constant), "upload") - } - - if !isUploaded { - serializedConstant := msgpack.MustMarshal(constant.Value) - err = config.AWS.UploadBytesToS3(serializedConstant, constant.Key) - if err != nil { - return errors.Wrap(err, userconfig.Identify(constant), "upload") - } - } - - uploadedConstants.Add(constant.ID) - return nil -} diff --git a/pkg/operator/context/context.go b/pkg/operator/context/context.go index 562b062ea3..641d7f87e9 100644 --- a/pkg/operator/context/context.go +++ b/pkg/operator/context/context.go @@ -38,14 +38,20 @@ var ( uploadedAggregators = strset.New() builtinTransformers = make(map[string]*context.Transformer) uploadedTransformers = strset.New() + builtinEstimators = make(map[string]*context.Estimator) + uploadedEstimators = strset.New() + OperatorAggregatorsDir = configreader.MustStringFromEnv( + "CONST_OPERATOR_AGGREGATORS_DIR", + &configreader.StringValidation{Default: "/src/aggregators"}, + ) OperatorTransformersDir = configreader.MustStringFromEnv( "CONST_OPERATOR_TRANSFORMERS_DIR", &configreader.StringValidation{Default: "/src/transformers"}, ) - OperatorAggregatorsDir = configreader.MustStringFromEnv( - "CONST_OPERATOR_AGGREGATORS_DIR", - &configreader.StringValidation{Default: "/src/aggregators"}, + OperatorEstimatorsDir = configreader.MustStringFromEnv( + "CONST_OPERATOR_ESTIMATORS_DIR", + &configreader.StringValidation{Default: "/src/estimators"}, ) ) @@ -88,6 +94,25 @@ func Init() error { builtinTransformers["cortex."+transConfig.Name] = transformer } + estimatorConfigPath := filepath.Join(OperatorEstimatorsDir, "estimators.yaml") + estimatorConfig, err := userconfig.NewPartialPath(estimatorConfigPath) + if err != nil { + return err + } + + for _, estimatorConfig := range estimatorConfig.Estimators { + implPath := filepath.Join(OperatorEstimatorsDir, estimatorConfig.Path) + impl, err := files.ReadFileBytes(implPath) + if err != nil { + return errors.Wrap(err, userconfig.Identify(estimatorConfig)) + } + estimator, err := newEstimator(*estimatorConfig, impl, pointer.String("cortex"), nil) + if err != nil { + return err + } + builtinEstimators["cortex."+estimatorConfig.Name] = estimator + } + return nil } @@ -135,22 +160,22 @@ func New( } ctx.PythonPackages = pythonPackages - userTransformers, err := loadUserTransformers(userconf, files, pythonPackages) + userAggregators, err := loadUserAggregators(userconf, files, pythonPackages) if err != nil { return nil, err } - userAggregators, err := loadUserAggregators(userconf, files, pythonPackages) + userTransformers, err := loadUserTransformers(userconf, files, pythonPackages) if err != nil { return nil, err } - err = autoGenerateConfig(userconf, userAggregators, userTransformers) + userEstimators, err := loadUserEstimators(userconf, files, pythonPackages) if err != nil { return nil, err } - constants, err := loadConstants(userconf.Constants) + constants, err := getConstants(userconf.Constants) if err != nil { return nil, err } @@ -168,25 +193,31 @@ func New( } ctx.Transformers = transformers + estimators, err := getEstimators(userconf, userEstimators) + if err != nil { + return nil, err + } + ctx.Estimators = estimators + rawColumns, err := getRawColumns(userconf, ctx.Environment) if err != nil { return nil, err } ctx.RawColumns = rawColumns - aggregates, err := getAggregates(userconf, constants, rawColumns, userAggregators, ctx.Root) + aggregates, err := getAggregates(userconf, constants, rawColumns, aggregators, ctx.Root) if err != nil { return nil, err } ctx.Aggregates = aggregates - transformedColumns, err := getTransformedColumns(userconf, constants, rawColumns, ctx.Aggregates, userTransformers, ctx.Root) + transformedColumns, err := getTransformedColumns(userconf, constants, rawColumns, aggregates, aggregators, transformers, ctx.Root) if err != nil { return nil, err } ctx.TransformedColumns = transformedColumns - models, err := getModels(userconf, aggregates, ctx.Columns(), files, ctx.Root, pythonPackages) + models, err := getModels(userconf, constants, ctx.Columns(), aggregates, transformedColumns, aggregators, transformers, estimators, ctx.Root) if err != nil { return nil, err } diff --git a/pkg/operator/context/environment.go b/pkg/operator/context/environment.go index fcd0c635d1..b00414f0e6 100644 --- a/pkg/operator/context/environment.go +++ b/pkg/operator/context/environment.go @@ -19,6 +19,8 @@ package context import ( "bytes" + "github.com/cortexlabs/yaml" + "github.com/cortexlabs/cortex/pkg/lib/hash" s "github.com/cortexlabs/cortex/pkg/lib/strings" "github.com/cortexlabs/cortex/pkg/operator/api/context" @@ -38,7 +40,7 @@ func dataID(config *userconfig.Config, datasetVersion string) string { rawColumnTypeMap := make(map[string]userconfig.ColumnType, len(config.RawColumns)) for _, rawColumnConfig := range config.RawColumns { - rawColumnTypeMap[rawColumnConfig.GetName()] = rawColumnConfig.GetType() + rawColumnTypeMap[rawColumnConfig.GetName()] = rawColumnConfig.GetColumnType() } buf.WriteString(s.Obj(config.Environment.Limit)) buf.WriteString(s.Obj(rawColumnTypeMap)) @@ -53,7 +55,8 @@ func dataID(config *userconfig.Config, datasetVersion string) string { buf.WriteString(s.Bool(typedData.DropNull)) schemaMap := map[string]string{} // use map to sort keys for _, parqCol := range typedData.Schema { - schemaMap[parqCol.RawColumnName] = parqCol.ParquetColumnName + colName, _ := yaml.ExtractAtSymbolText(parqCol.RawColumn) + schemaMap[colName] = parqCol.ParquetColumnName } buf.WriteString(s.Obj(schemaMap)) } diff --git a/pkg/operator/context/errors.go b/pkg/operator/context/errors.go deleted file mode 100644 index 603da7b834..0000000000 --- a/pkg/operator/context/errors.go +++ /dev/null @@ -1,85 +0,0 @@ -/* -Copyright 2019 Cortex Labs, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package context - -import ( - "fmt" -) - -type ErrorKind int - -const ( - ErrUnknown ErrorKind = iota - ErrImplDoesNotExist -) - -var errorKinds = []string{ - "err_unknown", - "err_impl_does_not_exist", -} - -var _ = [1]int{}[int(ErrImplDoesNotExist)-(len(errorKinds)-1)] // Ensure list length matches - -func (t ErrorKind) String() string { - return errorKinds[t] -} - -// MarshalText satisfies TextMarshaler -func (t ErrorKind) MarshalText() ([]byte, error) { - return []byte(t.String()), nil -} - -// UnmarshalText satisfies TextUnmarshaler -func (t *ErrorKind) UnmarshalText(text []byte) error { - enum := string(text) - for i := 0; i < len(errorKinds); i++ { - if enum == errorKinds[i] { - *t = ErrorKind(i) - return nil - } - } - - *t = ErrUnknown - return nil -} - -// UnmarshalBinary satisfies BinaryUnmarshaler -// Needed for msgpack -func (t *ErrorKind) UnmarshalBinary(data []byte) error { - return t.UnmarshalText(data) -} - -// MarshalBinary satisfies BinaryMarshaler -func (t ErrorKind) MarshalBinary() ([]byte, error) { - return []byte(t.String()), nil -} - -type Error struct { - Kind ErrorKind - message string -} - -func (e Error) Error() string { - return e.message -} - -func ErrorImplDoesNotExist(path string) error { - return Error{ - Kind: ErrImplDoesNotExist, - message: fmt.Sprintf("%s: implementation file does not exist", path), - } -} diff --git a/pkg/operator/context/estimators.go b/pkg/operator/context/estimators.go new file mode 100644 index 0000000000..14da666cdd --- /dev/null +++ b/pkg/operator/context/estimators.go @@ -0,0 +1,174 @@ +/* +Copyright 2019 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package context + +import ( + "bytes" + "path/filepath" + + "github.com/cortexlabs/cortex/pkg/consts" + "github.com/cortexlabs/cortex/pkg/lib/errors" + "github.com/cortexlabs/cortex/pkg/lib/hash" + "github.com/cortexlabs/cortex/pkg/operator/api/context" + "github.com/cortexlabs/cortex/pkg/operator/api/resource" + "github.com/cortexlabs/cortex/pkg/operator/api/userconfig" + "github.com/cortexlabs/cortex/pkg/operator/config" +) + +func loadUserEstimators( + config *userconfig.Config, + impls map[string][]byte, + pythonPackages context.PythonPackages, +) (map[string]*context.Estimator, error) { + + userEstimators := make(map[string]*context.Estimator) + for _, estimatorConfig := range config.Estimators { + impl, ok := impls[estimatorConfig.Path] + if !ok { + return nil, errors.Wrap(userconfig.ErrorImplDoesNotExist(estimatorConfig.Path), userconfig.Identify(estimatorConfig)) + } + estimator, err := newEstimator(*estimatorConfig, impl, nil, pythonPackages) + if err != nil { + return nil, err + } + userEstimators[estimator.Name] = estimator + } + + for _, modelConfig := range config.Models { + if modelConfig.EstimatorPath == nil { + continue + } + + impl, ok := impls[*modelConfig.EstimatorPath] + if !ok { + return nil, errors.Wrap(userconfig.ErrorImplDoesNotExist(*modelConfig.EstimatorPath), userconfig.Identify(modelConfig)) + } + + implHash := hash.Bytes(impl) + if _, ok := userEstimators[implHash]; ok { + continue + } + + anonEstimatorConfig := &userconfig.Estimator{ + ResourceFields: userconfig.ResourceFields{ + Name: implHash, + }, + TargetColumn: userconfig.InferredColumnType, + PredictionKey: modelConfig.PredictionKey, + Path: *modelConfig.EstimatorPath, + } + estimator, err := newEstimator(*anonEstimatorConfig, impl, nil, pythonPackages) + if err != nil { + return nil, err + } + + modelConfig.Estimator = estimator.Name + userEstimators[anonEstimatorConfig.Name] = estimator + } + + return userEstimators, nil +} + +func newEstimator( + estimatorConfig userconfig.Estimator, + impl []byte, + namespace *string, + pythonPackages context.PythonPackages, +) (*context.Estimator, error) { + + implID := hash.Bytes(impl) + + var buf bytes.Buffer + buf.WriteString(context.DataTypeID(estimatorConfig.TargetColumn)) + buf.WriteString(context.DataTypeID(estimatorConfig.Input)) + buf.WriteString(context.DataTypeID(estimatorConfig.TrainingInput)) + buf.WriteString(context.DataTypeID(estimatorConfig.Hparams)) + buf.WriteString(estimatorConfig.PredictionKey) + buf.WriteString(implID) + + for _, pythonPackage := range pythonPackages { + buf.WriteString(pythonPackage.GetID()) + } + + id := hash.Bytes(buf.Bytes()) + + estimator := &context.Estimator{ + ResourceFields: &context.ResourceFields{ + ID: id, + ResourceType: resource.EstimatorType, + }, + Estimator: &estimatorConfig, + Namespace: namespace, + ImplKey: filepath.Join(consts.EstimatorsDir, implID+".py"), + } + + if err := uploadEstimator(estimator, impl); err != nil { + return nil, err + } + + return estimator, nil +} + +func uploadEstimator(estimator *context.Estimator, impl []byte) error { + if uploadedEstimators.Has(estimator.ID) { + return nil + } + + isUploaded, err := config.AWS.IsS3File(estimator.ImplKey) + if err != nil { + return errors.Wrap(err, userconfig.Identify(estimator), "upload") + } + + if !isUploaded { + err = config.AWS.UploadBytesToS3(impl, estimator.ImplKey) + if err != nil { + return errors.Wrap(err, userconfig.Identify(estimator), "upload") + } + } + + uploadedEstimators.Add(estimator.ID) + return nil +} + +func getEstimators( + config *userconfig.Config, + userEstimators map[string]*context.Estimator, +) (context.Estimators, error) { + + estimators := context.Estimators{} + for _, modelConfig := range config.Models { + name := modelConfig.Estimator + + if _, ok := estimators[name]; ok { + continue + } + + if estimator, ok := builtinEstimators[name]; ok { + estimators[name] = estimator + continue + } + + if estimator, ok := userEstimators[name]; ok { + estimators[name] = estimator + continue + } + + return nil, errors.Wrap(userconfig.ErrorUndefinedResource(name, resource.EstimatorType), userconfig.Identify(modelConfig), userconfig.EstimatorKey) + } + + return estimators, nil +} diff --git a/pkg/operator/context/models.go b/pkg/operator/context/models.go index d0323c5cf7..7a731fc614 100644 --- a/pkg/operator/context/models.go +++ b/pkg/operator/context/models.go @@ -21,68 +21,126 @@ import ( "path/filepath" "strings" + "github.com/cortexlabs/yaml" + "github.com/cortexlabs/cortex/pkg/consts" "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/hash" - "github.com/cortexlabs/cortex/pkg/lib/sets/strset" s "github.com/cortexlabs/cortex/pkg/lib/strings" "github.com/cortexlabs/cortex/pkg/operator/api/context" "github.com/cortexlabs/cortex/pkg/operator/api/resource" "github.com/cortexlabs/cortex/pkg/operator/api/userconfig" - "github.com/cortexlabs/cortex/pkg/operator/config" ) -var uploadedModels = strset.New() - func getModels( config *userconfig.Config, - aggregates context.Aggregates, + constants context.Constants, columns context.Columns, - impls map[string][]byte, + aggregates context.Aggregates, + transformedColumns context.TransformedColumns, + aggregators context.Aggregators, + transformers context.Transformers, + estimators context.Estimators, root string, - pythonPackages context.PythonPackages, ) (context.Models, error) { models := context.Models{} for _, modelConfig := range config.Models { - modelImplID, modelImplKey, err := getModelImplID(modelConfig.Path, impls) - if err != nil { - return nil, errors.Wrap(err, userconfig.Identify(modelConfig), userconfig.PathKey) + estimator := estimators[modelConfig.Estimator] + + var validInputResources []context.Resource + for _, res := range constants { + validInputResources = append(validInputResources, res) + } + for _, res := range columns { + validInputResources = append(validInputResources, res) + } + for _, res := range aggregates { + validInputResources = append(validInputResources, res) + } + for _, res := range transformedColumns { + validInputResources = append(validInputResources, res) } - targetDataType := columns[modelConfig.TargetColumn].GetType() - err = context.ValidateModelTargetType(targetDataType, modelConfig.Type) + // Input + castedInput, inputID, err := ValidateInput( + modelConfig.Input, + estimator.Input, + []resource.Type{resource.RawColumnType, resource.TransformedColumnType, resource.ConstantType, resource.AggregateType, resource.TransformedColumnType}, + validInputResources, + config.Resources, + aggregators, + transformers, + ) if err != nil { - return nil, errors.Wrap(err, userconfig.Identify(modelConfig)) + return nil, errors.Wrap(err, userconfig.Identify(modelConfig), userconfig.InputKey) + } + modelConfig.Input = castedInput + + // TrainingInput + castedTrainingInput, trainingInputID, err := ValidateInput( + modelConfig.TrainingInput, + estimator.TrainingInput, + []resource.Type{resource.RawColumnType, resource.TransformedColumnType, resource.ConstantType, resource.AggregateType, resource.TransformedColumnType}, + validInputResources, + config.Resources, + aggregators, + transformers, + ) + if err != nil { + return nil, errors.Wrap(err, userconfig.Identify(modelConfig), userconfig.TrainingInputKey) + } + modelConfig.TrainingInput = castedTrainingInput + + // Hparams + if estimator.Hparams != nil { + castedHparams, err := userconfig.CastInputValue(modelConfig.Hparams, estimator.Hparams) + if err != nil { + return nil, errors.Wrap(err, userconfig.Identify(modelConfig), userconfig.HparamsKey) + } + modelConfig.Hparams = castedHparams } - var buf bytes.Buffer - buf.WriteString(modelConfig.Type.String()) - buf.WriteString(modelImplID) - for _, pythonPackage := range pythonPackages { - buf.WriteString(pythonPackage.GetID()) + // TargetColumn + targetColumnName, _ := yaml.UnescapeAtSymbol(modelConfig.TargetColumn) + targetColumn := columns[targetColumnName] + if targetColumn == nil { + return nil, errors.Wrap(userconfig.ErrorUndefinedResource(targetColumnName, resource.RawColumnType, resource.TransformedColumnType), userconfig.Identify(modelConfig), userconfig.TargetColumnKey) + } + if targetColumn.GetColumnType() != userconfig.IntegerColumnType && targetColumn.GetColumnType() != userconfig.FloatColumnType { + return nil, userconfig.ErrorTargetColumnIntOrFloat() + } + if estimator.TargetColumn != userconfig.InferredColumnType { + if targetColumn.GetColumnType() != estimator.TargetColumn { + return nil, errors.Wrap(userconfig.ErrorUnsupportedOutputType(targetColumn.GetColumnType(), estimator.TargetColumn), userconfig.Identify(modelConfig), userconfig.TargetColumnKey) + } } - buf.WriteString(modelConfig.PredictionKey) + + // Model ID + var buf bytes.Buffer + buf.WriteString(inputID) + buf.WriteString(trainingInputID) + buf.WriteString(targetColumn.GetID()) buf.WriteString(s.Obj(modelConfig.Hparams)) buf.WriteString(s.Obj(modelConfig.DataPartitionRatio)) buf.WriteString(s.Obj(modelConfig.Training)) buf.WriteString(s.Obj(modelConfig.Evaluation)) - buf.WriteString(columns.IDWithTags(modelConfig.AllColumnNames())) // A change in tags can invalidate the model - - for _, aggregate := range modelConfig.Aggregates { - buf.WriteString(aggregates[aggregate].GetID()) - } - buf.WriteString(modelConfig.Tags.ID()) - + buf.WriteString(estimator.ID) modelID := hash.Bytes(buf.Bytes()) + // Dataset ID buf.Reset() buf.WriteString(s.Obj(modelConfig.DataPartitionRatio)) - buf.WriteString(columns.ID(modelConfig.AllColumnNames())) + combinedInput := []interface{}{modelConfig.Input, modelConfig.TrainingInput, modelConfig.TargetColumn} + var columnSlice []context.Resource + for _, res := range columns { + columnSlice = append(columnSlice, res) + } + for _, col := range context.ExtractCortexResources(combinedInput, columnSlice, resource.RawColumnType, resource.TransformedColumnType) { + buf.WriteString(col.GetID()) + } datasetID := hash.Bytes(buf.Bytes()) - buf.WriteString(columns.IDWithTags(modelConfig.AllColumnNames())) - datasetIDWithTags := hash.Bytes(buf.Bytes()) datasetRoot := filepath.Join(root, consts.TrainingDataDir, datasetID) trainingDatasetName := strings.Join([]string{ @@ -94,14 +152,11 @@ func getModels( ComputedResourceFields: &context.ComputedResourceFields{ ResourceFields: &context.ResourceFields{ ID: modelID, - IDWithTags: modelID, ResourceType: resource.ModelType, }, }, - Model: modelConfig, - Key: filepath.Join(root, consts.ModelsDir, modelID+".zip"), - ImplID: modelImplID, - ImplKey: modelImplKey, + Model: modelConfig, + Key: filepath.Join(root, consts.ModelsDir, modelID+".zip"), Dataset: &context.TrainingDataset{ ResourceFields: userconfig.ResourceFields{ Name: trainingDatasetName, @@ -111,7 +166,6 @@ func getModels( ComputedResourceFields: &context.ComputedResourceFields{ ResourceFields: &context.ResourceFields{ ID: datasetID, - IDWithTags: datasetIDWithTags, ResourceType: resource.TrainingDatasetType, }, }, @@ -124,42 +178,3 @@ func getModels( return models, nil } - -func getModelImplID(implPath string, impls map[string][]byte) (string, string, error) { - impl, ok := impls[implPath] - if !ok { - return "", "", ErrorImplDoesNotExist(implPath) - } - modelImplID := hash.Bytes(impl) - modelImplKey, err := uploadModelImpl(modelImplID, impl) - if err != nil { - return "", "", errors.Wrap(err, implPath) - } - return modelImplID, modelImplKey, nil -} - -func uploadModelImpl(modelImplID string, impl []byte) (string, error) { - modelImplKey := filepath.Join( - consts.ModelImplsDir, - modelImplID+".py", - ) - - if uploadedModels.Has(modelImplID) { - return modelImplKey, nil - } - - isUploaded, err := config.AWS.IsS3File(modelImplKey) - if err != nil { - return "", errors.Wrap(err, "upload") - } - - if !isUploaded { - err = config.AWS.UploadBytesToS3(impl, modelImplKey) - if err != nil { - return "", errors.Wrap(err, "upload") - } - } - - uploadedModels.Add(modelImplID) - return modelImplKey, nil -} diff --git a/pkg/operator/context/raw_columns.go b/pkg/operator/context/raw_columns.go index 9e236ba913..4ee4b702d6 100644 --- a/pkg/operator/context/raw_columns.go +++ b/pkg/operator/context/raw_columns.go @@ -40,7 +40,7 @@ func getRawColumns( var buf bytes.Buffer buf.WriteString(env.ID) buf.WriteString(columnConfig.GetName()) - buf.WriteString(columnConfig.GetType().String()) + buf.WriteString(columnConfig.GetColumnType().String()) var rawColumn context.RawColumn switch typedColumnConfig := columnConfig.(type) { @@ -50,12 +50,10 @@ func getRawColumns( buf.WriteString(s.Obj(typedColumnConfig.Max)) buf.WriteString(s.Obj(slices.SortInt64sCopy(typedColumnConfig.Values))) id := hash.Bytes(buf.Bytes()) - idWithTags := hash.String(id + typedColumnConfig.Tags.ID()) rawColumn = &context.RawIntColumn{ ComputedResourceFields: &context.ComputedResourceFields{ ResourceFields: &context.ResourceFields{ ID: id, - IDWithTags: idWithTags, ResourceType: resource.RawColumnType, }, }, @@ -67,12 +65,10 @@ func getRawColumns( buf.WriteString(s.Obj(typedColumnConfig.Max)) buf.WriteString(s.Obj(slices.SortFloat32sCopy(typedColumnConfig.Values))) id := hash.Bytes(buf.Bytes()) - idWithTags := hash.String(id + typedColumnConfig.Tags.ID()) rawColumn = &context.RawFloatColumn{ ComputedResourceFields: &context.ComputedResourceFields{ ResourceFields: &context.ResourceFields{ ID: id, - IDWithTags: idWithTags, ResourceType: resource.RawColumnType, }, }, @@ -82,12 +78,10 @@ func getRawColumns( buf.WriteString(s.Bool(typedColumnConfig.Required)) buf.WriteString(s.Obj(slices.SortStrsCopy(typedColumnConfig.Values))) id := hash.Bytes(buf.Bytes()) - idWithTags := hash.String(id + typedColumnConfig.Tags.ID()) rawColumn = &context.RawStringColumn{ ComputedResourceFields: &context.ComputedResourceFields{ ResourceFields: &context.ResourceFields{ ID: id, - IDWithTags: idWithTags, ResourceType: resource.RawColumnType, }, }, @@ -106,7 +100,7 @@ func getRawColumns( RawInferredColumn: typedColumnConfig, } default: - return nil, errors.Wrap(configreader.ErrorInvalidStr(typedColumnConfig.GetType().String(), userconfig.IntegerColumnType.String(), userconfig.FloatColumnType.String(), userconfig.StringColumnType.String()), userconfig.Identify(columnConfig)) // unexpected error + return nil, errors.Wrap(configreader.ErrorInvalidStr(typedColumnConfig.GetColumnType().String(), userconfig.IntegerColumnType.String(), userconfig.FloatColumnType.String(), userconfig.StringColumnType.String()), userconfig.Identify(columnConfig)) // unexpected error } rawColumns[columnConfig.GetName()] = rawColumn diff --git a/pkg/operator/context/resource_fakes_test.go b/pkg/operator/context/resource_fakes_test.go new file mode 100644 index 0000000000..6b06bddf94 --- /dev/null +++ b/pkg/operator/context/resource_fakes_test.go @@ -0,0 +1,344 @@ +/* +Copyright 2019 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package context + +import ( + cr "github.com/cortexlabs/cortex/pkg/lib/configreader" + "github.com/cortexlabs/cortex/pkg/operator/api/context" + "github.com/cortexlabs/cortex/pkg/operator/api/resource" + "github.com/cortexlabs/cortex/pkg/operator/api/userconfig" +) + +func mustValidateOutputSchema(yamlStr string) userconfig.OutputSchema { + outputSchema, err := userconfig.ValidateOutputSchema(cr.MustReadYAMLStr(yamlStr)) + if err != nil { + panic(err) + } + return outputSchema +} + +func genConst(id string, outputType string, value string) *context.Constant { + var outType userconfig.OutputSchema = nil + if outputType != "" { + outType = mustValidateOutputSchema(outputType) + } + + constant := &context.Constant{ + Constant: &userconfig.Constant{ + ResourceFields: userconfig.ResourceFields{ + Name: "c" + id, + }, + Type: outType, + Value: cr.MustReadYAMLStr(value), + }, + ResourceFields: &context.ResourceFields{ + ID: "a_c" + id, + ResourceType: resource.ConstantType, + }, + } + + err := constant.Validate() + if err != nil { + panic(err) + } + + return constant +} + +func genAgg(id string, aggregatorType string) (*context.Aggregate, *context.Aggregator) { + var outputType userconfig.OutputSchema = nil + if aggregatorType != "" { + outputType = mustValidateOutputSchema(aggregatorType) + } + + aggregator := &context.Aggregator{ + Aggregator: &userconfig.Aggregator{ + ResourceFields: userconfig.ResourceFields{ + Name: "aggregator" + id, + }, + OutputType: outputType, + }, + ResourceFields: &context.ResourceFields{ + ID: "d_aggregator" + id, + ResourceType: resource.AggregatorType, + }, + } + + aggregate := &context.Aggregate{ + Aggregate: &userconfig.Aggregate{ + ResourceFields: userconfig.ResourceFields{ + Name: "agg" + id, + }, + Aggregator: "aggregator" + id, + }, + Type: aggregator.OutputType, + ComputedResourceFields: &context.ComputedResourceFields{ + ResourceFields: &context.ResourceFields{ + ID: "c_agg" + id, + ResourceType: resource.AggregateType, + }, + }, + } + + return aggregate, aggregator +} + +func genTrans(id string, transformerType userconfig.ColumnType) (*context.TransformedColumn, *context.Transformer) { + transformer := &context.Transformer{ + Transformer: &userconfig.Transformer{ + ResourceFields: userconfig.ResourceFields{ + Name: "transformer" + id, + }, + OutputType: transformerType, + }, + ResourceFields: &context.ResourceFields{ + ID: "f_transformer" + id, + ResourceType: resource.TransformerType, + }, + } + + transformedCol := &context.TransformedColumn{ + TransformedColumn: &userconfig.TransformedColumn{ + ResourceFields: userconfig.ResourceFields{ + Name: "tc" + id, + }, + Transformer: "transformer" + id, + }, + Type: transformer.OutputType, + ComputedResourceFields: &context.ComputedResourceFields{ + ResourceFields: &context.ResourceFields{ + ID: "e_tc" + id, + ResourceType: resource.TransformedColumnType, + }, + }, + } + + return transformedCol, transformer +} + +var c1 = genConst("1", `STRING`, `test`) +var c2 = genConst("2", ``, `test`) +var c3 = genConst("3", ` + map: {INT: FLOAT} + map2: {a: FLOAT, b: FLOAT, c: INT} + str: STRING + floats: [FLOAT] + list: + - STRING: + lat: INT + lon: + a: [STRING] + `, ` + map: {2: 2.2, 3: 3} + map2: {a: 2.2, b: 3, c: 4} + str: test + floats: [1.1, 2.2, 3.3] + list: + - key_1: + lat: 17 + lon: + a: [test1, test2, test3] + key_2: + lat: 88 + lon: + a: [test4, test5, test6] + - key_a: + lat: 12 + lon: + a: [test7, test8, test9] + `) +var c4 = genConst("4", ``, ` + map: {2: 2.2, 3: 3} + map2: {a: 2, b: 3, c: 4} + str: test + floats: [1.1, 2.2, 3.3] + list: + - key_1: + lat: 17 + lon: + a: [test1, test2, test3] + key_2: + lat: 88 + lon: + a: [test4, test5, test6] + - key_a: + lat: 12 + lon: + a: [test7, test8, test9] + `) +var c5 = genConst("5", `FLOAT`, `2`) +var c6 = genConst("6", `INT`, `2`) +var c7 = genConst("7", ``, `2`) +var c8 = genConst("8", `{a: INT, b: INT}`, `{a: 1, b: 2}`) +var c9 = genConst("9", `{STRING: INT}`, `{a: 1, b: 2}`) +var ca = genConst("a", ``, `{a: 1, b: 2}`) +var cb = genConst("b", `[FLOAT]`, `[2]`) +var cc = genConst("c", `[INT]`, `[2]`) +var cd = genConst("d", ``, `[2]`) +var ce = genConst("e", `[STRING]`, `[a, b, c]`) +var cf = genConst("f", ``, `{a: [a, b, c]}`) + +var rc1 = &context.RawIntColumn{ + RawIntColumn: &userconfig.RawIntColumn{ + ResourceFields: userconfig.ResourceFields{ + Name: "rc1", + }, + Type: userconfig.IntegerColumnType, + }, + ComputedResourceFields: &context.ComputedResourceFields{ + ResourceFields: &context.ResourceFields{ + ID: "b_rc1", + ResourceType: resource.RawColumnType, + }, + }, +} + +// Type not specified +var rc2 = &context.RawInferredColumn{ + RawInferredColumn: &userconfig.RawInferredColumn{ + ResourceFields: userconfig.ResourceFields{ + Name: "rc2", + }, + Type: userconfig.InferredColumnType, + }, + ComputedResourceFields: &context.ComputedResourceFields{ + ResourceFields: &context.ResourceFields{ + ID: "b_rc2", + ResourceType: resource.RawColumnType, + }, + }, +} + +var rc3 = &context.RawStringColumn{ + RawStringColumn: &userconfig.RawStringColumn{ + ResourceFields: userconfig.ResourceFields{ + Name: "rc3", + }, + Type: userconfig.StringColumnType, + }, + ComputedResourceFields: &context.ComputedResourceFields{ + ResourceFields: &context.ResourceFields{ + ID: "b_rc3", + ResourceType: resource.RawColumnType, + }, + }, +} + +var rc4 = &context.RawFloatColumn{ + RawFloatColumn: &userconfig.RawFloatColumn{ + ResourceFields: userconfig.ResourceFields{ + Name: "rc4", + }, + Type: userconfig.FloatColumnType, + }, + ComputedResourceFields: &context.ComputedResourceFields{ + ResourceFields: &context.ResourceFields{ + ID: "b_rc4", + ResourceType: resource.RawColumnType, + }, + }, +} + +var agg1, aggregator1 = genAgg("1", `STRING`) +var agg2, aggregator2 = genAgg("2", ``) // Nil output type +var agg3, aggregator3 = genAgg("3", ` + map: {INT: FLOAT} + map2: {a: FLOAT, b: FLOAT, c: INT} + str: STRING + floats: [FLOAT] + list: + - STRING: + lat: INT + lon: + a: [STRING] + `) +var agg4, aggregator4 = genAgg("4", `INT`) +var agg5, aggregator5 = genAgg("5", `FLOAT`) +var agg6, aggregator6 = genAgg("6", `{a: INT, b: INT}`) +var agg7, aggregator7 = genAgg("7", `{STRING: INT}`) +var agg8, aggregator8 = genAgg("8", `[INT]`) +var agg9, aggregator9 = genAgg("9", `[FLOAT]`) +var agga, aggregatora = genAgg("a", `[STRING]`) +var aggb, aggregatorb = genAgg("b", `{a: [STRING]}`) + +var tc1, transformer1 = genTrans("1", userconfig.StringColumnType) +var tc2, transformer2 = genTrans("2", userconfig.InferredColumnType) +var tc3, transformer3 = genTrans("3", userconfig.IntegerColumnType) +var tc4, transformer4 = genTrans("4", userconfig.FloatColumnType) +var tc5, transformer5 = genTrans("5", userconfig.StringListColumnType) +var tc6, transformer6 = genTrans("6", userconfig.IntegerListColumnType) +var tc7, transformer7 = genTrans("7", userconfig.FloatListColumnType) + +var constants = []context.Resource{c1, c2, c3, c4, c5, c6, c7, c8, c9, ca, cb, cc, cd, ce, cf} +var rawCols = []context.Resource{rc1, rc2, rc3, rc4} +var aggregates = []context.Resource{agg1, agg2, agg3, agg4, agg5, agg6, agg7, agg8, agg9, agga, aggb} +var transformedCols = []context.Resource{tc1, tc2, tc3, tc4, tc5, tc6, tc7} + +var aggregators = context.Aggregators{ + aggregator1.GetName(): aggregator1, + aggregator2.GetName(): aggregator2, + aggregator3.GetName(): aggregator3, + aggregator4.GetName(): aggregator4, + aggregator5.GetName(): aggregator5, + aggregator6.GetName(): aggregator6, + aggregator7.GetName(): aggregator7, + aggregator8.GetName(): aggregator8, + aggregator9.GetName(): aggregator9, + aggregatora.GetName(): aggregatora, + aggregatorb.GetName(): aggregatorb, +} +var transformers = context.Transformers{ + transformer1.GetName(): transformer1, + transformer2.GetName(): transformer2, + transformer3.GetName(): transformer3, + transformer4.GetName(): transformer4, + transformer5.GetName(): transformer5, + transformer6.GetName(): transformer6, + transformer7.GetName(): transformer7, +} + +var allResources = append(append(append(constants, rawCols...), aggregates...), transformedCols...) + +var constantsMap = make(map[string]context.Resource) +var rawColsMap = make(map[string]context.Resource) +var aggregatesMap = make(map[string]context.Resource) +var transformedColsMap = make(map[string]context.Resource) +var allResourcesMap = make(map[string]context.Resource) + +var allResourceConfigsMap = make(map[string][]userconfig.Resource) + +func init() { + for _, res := range constants { + constantsMap[res.GetName()] = res + } + for _, res := range rawCols { + rawColsMap[res.GetName()] = res + } + for _, res := range aggregates { + aggregatesMap[res.GetName()] = res + } + for _, res := range transformedCols { + transformedColsMap[res.GetName()] = res + } + for _, res := range allResources { + allResourcesMap[res.GetName()] = res + } + + for _, res := range allResources { + allResourceConfigsMap[res.GetName()] = []userconfig.Resource{res} + } +} diff --git a/pkg/operator/context/resources.go b/pkg/operator/context/resources.go new file mode 100644 index 0000000000..6230e83898 --- /dev/null +++ b/pkg/operator/context/resources.go @@ -0,0 +1,408 @@ +/* +Copyright 2019 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package context + +import ( + "github.com/cortexlabs/yaml" + + "github.com/cortexlabs/cortex/pkg/lib/cast" + "github.com/cortexlabs/cortex/pkg/lib/configreader" + "github.com/cortexlabs/cortex/pkg/lib/errors" + "github.com/cortexlabs/cortex/pkg/lib/hash" + s "github.com/cortexlabs/cortex/pkg/lib/strings" + "github.com/cortexlabs/cortex/pkg/operator/api/context" + "github.com/cortexlabs/cortex/pkg/operator/api/resource" + "github.com/cortexlabs/cortex/pkg/operator/api/userconfig" +) + +func ValidateInput( + input interface{}, + schema *userconfig.InputSchema, + validResourceTypes []resource.Type, // this is just used for error messages + validResources []context.Resource, + allResources map[string][]userconfig.Resource, // key is resource name + aggregators context.Aggregators, + transformers context.Transformers, +) (interface{}, string, error) { + + validResourcesMap := make(map[string]context.Resource, len(validResources)) + for _, res := range validResources { + validResourcesMap[res.GetName()] = res + } + + inputWithIDs, err := validateResourceReferences(input, validResourceTypes, validResourcesMap, allResources) + if err != nil { + return nil, "", err + } + + castedInput := input + + // Skip validation if schema is nil (i.e. user didn't define the aggregator/transformer/estimator) + if schema != nil { + castedInput, err = validateRuntimeTypes(input, schema, validResourcesMap, aggregators, transformers, false) + if err != nil { + return nil, "", err + } + } + + return castedInput, hash.Any(inputWithIDs), nil +} + +// Return a copy of the input with all resource references replaced by their IDs +func validateResourceReferences( + input interface{}, + validResourceTypes []resource.Type, // this is just used for error messages + validResources map[string]context.Resource, // key is resource name + allResources map[string][]userconfig.Resource, // key is resource name +) (interface{}, error) { + + if input == nil { + return nil, nil + } + + if resourceName, ok := yaml.ExtractAtSymbolText(input); ok { + if res, ok := validResources[resourceName]; ok { + return res.GetID(), nil + } + + if len(allResources[resourceName]) > 0 { + return nil, userconfig.ErrorResourceWrongType(allResources[resourceName], validResourceTypes...) + } + + return nil, userconfig.ErrorUndefinedResource(resourceName, validResourceTypes...) + } + + if inputSlice, ok := cast.InterfaceToInterfaceSlice(input); ok { + sliceWithIDs := make([]interface{}, len(inputSlice)) + for i, elem := range inputSlice { + elemWithIDs, err := validateResourceReferences(elem, validResourceTypes, validResources, allResources) + if err != nil { + return nil, errors.Wrap(err, s.Index(i)) + } + sliceWithIDs[i] = elemWithIDs + } + return sliceWithIDs, nil + } + + if inputMap, ok := cast.InterfaceToInterfaceInterfaceMap(input); ok { + mapWithIDs := make(map[interface{}]interface{}, len(inputMap)) + for key, val := range inputMap { + keyWithIDs, err := validateResourceReferences(key, validResourceTypes, validResources, allResources) + if err != nil { + return nil, err + } + valWithIDs, err := validateResourceReferences(val, validResourceTypes, validResources, allResources) + if err != nil { + return nil, errors.Wrap(err, s.UserStrStripped(key)) + } + mapWithIDs[keyWithIDs] = valWithIDs + } + return mapWithIDs, nil + } + + return input, nil +} + +// resource references have already been validated to exist in validResources +func validateRuntimeTypes( + input interface{}, + schema *userconfig.InputSchema, + validResources map[string]context.Resource, // key is resource name + aggregators context.Aggregators, + transformers context.Transformers, + isNestedConstant bool, +) (interface{}, error) { + + // Check for null + if input == nil { + if schema.AllowNull { + return nil, nil + } + return nil, userconfig.ErrorCannotBeNull() + } + + // Check if input is Cortex resource + if resourceName, ok := yaml.ExtractAtSymbolText(input); ok { + res := validResources[resourceName] + if res == nil { + return nil, errors.New(resourceName, "missing resource") // unexpected + } + switch res.GetResourceType() { + case resource.ConstantType: + constant := res.(*context.Constant) + _, err := validateRuntimeTypes(constant.Value, schema, validResources, aggregators, transformers, true) + if err != nil { + return nil, errors.Wrap(err, userconfig.Identify(constant), userconfig.ValueKey) + } + return input, nil + case resource.RawColumnType: + rawColumn := res.(context.RawColumn) + if err := validateInputRuntimeOutputTypes(rawColumn.GetColumnType(), schema); err != nil { + return nil, errors.Wrap(err, userconfig.Identify(rawColumn), userconfig.TypeKey) + } + return input, nil + case resource.AggregateType: + aggregate := res.(*context.Aggregate) + aggregator := aggregators[aggregate.Aggregator] + if aggregator.OutputType != nil { + if err := validateInputRuntimeOutputTypes(aggregator.OutputType, schema); err != nil { + return nil, errors.Wrap(err, userconfig.Identify(aggregate), userconfig.Identify(aggregator), userconfig.OutputTypeKey) + } + } + return input, nil + case resource.TransformedColumnType: + transformedColumn := res.(*context.TransformedColumn) + transformer := transformers[transformedColumn.Transformer] + if err := validateInputRuntimeOutputTypes(transformer.OutputType, schema); err != nil { + return nil, errors.Wrap(err, userconfig.Identify(transformedColumn), userconfig.Identify(transformer), userconfig.OutputTypeKey) + } + return input, nil + default: + return nil, errors.New(res.GetResourceType().String(), "unsupported resource type") // unexpected + } + } + + typeSchema := schema.Type + + // CompoundType + if compoundType, ok := typeSchema.(userconfig.CompoundType); ok { + return compoundType.CastValue(input) + } + + // array of *InputSchema + if inputSchemas, ok := cast.InterfaceToInterfaceSlice(typeSchema); ok { + values, ok := cast.InterfaceToInterfaceSlice(input) + if !ok { + return nil, userconfig.ErrorUnsupportedLiteralType(input, typeSchema) + } + + if schema.MinCount != nil && int64(len(values)) < *schema.MinCount { + return nil, userconfig.ErrorTooFewElements(configreader.PrimTypeList, *schema.MinCount) + } + if schema.MaxCount != nil && int64(len(values)) > *schema.MaxCount { + return nil, userconfig.ErrorTooManyElements(configreader.PrimTypeList, *schema.MaxCount) + } + + valuesCasted := make([]interface{}, len(values)) + for i, valueItem := range values { + valueItemCasted, err := validateRuntimeTypes(valueItem, inputSchemas[0].(*userconfig.InputSchema), validResources, aggregators, transformers, false) + if err != nil { + return nil, errors.Wrap(err, s.Index(i)) + } + valuesCasted[i] = valueItemCasted + } + return valuesCasted, nil + } + + // Map + if typeSchemaMap, ok := cast.InterfaceToInterfaceInterfaceMap(typeSchema); ok { + valueMap, ok := cast.InterfaceToInterfaceInterfaceMap(input) + if !ok { + return nil, userconfig.ErrorUnsupportedLiteralType(input, typeSchema) + } + + var genericKey userconfig.CompoundType + var genericValue *userconfig.InputSchema + for k, v := range typeSchemaMap { + ok := false + if genericKey, ok = k.(userconfig.CompoundType); ok { + genericValue = v.(*userconfig.InputSchema) + } + } + + valueMapCasted := make(map[interface{}]interface{}, len(valueMap)) + + // Generic map + if genericValue != nil { + if schema.MinCount != nil && int64(len(valueMap)) < *schema.MinCount { + return nil, userconfig.ErrorTooFewElements(configreader.PrimTypeMap, *schema.MinCount) + } + if schema.MaxCount != nil && int64(len(valueMap)) > *schema.MaxCount { + return nil, userconfig.ErrorTooManyElements(configreader.PrimTypeMap, *schema.MaxCount) + } + + for valueKey, valueVal := range valueMap { + valueKeyCasted, err := validateRuntimeTypes(valueKey, &userconfig.InputSchema{Type: genericKey}, validResources, aggregators, transformers, false) + if err != nil { + return nil, err + } + valueValCasted, err := validateRuntimeTypes(valueVal, genericValue, validResources, aggregators, transformers, false) + if err != nil { + return nil, errors.Wrap(err, s.UserStrStripped(valueKey)) + } + valueMapCasted[valueKeyCasted] = valueValCasted + } + return valueMapCasted, nil + } + + // Fixed map + for typeSchemaKey, typeSchemaValue := range typeSchemaMap { + valueVal, ok := valueMap[typeSchemaKey] + if ok { + valueValCasted, err := validateRuntimeTypes(valueVal, typeSchemaValue.(*userconfig.InputSchema), validResources, aggregators, transformers, false) + if err != nil { + return nil, errors.Wrap(err, s.UserStrStripped(typeSchemaKey)) + } + valueMapCasted[typeSchemaKey] = valueValCasted + } else { + if !typeSchemaValue.(*userconfig.InputSchema).Optional { + return nil, userconfig.ErrorMustBeDefined(typeSchemaValue) + } + // don't set default (python has to) + } + } + if !isNestedConstant { + for valueKey := range valueMap { + if _, ok := typeSchemaMap[valueKey]; !ok { + return nil, userconfig.ErrorUnsupportedLiteralMapKey(valueKey, typeSchemaMap) + } + } + } + return valueMapCasted, nil + } + + return nil, userconfig.ErrorInvalidInputType(typeSchema) // unexpected +} + +// outputType should be ValueType|ColumnType, length-one array of , or map of {scalar|ValueType -> } +func validateInputRuntimeOutputTypes(outputType interface{}, schema *userconfig.InputSchema) error { + // Check for missing + if outputType == nil { + if schema.AllowNull { + return nil + } + return userconfig.ErrorCannotBeNull() + } + + typeSchema := schema.Type + + // CompoundType + if compoundType, ok := typeSchema.(userconfig.CompoundType); ok { + if !compoundType.SupportsType(outputType) { + return userconfig.ErrorUnsupportedOutputType(outputType, compoundType) + } + return nil + } + + // array of *InputSchema + if inputSchemas, ok := cast.InterfaceToInterfaceSlice(typeSchema); ok { + outputTypes, ok := cast.InterfaceToInterfaceSlice(outputType) + if !ok { + return userconfig.ErrorUnsupportedOutputType(outputType, inputSchemas) + } + + err := validateInputRuntimeOutputTypes(outputTypes[0], inputSchemas[0].(*userconfig.InputSchema)) + if err != nil { + return errors.Wrap(err, s.Index(0)) + } + return nil + } + + // Map + if typeSchemaMap, ok := cast.InterfaceToInterfaceInterfaceMap(typeSchema); ok { + outputTypeMap, ok := cast.InterfaceToInterfaceInterfaceMap(outputType) + if !ok { + return userconfig.ErrorUnsupportedOutputType(outputType, typeSchemaMap) + } + + var typeSchemaGenericKey userconfig.CompoundType + var typeSchemaGenericValue *userconfig.InputSchema + for k, v := range typeSchemaMap { + ok := false + if typeSchemaGenericKey, ok = k.(userconfig.CompoundType); ok { + typeSchemaGenericValue = v.(*userconfig.InputSchema) + } + } + + var outputTypeGenericKey userconfig.ValueType + var outputTypeGenericValue interface{} + for k, v := range outputTypeMap { + ok := false + if outputTypeGenericKey, ok = k.(userconfig.ValueType); ok { + outputTypeGenericValue = v + } + } + + // Check length if fixed outputType + if outputTypeGenericValue == nil { + if schema.MinCount != nil && int64(len(outputTypeMap)) < *schema.MinCount { + return userconfig.ErrorTooFewElements(configreader.PrimTypeMap, *schema.MinCount) + } + if schema.MaxCount != nil && int64(len(outputTypeMap)) > *schema.MaxCount { + return userconfig.ErrorTooManyElements(configreader.PrimTypeMap, *schema.MaxCount) + } + } + + // Generic schema map and generic outputType + if typeSchemaGenericValue != nil && outputTypeGenericValue != nil { + if err := validateInputRuntimeOutputTypes(outputTypeGenericKey, &userconfig.InputSchema{Type: typeSchemaGenericKey}); err != nil { + return err + } + if err := validateInputRuntimeOutputTypes(outputTypeGenericValue, typeSchemaGenericValue); err != nil { + return errors.Wrap(err, s.UserStrStripped(outputTypeGenericKey)) + } + return nil + } + + // Generic schema map and fixed outputType (we'll check the types of the fixed map) + if typeSchemaGenericValue != nil && outputTypeGenericValue == nil { + for outputTypeKey, outputTypeValue := range outputTypeMap { + if _, err := typeSchemaGenericKey.CastValue(outputTypeKey); err != nil { + return err + } + if err := validateInputRuntimeOutputTypes(outputTypeValue, typeSchemaGenericValue); err != nil { + return errors.Wrap(err, s.UserStrStripped(outputTypeKey)) + } + } + return nil + } + + // Generic outputType map and fixed schema map + if typeSchemaGenericValue == nil && outputTypeGenericValue != nil { + return userconfig.ErrorUnsupportedOutputType(outputType, typeSchemaMap) + // This code would allow for this case (for now we are considering it an error): + // for typeSchemaKey, typeSchemaValue := range typeSchemaMap { + // if _, err := outputTypeGenericKey.CastValue(typeSchemaKey); err != nil { + // return err + // } + // if err := validateInputRuntimeOutputTypes(outputTypeGenericValue, typeSchemaValue.(*userconfig.InputSchema)); err != nil { + // return errors.Wrap(err, s.UserStrStripped(typeSchemaKey)) + // } + // } + // return nil + } + + // Fixed outputType map and fixed schema map + if typeSchemaGenericValue == nil && outputTypeGenericValue == nil { + for typeSchemaKey, typeSchemaValue := range typeSchemaMap { + outputTypeValue, ok := outputTypeMap[typeSchemaKey] + if ok { + if err := validateInputRuntimeOutputTypes(outputTypeValue, typeSchemaValue.(*userconfig.InputSchema)); err != nil { + return errors.Wrap(err, s.UserStrStripped(typeSchemaKey)) + } + } else { + if !typeSchemaValue.(*userconfig.InputSchema).Optional { + return userconfig.ErrorMustBeDefined(typeSchemaValue) + } + } + } + return nil + } + } + + return userconfig.ErrorInvalidInputType(typeSchema) // unexpected +} diff --git a/pkg/operator/context/resources_test.go b/pkg/operator/context/resources_test.go new file mode 100644 index 0000000000..8f9c52a121 --- /dev/null +++ b/pkg/operator/context/resources_test.go @@ -0,0 +1,871 @@ +/* +Copyright 2019 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package context + +import ( + "testing" + + "github.com/stretchr/testify/require" + + cr "github.com/cortexlabs/cortex/pkg/lib/configreader" + "github.com/cortexlabs/cortex/pkg/operator/api/userconfig" +) + +func checkValidateRuntimeTypesEqual(t *testing.T, schemaYAML string, inputYAML string, expected interface{}) { + schema, err := userconfig.ValidateInputSchema(cr.MustReadYAMLStr(schemaYAML), false) + require.NoError(t, err) + input := cr.MustReadYAMLStr(inputYAML) + casted, err := validateRuntimeTypes(input, schema, allResourcesMap, aggregators, transformers, false) + require.NoError(t, err) + require.Equal(t, expected, casted) +} + +func checkValidateRuntimeTypesError(t *testing.T, schemaYAML string, inputYAML string) { + schema, err := userconfig.ValidateInputSchema(cr.MustReadYAMLStr(schemaYAML), false) + require.NoError(t, err) + input := cr.MustReadYAMLStr(inputYAML) + _, err = validateRuntimeTypes(input, schema, allResourcesMap, aggregators, transformers, false) + require.Error(t, err) +} + +func checkValidateRuntimeTypesNoError(t *testing.T, schemaYAML string, inputYAML string) { + schema, err := userconfig.ValidateInputSchema(cr.MustReadYAMLStr(schemaYAML), false) + require.NoError(t, err) + input := cr.MustReadYAMLStr(inputYAML) + _, err = validateRuntimeTypes(input, schema, allResourcesMap, aggregators, transformers, false) + require.NoError(t, err) +} + +func TestValidateRuntimeTypes(t *testing.T) { + + // Replacements + + checkValidateRuntimeTypesEqual(t, `STRING`, `@c1`, `🌝🌝🌝🌝🌝c1`) + + checkValidateRuntimeTypesEqual(t, `STRING`, `@c2`, `🌝🌝🌝🌝🌝c2`) + checkValidateRuntimeTypesEqual(t, `STRING|INT`, `@c2`, `🌝🌝🌝🌝🌝c2`) + checkValidateRuntimeTypesError(t, `INT`, `@c2`) + checkValidateRuntimeTypesError(t, `FLOAT`, `@c2`) + + checkValidateRuntimeTypesError(t, `STRING`, `@c3`) + checkValidateRuntimeTypesNoError(t, ` + map: {INT: FLOAT} + map2: {a: FLOAT, b: FLOAT, c: INT} + str: STRING + floats: [FLOAT] + list: + - STRING: + lat: INT + lon: + a: [STRING] + `, `@c3`) + checkValidateRuntimeTypesNoError(t, ` + map: {2: FLOAT, 3: FLOAT} + map2: {STRING: FLOAT} + str: STRING + floats: [FLOAT] + list: [{STRING: {lat: INT, lon: {a: [STRING]}}}] + `, `@c3`) + checkValidateRuntimeTypesError(t, ` + map: {2: FLOAT, 3: INT} + map2: {a: FLOAT, b: FLOAT, c: INT} + str: STRING + floats: [FLOAT] + list: [{STRING: {lat: INT, lon: {a: [STRING]}}}] + `, `@c3`) + checkValidateRuntimeTypesError(t, ` + map: {INT: FLOAT} + map2: {STRING: INT} + str: STRING + floats: [FLOAT] + list: [{STRING: {lat: INT, lon: {a: [STRING]}}}] + `, `@c3`) + + checkValidateRuntimeTypesError(t, `STRING`, `@c4`) + checkValidateRuntimeTypesNoError(t, ` + map: {INT: FLOAT} + map2: {a: FLOAT, b: FLOAT, c: INT} + str: STRING + floats: [FLOAT] + list: [{STRING: {lat: INT, lon: {a: [STRING]}}}] + `, `@c4`) + checkValidateRuntimeTypesNoError(t, ` + map: {2: FLOAT, 3: FLOAT} + map2: {STRING: FLOAT} + str: STRING + floats: [FLOAT] + list: [{STRING: {lat: INT, lon: {a: [STRING]}}}] + `, `@c4`) + checkValidateRuntimeTypesNoError(t, ` + map: {2: FLOAT, 3: INT} + map2: {STRING: INT} + str: STRING + floats: [FLOAT] + list: [{STRING: {lat: INT, lon: {a: [STRING]}}}] + `, `@c4`) + + checkValidateRuntimeTypesError(t, `INT`, `@c5`) + checkValidateRuntimeTypesEqual(t, `FLOAT`, `@c5`, `🌝🌝🌝🌝🌝c5`) + checkValidateRuntimeTypesEqual(t, `FLOAT|INT`, `@c5`, `🌝🌝🌝🌝🌝c5`) + checkValidateRuntimeTypesError(t, `STRING`, `@c5`) + checkValidateRuntimeTypesError(t, `BOOL`, `@c5`) + + checkValidateRuntimeTypesEqual(t, `INT`, `@c6`, `🌝🌝🌝🌝🌝c6`) + checkValidateRuntimeTypesEqual(t, `FLOAT`, `@c6`, `🌝🌝🌝🌝🌝c6`) + checkValidateRuntimeTypesEqual(t, `INT|FLOAT`, `@c6`, `🌝🌝🌝🌝🌝c6`) + checkValidateRuntimeTypesError(t, `STRING`, `@c6`) + checkValidateRuntimeTypesError(t, `BOOL`, `@c6`) + + checkValidateRuntimeTypesEqual(t, `INT`, `@c7`, `🌝🌝🌝🌝🌝c7`) + checkValidateRuntimeTypesEqual(t, `FLOAT`, `@c7`, `🌝🌝🌝🌝🌝c7`) + checkValidateRuntimeTypesEqual(t, `FLOAT|INT`, `@c7`, `🌝🌝🌝🌝🌝c7`) + checkValidateRuntimeTypesError(t, `STRING`, `@c7`) + checkValidateRuntimeTypesError(t, `BOOL`, `@c7`) + + checkValidateRuntimeTypesError(t, `[INT]`, `@cb`) + checkValidateRuntimeTypesEqual(t, `[FLOAT]`, `@cb`, `🌝🌝🌝🌝🌝cb`) + checkValidateRuntimeTypesEqual(t, `[FLOAT|INT]`, `@cb`, `🌝🌝🌝🌝🌝cb`) + checkValidateRuntimeTypesError(t, `[STRING]`, `@cb`) + checkValidateRuntimeTypesError(t, `[BOOL]`, `@cb`) + + checkValidateRuntimeTypesEqual(t, `[INT]`, `@cc`, `🌝🌝🌝🌝🌝cc`) + checkValidateRuntimeTypesEqual(t, `[FLOAT]`, `@cc`, `🌝🌝🌝🌝🌝cc`) + checkValidateRuntimeTypesEqual(t, `[INT|FLOAT]`, `@cc`, `🌝🌝🌝🌝🌝cc`) + checkValidateRuntimeTypesError(t, `[STRING]`, `@cc`) + checkValidateRuntimeTypesError(t, `[BOOL]`, `@cc`) + + checkValidateRuntimeTypesEqual(t, `[INT]`, `@cd`, `🌝🌝🌝🌝🌝cd`) + checkValidateRuntimeTypesEqual(t, `[FLOAT]`, `@cd`, `🌝🌝🌝🌝🌝cd`) + checkValidateRuntimeTypesEqual(t, `[FLOAT|INT]`, `@cd`, `🌝🌝🌝🌝🌝cd`) + checkValidateRuntimeTypesError(t, `[STRING]`, `@cd`) + checkValidateRuntimeTypesError(t, `[BOOL]`, `@cd`) + + checkValidateRuntimeTypesEqual(t, `{a: INT, b: INT}`, `@c8`, `🌝🌝🌝🌝🌝c8`) + checkValidateRuntimeTypesEqual(t, `{STRING: INT}`, `@c8`, `🌝🌝🌝🌝🌝c8`) + checkValidateRuntimeTypesEqual(t, `{a: FLOAT, b: INT}`, `@c8`, `🌝🌝🌝🌝🌝c8`) + checkValidateRuntimeTypesEqual(t, `{STRING: FLOAT}`, `@c8`, `🌝🌝🌝🌝🌝c8`) + checkValidateRuntimeTypesEqual(t, `{a: FLOAT}`, `@c8`, `🌝🌝🌝🌝🌝c8`) + checkValidateRuntimeTypesError(t, `{a: FLOAT, b: INT, c: INT}`, `@c8`) + + checkValidateRuntimeTypesEqual(t, `{a: INT, b: INT}`, `@c9`, `🌝🌝🌝🌝🌝c9`) + checkValidateRuntimeTypesEqual(t, `{STRING: INT}`, `@c9`, `🌝🌝🌝🌝🌝c9`) + checkValidateRuntimeTypesEqual(t, `{a: FLOAT, b: INT}`, `@c9`, `🌝🌝🌝🌝🌝c9`) + checkValidateRuntimeTypesEqual(t, `{STRING: FLOAT}`, `@c9`, `🌝🌝🌝🌝🌝c9`) + checkValidateRuntimeTypesEqual(t, `{a: FLOAT}`, `@c9`, `🌝🌝🌝🌝🌝c9`) + checkValidateRuntimeTypesError(t, `{a: FLOAT, b: INT, c: INT}`, `@c9`) + + checkValidateRuntimeTypesEqual(t, `{a: INT, b: INT}`, `@ca`, `🌝🌝🌝🌝🌝ca`) + checkValidateRuntimeTypesEqual(t, `{STRING: INT}`, `@ca`, `🌝🌝🌝🌝🌝ca`) + checkValidateRuntimeTypesEqual(t, `{a: FLOAT, b: INT}`, `@ca`, `🌝🌝🌝🌝🌝ca`) + checkValidateRuntimeTypesEqual(t, `{STRING: FLOAT}`, `@ca`, `🌝🌝🌝🌝🌝ca`) + checkValidateRuntimeTypesEqual(t, `{a: FLOAT}`, `@ca`, `🌝🌝🌝🌝🌝ca`) + checkValidateRuntimeTypesError(t, `{a: FLOAT, b: INT, c: INT}`, `@ca`) + + checkValidateRuntimeTypesEqual(t, `INT_COLUMN`, `@rc1`, `🌝🌝🌝🌝🌝rc1`) + checkValidateRuntimeTypesEqual(t, `FLOAT_COLUMN`, `@rc1`, `🌝🌝🌝🌝🌝rc1`) + checkValidateRuntimeTypesError(t, `STRING_COLUMN`, `@rc1`) + checkValidateRuntimeTypesEqual(t, `FLOAT_COLUMN|STRING_COLUMN`, `@rc1`, `🌝🌝🌝🌝🌝rc1`) + checkValidateRuntimeTypesError(t, `INT_LIST_COLUMN`, `@rc1`) + checkValidateRuntimeTypesError(t, `[INT_COLUMN]`, `@rc1`) + checkValidateRuntimeTypesError(t, `{INT_COLUMN: INT}`, `@rc1`) + checkValidateRuntimeTypesError(t, `{k: INT_COLUMN}`, `@rc1`) + + checkValidateRuntimeTypesEqual(t, `INT_COLUMN`, `@rc2`, `🌝🌝🌝🌝🌝rc2`) + checkValidateRuntimeTypesEqual(t, `FLOAT_COLUMN`, `@rc2`, `🌝🌝🌝🌝🌝rc2`) + checkValidateRuntimeTypesEqual(t, `STRING_COLUMN`, `@rc2`, `🌝🌝🌝🌝🌝rc2`) + checkValidateRuntimeTypesEqual(t, `INT_LIST_COLUMN`, `@rc2`, `🌝🌝🌝🌝🌝rc2`) + checkValidateRuntimeTypesEqual(t, `STRING_LIST_COLUMN`, `@rc2`, `🌝🌝🌝🌝🌝rc2`) + checkValidateRuntimeTypesError(t, `[INT_COLUMN]`, `@rc2`) + checkValidateRuntimeTypesError(t, `[STRING_COLUMN]`, `@rc2`) + checkValidateRuntimeTypesError(t, `{INT_COLUMN: INT}`, `@rc2`) + checkValidateRuntimeTypesError(t, `{k: INT_COLUMN}`, `@rc2`) + + checkValidateRuntimeTypesError(t, `INT_COLUMN`, `@rc3`) + checkValidateRuntimeTypesError(t, `FLOAT_COLUMN`, `@rc3`) + checkValidateRuntimeTypesEqual(t, `STRING_COLUMN`, `@rc3`, `🌝🌝🌝🌝🌝rc3`) + checkValidateRuntimeTypesError(t, `STRING_LIST_COLUMN`, `@rc3`) + checkValidateRuntimeTypesError(t, `[STRING_COLUMN]`, `@rc3`) + + checkValidateRuntimeTypesError(t, `INT_COLUMN`, `@rc4`) + checkValidateRuntimeTypesEqual(t, `FLOAT_COLUMN`, `@rc4`, `🌝🌝🌝🌝🌝rc4`) + checkValidateRuntimeTypesEqual(t, `INT_COLUMN|FLOAT_COLUMN`, `@rc4`, `🌝🌝🌝🌝🌝rc4`) + checkValidateRuntimeTypesError(t, `STRING_COLUMN`, `@rc4`) + checkValidateRuntimeTypesError(t, `FLOAT_LIST_COLUMN`, `@rc4`) + checkValidateRuntimeTypesError(t, `[FLOAT_COLUMN]`, `@rc4`) + + checkValidateRuntimeTypesEqual(t, `STRING`, `@agg1`, `🌝🌝🌝🌝🌝agg1`) + checkValidateRuntimeTypesEqual(t, `INT|STRING`, `@agg1`, `🌝🌝🌝🌝🌝agg1`) + checkValidateRuntimeTypesError(t, `INT`, `@agg1`) + + checkValidateRuntimeTypesEqual(t, `STRING`, `@agg2`, `🌝🌝🌝🌝🌝agg2`) + checkValidateRuntimeTypesEqual(t, `INT`, `@agg2`, `🌝🌝🌝🌝🌝agg2`) + checkValidateRuntimeTypesEqual(t, `{INT: BOOL}`, `@agg2`, `🌝🌝🌝🌝🌝agg2`) + + checkValidateRuntimeTypesError(t, `STRING`, `@agg3`) + checkValidateRuntimeTypesNoError(t, ` + map: {INT: FLOAT} + map2: {a: FLOAT, b: FLOAT, c: INT} + str: STRING + floats: [FLOAT] + list: + - STRING: + lat: INT + lon: + a: [STRING] + `, `@agg3`) + checkValidateRuntimeTypesError(t, ` + map: {INT: INT} + map2: {a: FLOAT, b: FLOAT, c: INT} + str: STRING + floats: [FLOAT] + list: [{STRING: {lat: INT, lon: {a: [STRING]}}}] + `, `@agg3`) + checkValidateRuntimeTypesNoError(t, ` + map: {FLOAT: FLOAT} + map2: {a: FLOAT, b: FLOAT, c: FLOAT} + str: STRING + floats: [FLOAT] + list: [{STRING: {lat: INT, lon: {a: [STRING]}}}] + `, `@agg3`) + checkValidateRuntimeTypesNoError(t, ` + map: {INT: FLOAT} + map2: {STRING: FLOAT} + str: STRING + floats: [FLOAT] + list: [{STRING: {lat: INT, lon: {a: [STRING]}}}] + `, `@agg3`) + checkValidateRuntimeTypesError(t, ` + map: {INT: FLOAT} + map2: {STRING: INT} + str: STRING + floats: [FLOAT] + list: [{STRING: {lat: INT, lon: {a: [STRING]}}}] + `, `@agg3`) + checkValidateRuntimeTypesError(t, ` + map: {2: FLOAT, 3: FLOAT} + map2: {a: FLOAT, b: FLOAT, c: INT} + str: STRING + floats: [FLOAT] + list: [{STRING: {lat: INT, lon: {a: [STRING]}}}] + `, `@agg3`) + + checkValidateRuntimeTypesEqual(t, `INT`, `@agg4`, `🌝🌝🌝🌝🌝agg4`) + checkValidateRuntimeTypesEqual(t, `FLOAT`, `@agg4`, `🌝🌝🌝🌝🌝agg4`) + checkValidateRuntimeTypesEqual(t, `FLOAT|INT`, `@agg4`, `🌝🌝🌝🌝🌝agg4`) + checkValidateRuntimeTypesError(t, `STRING`, `@agg4`) + checkValidateRuntimeTypesError(t, `BOOL`, `@agg4`) + + checkValidateRuntimeTypesError(t, `INT`, `@agg5`) + checkValidateRuntimeTypesEqual(t, `FLOAT`, `@agg5`, `🌝🌝🌝🌝🌝agg5`) + checkValidateRuntimeTypesEqual(t, `INT|FLOAT`, `@agg5`, `🌝🌝🌝🌝🌝agg5`) + checkValidateRuntimeTypesError(t, `STRING`, `@agg5`) + checkValidateRuntimeTypesError(t, `BOOL`, `@agg5`) + + checkValidateRuntimeTypesEqual(t, `[INT]`, `@agg8`, `🌝🌝🌝🌝🌝agg8`) + checkValidateRuntimeTypesEqual(t, `[FLOAT]`, `@agg8`, `🌝🌝🌝🌝🌝agg8`) + checkValidateRuntimeTypesEqual(t, `[FLOAT|INT]`, `@agg8`, `🌝🌝🌝🌝🌝agg8`) + checkValidateRuntimeTypesError(t, `[STRING]`, `@agg8`) + checkValidateRuntimeTypesError(t, `[BOOL]`, `@agg8`) + + checkValidateRuntimeTypesError(t, `[INT]`, `@agg9`) + checkValidateRuntimeTypesEqual(t, `[FLOAT]`, `@agg9`, `🌝🌝🌝🌝🌝agg9`) + checkValidateRuntimeTypesEqual(t, `[INT|FLOAT]`, `@agg9`, `🌝🌝🌝🌝🌝agg9`) + checkValidateRuntimeTypesError(t, `[STRING]`, `@agg9`) + checkValidateRuntimeTypesError(t, `[BOOL]`, `@agg9`) + + checkValidateRuntimeTypesEqual(t, `{a: INT, b: INT}`, `@agg6`, `🌝🌝🌝🌝🌝agg6`) + checkValidateRuntimeTypesEqual(t, `{STRING: INT}`, `@agg6`, `🌝🌝🌝🌝🌝agg6`) + checkValidateRuntimeTypesEqual(t, `{a: FLOAT, b: INT}`, `@agg6`, `🌝🌝🌝🌝🌝agg6`) + checkValidateRuntimeTypesEqual(t, `{STRING: FLOAT}`, `@agg6`, `🌝🌝🌝🌝🌝agg6`) + checkValidateRuntimeTypesEqual(t, `{a: FLOAT}`, `@agg6`, `🌝🌝🌝🌝🌝agg6`) + checkValidateRuntimeTypesError(t, `{a: FLOAT, b: INT, c: INT}`, `@agg6`) + + checkValidateRuntimeTypesError(t, `{a: INT, b: INT}`, `@agg7`) + checkValidateRuntimeTypesEqual(t, `{STRING: INT}`, `@agg7`, `🌝🌝🌝🌝🌝agg7`) + checkValidateRuntimeTypesEqual(t, `{STRING: FLOAT}`, `@agg7`, `🌝🌝🌝🌝🌝agg7`) + + checkValidateRuntimeTypesEqual(t, `STRING_COLUMN`, `@tc1`, `🌝🌝🌝🌝🌝tc1`) + checkValidateRuntimeTypesError(t, `INT_COLUMN`, `@tc1`) + checkValidateRuntimeTypesEqual(t, `INT_COLUMN|STRING_COLUMN`, `@tc1`, `🌝🌝🌝🌝🌝tc1`) + checkValidateRuntimeTypesError(t, `FLOAT_COLUMN`, `@tc1`) + checkValidateRuntimeTypesError(t, `STRING_LIST_COLUMN`, `@tc1`) + checkValidateRuntimeTypesError(t, `INT_LIST_COLUMN`, `@tc1`) + checkValidateRuntimeTypesError(t, `FLOAT_LIST_COLUMN`, `@tc1`) + checkValidateRuntimeTypesError(t, `[STRING_COLUMN]`, `@tc1`) + checkValidateRuntimeTypesError(t, `[INT_COLUMN]`, `@tc1`) + checkValidateRuntimeTypesError(t, `{STRING_COLUMN: INT}`, `@tc1`) + checkValidateRuntimeTypesError(t, `{k: STRING_COLUMN}`, `@tc1`) + + checkValidateRuntimeTypesEqual(t, `STRING_COLUMN`, `@tc2`, `🌝🌝🌝🌝🌝tc2`) + checkValidateRuntimeTypesEqual(t, `INT_COLUMN`, `@tc2`, `🌝🌝🌝🌝🌝tc2`) + checkValidateRuntimeTypesEqual(t, `FLOAT_COLUMN`, `@tc2`, `🌝🌝🌝🌝🌝tc2`) + checkValidateRuntimeTypesEqual(t, `STRING_LIST_COLUMN`, `@tc2`, `🌝🌝🌝🌝🌝tc2`) + checkValidateRuntimeTypesEqual(t, `INT_LIST_COLUMN`, `@tc2`, `🌝🌝🌝🌝🌝tc2`) + checkValidateRuntimeTypesEqual(t, `FLOAT_LIST_COLUMN`, `@tc2`, `🌝🌝🌝🌝🌝tc2`) + checkValidateRuntimeTypesError(t, `[STRING_COLUMN]`, `@tc2`) + checkValidateRuntimeTypesError(t, `[INT_COLUMN]`, `@tc2`) + checkValidateRuntimeTypesError(t, `{STRING_COLUMN: INT}`, `@tc2`) + checkValidateRuntimeTypesError(t, `{k: STRING_COLUMN}`, `@tc2`) + + checkValidateRuntimeTypesError(t, `STRING_COLUMN`, `@tc3`) + checkValidateRuntimeTypesEqual(t, `INT_COLUMN`, `@tc3`, `🌝🌝🌝🌝🌝tc3`) + checkValidateRuntimeTypesEqual(t, `FLOAT_COLUMN`, `@tc3`, `🌝🌝🌝🌝🌝tc3`) + checkValidateRuntimeTypesError(t, `STRING_LIST_COLUMN`, `@tc3`) + checkValidateRuntimeTypesError(t, `INT_LIST_COLUMN`, `@tc3`) + checkValidateRuntimeTypesError(t, `FLOAT_LIST_COLUMN`, `@tc3`) + checkValidateRuntimeTypesError(t, `[STRING_COLUMN]`, `@tc3`) + checkValidateRuntimeTypesError(t, `[INT_COLUMN]`, `@tc3`) + checkValidateRuntimeTypesError(t, `{STRING_COLUMN: INT}`, `@tc3`) + checkValidateRuntimeTypesError(t, `{k: STRING_COLUMN}`, `@tc3`) + + checkValidateRuntimeTypesError(t, `STRING_COLUMN`, `@tc4`) + checkValidateRuntimeTypesError(t, `INT_COLUMN`, `@tc4`) + checkValidateRuntimeTypesEqual(t, `FLOAT_COLUMN`, `@tc4`, `🌝🌝🌝🌝🌝tc4`) + checkValidateRuntimeTypesEqual(t, `INT_COLUMN|FLOAT_COLUMN`, `@tc4`, `🌝🌝🌝🌝🌝tc4`) + checkValidateRuntimeTypesError(t, `STRING_LIST_COLUMN`, `@tc4`) + checkValidateRuntimeTypesError(t, `INT_LIST_COLUMN`, `@tc4`) + checkValidateRuntimeTypesError(t, `FLOAT_LIST_COLUMN`, `@tc4`) + checkValidateRuntimeTypesError(t, `[STRING_COLUMN]`, `@tc4`) + checkValidateRuntimeTypesError(t, `[INT_COLUMN]`, `@tc4`) + checkValidateRuntimeTypesError(t, `{STRING_COLUMN: INT}`, `@tc4`) + checkValidateRuntimeTypesError(t, `{k: STRING_COLUMN}`, `@tc4`) + + checkValidateRuntimeTypesError(t, `STRING_COLUMN`, `@tc5`) + checkValidateRuntimeTypesError(t, `INT_COLUMN`, `@tc5`) + checkValidateRuntimeTypesError(t, `FLOAT_COLUMN`, `@tc5`) + checkValidateRuntimeTypesEqual(t, `STRING_LIST_COLUMN`, `@tc5`, `🌝🌝🌝🌝🌝tc5`) + checkValidateRuntimeTypesEqual(t, `STRING_LIST_COLUMN|INT_COLUMN`, `@tc5`, `🌝🌝🌝🌝🌝tc5`) + checkValidateRuntimeTypesError(t, `INT_LIST_COLUMN`, `@tc5`) + checkValidateRuntimeTypesError(t, `FLOAT_LIST_COLUMN`, `@tc5`) + checkValidateRuntimeTypesError(t, `[STRING_COLUMN]`, `@tc5`) + checkValidateRuntimeTypesError(t, `[INT_COLUMN]`, `@tc5`) + checkValidateRuntimeTypesError(t, `{STRING_COLUMN: INT}`, `@tc5`) + checkValidateRuntimeTypesError(t, `{k: STRING_COLUMN}`, `@tc5`) + + checkValidateRuntimeTypesError(t, `STRING_COLUMN`, `@tc6`) + checkValidateRuntimeTypesError(t, `INT_COLUMN`, `@tc6`) + checkValidateRuntimeTypesError(t, `FLOAT_COLUMN`, `@tc6`) + checkValidateRuntimeTypesError(t, `STRING_LIST_COLUMN`, `@tc6`) + checkValidateRuntimeTypesEqual(t, `INT_LIST_COLUMN`, `@tc6`, `🌝🌝🌝🌝🌝tc6`) + checkValidateRuntimeTypesEqual(t, `FLOAT_LIST_COLUMN`, `@tc6`, `🌝🌝🌝🌝🌝tc6`) + checkValidateRuntimeTypesError(t, `[STRING_COLUMN]`, `@tc6`) + checkValidateRuntimeTypesError(t, `[INT_COLUMN]`, `@tc6`) + checkValidateRuntimeTypesError(t, `{STRING_COLUMN: INT}`, `@tc6`) + checkValidateRuntimeTypesError(t, `{k: STRING_COLUMN}`, `@tc6`) + + checkValidateRuntimeTypesError(t, `STRING_COLUMN`, `@tc7`) + checkValidateRuntimeTypesError(t, `INT_COLUMN`, `@tc7`) + checkValidateRuntimeTypesError(t, `FLOAT_COLUMN`, `@tc7`) + checkValidateRuntimeTypesError(t, `STRING_LIST_COLUMN`, `@tc7`) + checkValidateRuntimeTypesError(t, `INT_LIST_COLUMN`, `@tc7`) + checkValidateRuntimeTypesEqual(t, `FLOAT_LIST_COLUMN`, `@tc7`, `🌝🌝🌝🌝🌝tc7`) + checkValidateRuntimeTypesEqual(t, `FLOAT_LIST_COLUMN|INT_LIST_COLUMN`, `@tc7`, `🌝🌝🌝🌝🌝tc7`) + checkValidateRuntimeTypesError(t, `[STRING_COLUMN]`, `@tc7`) + checkValidateRuntimeTypesError(t, `[INT_COLUMN]`, `@tc7`) + checkValidateRuntimeTypesError(t, `{STRING_COLUMN: INT}`, `@tc7`) + checkValidateRuntimeTypesError(t, `{k: STRING_COLUMN}`, `@tc7`) + + checkValidateRuntimeTypesEqual(t, + `[INT_COLUMN]`, + `[@tc3, @rc1]`, + []interface{}{"🌝🌝🌝🌝🌝tc3", "🌝🌝🌝🌝🌝rc1"}) + + checkValidateRuntimeTypesEqual(t, + `[FLOAT_COLUMN]`, + `[@tc3, @rc1, @tc4, @rc4]`, + []interface{}{"🌝🌝🌝🌝🌝tc3", "🌝🌝🌝🌝🌝rc1", "🌝🌝🌝🌝🌝tc4", "🌝🌝🌝🌝🌝rc4"}) + + checkValidateRuntimeTypesEqual(t, + `[FLOAT]`, + `[@c5, @c6, 2, 2.2, @agg4, @agg5]`, + []interface{}{"🌝🌝🌝🌝🌝c5", "🌝🌝🌝🌝🌝c6", float64(2), float64(2.2), "🌝🌝🌝🌝🌝agg4", "🌝🌝🌝🌝🌝agg5"}) + + checkValidateRuntimeTypesEqual(t, + `[FLOAT]`, + `@cb`, + "🌝🌝🌝🌝🌝cb") + + checkValidateRuntimeTypesEqual(t, + `[FLOAT]`, + `@cc`, + "🌝🌝🌝🌝🌝cc") + + checkValidateRuntimeTypesEqual(t, + `[FLOAT|INT]`, + `[@c5, @c6, 2, 2.2, @agg4, @agg5]`, + []interface{}{"🌝🌝🌝🌝🌝c5", "🌝🌝🌝🌝🌝c6", int64(2), float64(2.2), "🌝🌝🌝🌝🌝agg4", "🌝🌝🌝🌝🌝agg5"}) + + checkValidateRuntimeTypesEqual(t, + `{2: INT_COLUMN, 3: INT}`, + `{2: @tc3, 3: @agg4}`, + map[interface{}]interface{}{int64(2): "🌝🌝🌝🌝🌝tc3", int64(3): "🌝🌝🌝🌝🌝agg4"}) + + checkValidateRuntimeTypesEqual(t, + `{2: FLOAT_COLUMN, 3: FLOAT}`, + `{2: @tc3, 3: @agg4}`, + map[interface{}]interface{}{int64(2): "🌝🌝🌝🌝🌝tc3", int64(3): "🌝🌝🌝🌝🌝agg4"}) + + checkValidateRuntimeTypesEqual(t, + `{FLOAT: FLOAT_COLUMN}`, + `{2: @tc3, 3: @tc4, @agg4: @rc1, @agg5: @rc2, @c5: @rc4, @c6: @tc2}`, + map[interface{}]interface{}{ + float64(2): "🌝🌝🌝🌝🌝tc3", + float64(3): "🌝🌝🌝🌝🌝tc4", + "🌝🌝🌝🌝🌝agg4": "🌝🌝🌝🌝🌝rc1", + "🌝🌝🌝🌝🌝agg5": "🌝🌝🌝🌝🌝rc2", + "🌝🌝🌝🌝🌝c5": "🌝🌝🌝🌝🌝rc4", + "🌝🌝🌝🌝🌝c6": "🌝🌝🌝🌝🌝tc2", + }) + + checkValidateRuntimeTypesNoError(t, ` + FLOAT: + map: {INT: FLOAT} + map2: {a: FLOAT, b: FLOAT, c: INT} + str: STRING + floats: [FLOAT] + list: + - STRING: + lat: INT + lon: + a: [STRING] + `, ` + @agg4: @agg3 + @agg5: @c3 + @c5: @c4 + @c6: + map: {2: 2.2, 3: 3} + map2: {a: 2.2, b: 3, c: 4} + str: test + floats: [1.1, 2.2, 3.3] + list: + - key_1: + lat: 17 + lon: + a: [test1, test2, test3] + key_2: + lat: 88 + lon: + a: [test4, test5, test6] + - key_a: + lat: 12 + lon: + a: [test7, test8, test9] + 2.2: + map: {2: 2.2, @c6: @c6, @agg4: @agg5, 3: @c5, 4: @agg5} + map2: {a: 2.2, b: @c5, c: @agg4} + str: @c1 + floats: [@c5, @c6, 2, 2.2, @agg4, @agg5] + list: + - key_1: + lat: @c6 + lon: + a: [test1, @agg1, test3] + @agg1: + lat: 88 + lon: + a: @ce + - @c1: + lat: @agg4 + lon: + a: @agga + key_2: + lat: 17 + lon: @cf + key_3: + lat: 41 + lon: @aggb + `) + + checkValidateRuntimeTypesError(t, ` + FLOAT: + map: {INT: FLOAT} + map2: {a: FLOAT, b: FLOAT, c: INT} + str: STRING + floats: [FLOAT] + list: + - STRING: + lat: INT + lon: + a: [STRING] + `, ` + 2.2: + map: {2: 2.2, @c6: @c6, @agg4: @agg5, 3: @c5, 4: @agg5} + map2: {a: 2.2, b: @c5, c: @agg5} + str: @c1 + floats: [@c5, @c6, 2, 2.2, @agg4, @agg5] + list: + - key_1: + lat: @c6 + lon: + a: [test1, @agg1, test3] + @agg1: + lat: 88 + lon: + a: @ce + `) + + checkValidateRuntimeTypesNoError(t, ` + - a: FLOAT_COLUMN + b: INT_COLUMN|STRING_COLUMN + c: {1: INT_COLUMN, 2: FLOAT_COLUMN, 3: BOOL, 4: STRING} + d: {INT: INT_COLUMN} + e: {FLOAT_COLUMN: FLOAT|STRING} + f: {INT_LIST_COLUMN|STRING_COLUMN: FLOAT_COLUMN} + g: [FLOAT] + `, ` + - a: @tc4 + b: @rc3 + c: {1: @rc1, 2: @tc3, 3: true, 4: @agg1} + d: {1: @rc1, 2: @tc3, @c6: @rc2, @agg4: @tc2} + e: {@tc3: @agg4, @rc4: test, @tc2: 2.2, @rc1: @c1, @tc4: @agg5, @rc2: 2} + f: {@tc6: @tc4, @tc1: @rc2, @rc3: @rc1} + g: [@c5, @c6, 2, 2.2, @agg4, @agg5] + `) + + checkValidateRuntimeTypesError(t, ` + - a: FLOAT_COLUMN + b: INT_COLUMN|STRING_COLUMN + c: {1: INT_COLUMN, 2: FLOAT_COLUMN, 3: BOOL, 4: STRING} + d: {INT: INT_COLUMN} + e: {FLOAT_COLUMN: FLOAT|STRING} + f: {INT_LIST_COLUMN|STRING_COLUMN: FLOAT_COLUMN} + g: [FLOAT] + `, ` + - a: @tc4 + b: @rc3 + c: {1: @rc1, 2: @tc3, 3: true, 4: @agg1} + d: {1: @rc1, 2: @tc3, @c6: @rc2, @agg4: @tc2} + e: {@tc3: @agg4, @rc4: test, @tc2: 2.2, @rc1: @c1, @tc4: @agg5, @rc2: 2} + f: {@tc7: @tc4, @tc1: @rc2, @rc3: @rc1} + g: [@c5, @c6, 2, 2.2, @agg4, @agg5] + `) + + // No replacements + + checkValidateRuntimeTypesEqual(t, `INT`, `2`, int64(2)) + checkValidateRuntimeTypesError(t, `INT`, `test`) + checkValidateRuntimeTypesError(t, `INT`, `2.2`) + checkValidateRuntimeTypesEqual(t, `FLOAT`, `2`, float64(2)) + checkValidateRuntimeTypesError(t, `FLOAT`, `test`) + checkValidateRuntimeTypesEqual(t, `BOOL`, `true`, true) + checkValidateRuntimeTypesEqual(t, `STRING`, `str`, "str") + checkValidateRuntimeTypesError(t, `STRING`, `1`) + + checkValidateRuntimeTypesEqual(t, `{STRING: FLOAT}`, `{test: 2.2, test2: 4.4}`, + map[interface{}]interface{}{"test": 2.2, "test2": 4.4}) + checkValidateRuntimeTypesError(t, `{STRING: FLOAT}`, `{test: test2}`) + checkValidateRuntimeTypesEqual(t, `{STRING: FLOAT}`, `{test: 2}`, + map[interface{}]interface{}{"test": float64(2)}) + checkValidateRuntimeTypesEqual(t, `{STRING: INT}`, `{test: 2}`, + map[interface{}]interface{}{"test": int64(2)}) + checkValidateRuntimeTypesError(t, `{STRING: INT}`, `{test: 2.0}`) + + checkValidateRuntimeTypesEqual(t, `{mean: FLOAT, sum: INT}`, `{mean: 2.2, sum: 4}`, + map[interface{}]interface{}{"mean": float64(2.2), "sum": int64(4)}) + checkValidateRuntimeTypesError(t, `{mean: FLOAT, sum: INT}`, `{mean: 2.2, sum: test}`) + checkValidateRuntimeTypesError(t, `{mean: FLOAT, sum: INT}`, `{mean: false, sum: 4}`) + checkValidateRuntimeTypesError(t, `{mean: FLOAT, sum: INT}`, `{mean: 2.2, 2: 4}`) + checkValidateRuntimeTypesError(t, `{mean: FLOAT, sum: INT}`, `{mean: 2.2, sum: Null}`) + checkValidateRuntimeTypesError(t, `{mean: FLOAT, sum: INT}`, `{mean: 2.2}`) + checkValidateRuntimeTypesError(t, `{mean: FLOAT, sum: INT}`, `{mean: 2.2, sum: 4, stddev: 2}`) + + checkValidateRuntimeTypesEqual(t, `[INT]`, `[1, 2]`, + []interface{}{int64(1), int64(2)}) + checkValidateRuntimeTypesError(t, `[INT]`, `[1.0, 2]`) + checkValidateRuntimeTypesEqual(t, `[FLOAT]`, `[1.0, 2]`, + []interface{}{float64(1), float64(2)}) + + schemaYAML := + ` + map: {STRING: FLOAT} + str: STRING + floats: [FLOAT] + map2: + STRING: + lat: FLOAT + lon: + a: INT + b: [STRING] + c: {mean: FLOAT, sum: [INT], stddev: {STRING: INT}} + bools: [BOOL] + ` + + checkValidateRuntimeTypesNoError(t, schemaYAML, ` + map: {a: 2.2, b: 3} + str: test1 + floats: [2.2, 3.3, 4.4] + map2: + testA: + lat: 9.9 + lon: + a: 17 + b: [test1, test2, test3] + c: {mean: 8.8, sum: [3, 2, 1], stddev: {a: 1, b: 2}} + bools: [true] + testB: + lat: 3.14 + lon: + a: 88 + b: [testX, testY, testZ] + c: {mean: 1.7, sum: [1], stddev: {z: 12}} + bools: [true, false, true] + `) + + checkValidateRuntimeTypesError(t, schemaYAML, ` + map: {a: 2.2, b: 3} + str: test1 + floats: [2.2, 3.3, 4.4] + map2: + testA: + lat: 9.9 + lon: + a: 17 + b: [test1, test2, test3] + c: {mean: 8.8, sum: [3, 2, 1], stddev: {a: 1, b: 2}} + bools: [true] + testB: + lat: 3.14 + lon: + b: [testX, testY, testZ] + c: {mean: 1.7, sum: [1], stddev: {z: 12}} + bools: [true, false, true] + `) + + checkValidateRuntimeTypesError(t, schemaYAML, ` + map: {a: 2.2, b: 3} + str: test1 + floats: [2.2, 3.3, 4.4] + map2: + testA: + lat: 9.9 + lon: + a: 17 + b: [test1, test2, test3] + c: {mean: 8.8, sum: [3, 2, 1], stddev: {a: 1, b: 2}} + bools: [true] + testB: + lat: 3.14 + lon: + a: 88.8 + b: [testX, testY, testZ] + c: {mean: 1.7, sum: [1], stddev: {z: 12}} + bools: [true, false, true] + `) + + checkValidateRuntimeTypesError(t, schemaYAML, ` + map: {a: 2.2, b: 3} + str: test1 + floats: [2.2, 3.3, 4.4] + map2: + testA: + lat: 9.9 + lon: + a: 17 + b: [test1, test2, test3] + c: {mean: 8.8, sum: [3, 2, 1], stddev: {a: 1, b: 2}} + bools: [true] + testB: + lat: 3.14 + lon: + a: 88 + b: [testX, testY, 2] + c: {mean: 1.7, sum: [1], stddev: {z: 12}} + bools: [true, false, true] + `) + + checkValidateRuntimeTypesError(t, schemaYAML, ` + map: {a: 2.2, b: 3} + str: test1 + floats: [2.2, 3.3, 4.4] + map2: + testA: + lat: 9.9 + lon: + a: 17 + b: [test1, test2, test3] + c: {mean: 8.8, sum: [3, 2, 1], stddev: {a: 1, b: test}} + bools: [true] + testB: + lat: 3.14 + lon: + a: 88 + b: [testX, testY, testZ] + c: {mean: 1.7, sum: [1], stddev: {z: 12}} + bools: [true, false, true] + `) + + checkValidateRuntimeTypesError(t, schemaYAML, ` + map: {a: 2.2, b: 3} + str: test1 + floats: [2.2, 3.3, 4.4] + map2: + testA: + lat: 9.9 + lon: + a: 17 + b: [test1, test2, test3] + c: {mean: 8.8, sum: [3, 2, 1], stddev: {a: 1, b: 2}} + bools: [true] + testB: + lat: 3.14 + lon: + a: 88 + b: [testX, testY, testZ] + c: {mean: 1.7, sum: [1], stddev: {z: 12}} + bools: true + `) + + checkValidateRuntimeTypesError(t, schemaYAML, ` + map: {a: 2.2, b: 3} + str: test1 + floats: [2.2, 3.3, 4.4] + map2: + testA: + lat: 9.9 + lon: + a: 17 + b: [test1, test2, test3] + c: {mean: 8.8, sum: [3, 2, 1], stddev: {a: 1, b: 2}} + bools: [true] + testB: + lat: 3.14 + lon: + a: 88 + b: [testX, testY, testZ] + c: {mean: 1.7, sum: [1], stddev: {z: 12}} + bools: [1, 2, 3] + `) + + checkValidateRuntimeTypesEqual(t, `FLOAT|INT`, `2`, int64(2)) + checkValidateRuntimeTypesEqual(t, `INT|FLOAT`, `2`, int64(2)) + checkValidateRuntimeTypesEqual(t, `FLOAT|INT`, `2.2`, float64(2.2)) + checkValidateRuntimeTypesEqual(t, `INT|FLOAT`, `2.2`, float64(2.2)) + checkValidateRuntimeTypesError(t, `STRING`, `2`) + checkValidateRuntimeTypesEqual(t, `STRING|FLOAT`, `2`, float64(2)) + checkValidateRuntimeTypesEqual(t, `{_type: [INT], _max_count: 2}`, `[2]`, []interface{}{int64(2)}) + checkValidateRuntimeTypesError(t, `{_type: [INT], _max_count: 2}`, `[2, 3, 4]`) + checkValidateRuntimeTypesEqual(t, `{_type: [INT], _min_count: 2}`, `[2, 3, 4]`, []interface{}{int64(2), int64(3), int64(4)}) + checkValidateRuntimeTypesError(t, `{_type: [INT], _min_count: 2}`, `[2]`) + checkValidateRuntimeTypesError(t, `{_type: INT, _optional: true}`, `Null`) + checkValidateRuntimeTypesError(t, `{_type: INT, _optional: true}`, ``) + checkValidateRuntimeTypesEqual(t, `{_type: INT, _allow_null: true}`, `Null`, nil) + checkValidateRuntimeTypesEqual(t, `{_type: INT, _allow_null: true}`, ``, nil) + checkValidateRuntimeTypesError(t, `{_type: {a: INT}}`, `Null`) + checkValidateRuntimeTypesError(t, `{_type: {a: INT}, _optional: true}`, `Null`) + checkValidateRuntimeTypesEqual(t, `{_type: {a: INT}, _allow_null: true}`, `Null`, nil) + checkValidateRuntimeTypesEqual(t, `{_type: {a: INT}}`, `{a: 2}`, map[interface{}]interface{}{"a": int64(2)}) + checkValidateRuntimeTypesError(t, `{_type: {a: INT}}`, `{a: Null}`) + checkValidateRuntimeTypesError(t, `{a: {_type: INT, _optional: false}}`, `{a: Null}`) + checkValidateRuntimeTypesError(t, `{a: {_type: INT, _optional: false}}`, `{}`) + checkValidateRuntimeTypesError(t, `{a: {_type: INT, _optional: true}}`, `{a: Null}`) + checkValidateRuntimeTypesEqual(t, `{a: {_type: INT, _optional: true}}`, `{}`, map[interface{}]interface{}{}) + checkValidateRuntimeTypesEqual(t, `{a: {_type: INT, _allow_null: true}}`, `{a: Null}`, map[interface{}]interface{}{"a": nil}) + checkValidateRuntimeTypesError(t, `{a: {_type: INT, _allow_null: true}}`, `{}`) + checkValidateRuntimeTypesEqual(t, `{a: {_type: INT, _allow_null: true, _optional: true}}`, `{}`, map[interface{}]interface{}{}) +} + +func TestValidateResourceReferences(t *testing.T) { + var input, replaced interface{} + var err error + + input = cr.MustReadYAMLStr(`@bad`) + _, err = validateResourceReferences(input, nil, allResourcesMap, allResourceConfigsMap) + require.Error(t, err) + + input = cr.MustReadYAMLStr(`@rc1?test`) + _, err = validateResourceReferences(input, nil, allResourcesMap, allResourceConfigsMap) + require.Error(t, err) + + input = cr.MustReadYAMLStr(`@rc1 test`) + _, err = validateResourceReferences(input, nil, allResourcesMap, allResourceConfigsMap) + require.Error(t, err) + + input = cr.MustReadYAMLStr(`[@rc1 test]`) + _, err = validateResourceReferences(input, nil, allResourcesMap, allResourceConfigsMap) + require.Error(t, err) + + input = cr.MustReadYAMLStr(`@rc1 test: 2.2`) + _, err = validateResourceReferences(input, nil, allResourcesMap, allResourceConfigsMap) + require.Error(t, err) + + input = cr.MustReadYAMLStr(`2.2: @rc1 test`) + _, err = validateResourceReferences(input, nil, allResourcesMap, allResourceConfigsMap) + require.Error(t, err) + + input = cr.MustReadYAMLStr(`str`) + replaced, err = validateResourceReferences(input, nil, allResourcesMap, allResourceConfigsMap) + require.NoError(t, err) + require.Equal(t, "str", replaced) + + input = cr.MustReadYAMLStr(`@rc1`) + replaced, err = validateResourceReferences(input, nil, rawColsMap, allResourceConfigsMap) + require.NoError(t, err) + require.Equal(t, "b_rc1", replaced) + replaced, err = validateResourceReferences(input, nil, allResourcesMap, allResourceConfigsMap) + require.NoError(t, err) + require.Equal(t, "b_rc1", replaced) + _, err = validateResourceReferences(input, nil, transformedColsMap, allResourceConfigsMap) + require.Error(t, err) + _, err = validateResourceReferences(input, nil, nil, allResourceConfigsMap) + require.Error(t, err) + + input = cr.MustReadYAMLStr(`[@tc1, rc2, @rc1]`) + replaced, err = validateResourceReferences(input, nil, allResourcesMap, allResourceConfigsMap) + require.NoError(t, err) + require.Equal(t, []interface{}{"e_tc1", "rc2", "b_rc1"}, replaced) + _, err = validateResourceReferences(input, nil, rawColsMap, allResourceConfigsMap) + require.Error(t, err) + _, err = validateResourceReferences(input, nil, transformedColsMap, allResourceConfigsMap) + require.Error(t, err) + + input = cr.MustReadYAMLStr(`{@c5: 1, @agg4: @c6, @c6: @agg5}`) + replaced, err = validateResourceReferences(input, nil, allResourcesMap, allResourceConfigsMap) + require.NoError(t, err) + require.Equal(t, map[interface{}]interface{}{"a_c5": int64(1), "c_agg4": "a_c6", "a_c6": "c_agg5"}, replaced) + + input = cr.MustReadYAMLStr(`[@tc1, @bad, @rc1]`) + _, err = validateResourceReferences(input, nil, allResourcesMap, allResourceConfigsMap) + require.Error(t, err) + + input = cr.MustReadYAMLStr( + ` + map: {@agg1: @c1} + str: @rc1 + floats: [@tc2] + map2: + map3: + lat: @c2 + lon: + @c3: agg2 + b: [@tc1, @agg3] + `) + replaced, err = validateResourceReferences(input, nil, allResourcesMap, allResourceConfigsMap) + require.NoError(t, err) + expected := cr.MustReadYAMLStr( + ` + map: {c_agg1: a_c1} + str: b_rc1 + floats: [e_tc2] + map2: + map3: + lat: a_c2 + lon: + a_c3: agg2 + b: [e_tc1, c_agg3] + `) + require.Equal(t, expected, replaced) +} diff --git a/pkg/operator/context/transformed_columns.go b/pkg/operator/context/transformed_columns.go index 21b7dad144..1a325c667d 100644 --- a/pkg/operator/context/transformed_columns.go +++ b/pkg/operator/context/transformed_columns.go @@ -21,7 +21,6 @@ import ( "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/hash" - s "github.com/cortexlabs/cortex/pkg/lib/strings" "github.com/cortexlabs/cortex/pkg/operator/api/context" "github.com/cortexlabs/cortex/pkg/operator/api/resource" "github.com/cortexlabs/cortex/pkg/operator/api/userconfig" @@ -32,53 +31,50 @@ func getTransformedColumns( constants context.Constants, rawColumns context.RawColumns, aggregates context.Aggregates, - userTransformers map[string]*context.Transformer, + aggregators context.Aggregators, + transformers context.Transformers, root string, ) (context.TransformedColumns, error) { transformedColumns := context.TransformedColumns{} for _, transformedColumnConfig := range config.TransformedColumns { - transformer, err := getTransformer(transformedColumnConfig.Transformer, userTransformers) - if err != nil { - return nil, errors.Wrap(err, userconfig.Identify(transformedColumnConfig), userconfig.TransformerKey) - } + transformer := transformers[transformedColumnConfig.Transformer] - err = validateTransformedColumnInputs(transformedColumnConfig, constants, rawColumns, aggregates, transformer) - if err != nil { - return nil, errors.WithStack(err) + var validInputResources []context.Resource + for _, res := range constants { + validInputResources = append(validInputResources, res) + } + for _, res := range rawColumns { + validInputResources = append(validInputResources, res) + } + for _, res := range aggregates { + validInputResources = append(validInputResources, res) } - valueResourceIDMap := make(map[string]string, len(transformedColumnConfig.Inputs.Args)) - valueResourceIDWithTagsMap := make(map[string]string, len(transformedColumnConfig.Inputs.Args)) - for argName, resourceName := range transformedColumnConfig.Inputs.Args { - resourceNameStr := resourceName.(string) - resource, err := context.GetValueResource(resourceNameStr, constants, aggregates) - if err != nil { - return nil, errors.Wrap(err, userconfig.Identify(transformedColumnConfig), userconfig.InputsKey, userconfig.ArgsKey, argName) - } - valueResourceIDMap[argName] = resource.GetID() - valueResourceIDWithTagsMap[argName] = resource.GetIDWithTags() + castedInput, inputID, err := ValidateInput( + transformedColumnConfig.Input, + transformer.Input, + []resource.Type{resource.RawColumnType, resource.ConstantType, resource.AggregateType}, + validInputResources, + config.Resources, + aggregators, + nil, + ) + if err != nil { + return nil, errors.Wrap(err, userconfig.Identify(transformedColumnConfig), userconfig.InputKey) } + transformedColumnConfig.Input = castedInput var buf bytes.Buffer - buf.WriteString(rawColumns.ColumnInputsID(transformedColumnConfig.Inputs.Columns)) - buf.WriteString(s.Obj(valueResourceIDMap)) + buf.WriteString(inputID) buf.WriteString(transformer.ID) id := hash.Bytes(buf.Bytes()) - buf.Reset() - buf.WriteString(rawColumns.ColumnInputsIDWithTags(transformedColumnConfig.Inputs.Columns)) - buf.WriteString(s.Obj(valueResourceIDWithTagsMap)) - buf.WriteString(transformer.IDWithTags) - buf.WriteString(transformedColumnConfig.Tags.ID()) - idWithTags := hash.Bytes(buf.Bytes()) - transformedColumns[transformedColumnConfig.Name] = &context.TransformedColumn{ ComputedResourceFields: &context.ComputedResourceFields{ ResourceFields: &context.ResourceFields{ ID: id, - IDWithTags: idWithTags, ResourceType: resource.TransformedColumnType, }, }, @@ -89,57 +85,3 @@ func getTransformedColumns( return transformedColumns, nil } - -func validateTransformedColumnInputs( - transformedColumnConfig *userconfig.TransformedColumn, - constants context.Constants, - rawColumns context.RawColumns, - aggregates context.Aggregates, - transformer *context.Transformer, -) error { - if transformedColumnConfig.TransformerPath != nil { - return nil - } - - columnRuntimeTypes, err := context.GetColumnRuntimeTypes(transformedColumnConfig.Inputs.Columns, rawColumns) - if err != nil { - return errors.Wrap(err, userconfig.Identify(transformedColumnConfig), userconfig.InputsKey, userconfig.ColumnsKey) - } - err = userconfig.CheckColumnRuntimeTypesMatch(columnRuntimeTypes, transformer.Inputs.Columns) - if err != nil { - return errors.Wrap(err, userconfig.Identify(transformedColumnConfig), userconfig.InputsKey, userconfig.ColumnsKey) - } - - argTypes, err := getTransformedColumnArgTypes(transformedColumnConfig.Inputs.Args, constants, aggregates) - if err != nil { - return errors.Wrap(err, userconfig.Identify(transformedColumnConfig), userconfig.InputsKey, userconfig.ArgsKey) - } - err = userconfig.CheckArgRuntimeTypesMatch(argTypes, transformer.Inputs.Args) - if err != nil { - return errors.Wrap(err, userconfig.Identify(transformedColumnConfig), userconfig.InputsKey, userconfig.ArgsKey) - } - - return nil -} - -func getTransformedColumnArgTypes( - args map[string]interface{}, - constants context.Constants, - aggregates context.Aggregates, -) (map[string]interface{}, error) { - - if len(args) == 0 { - return nil, nil - } - - argTypes := make(map[string]interface{}, len(args)) - for argName, valueResourceName := range args { - valueResourceNameStr := valueResourceName.(string) - valueResource, err := context.GetValueResource(valueResourceNameStr, constants, aggregates) - if err != nil { - return nil, errors.Wrap(err, argName) - } - argTypes[argName] = valueResource.GetType() - } - return argTypes, nil -} diff --git a/pkg/operator/context/transformers.go b/pkg/operator/context/transformers.go index d19aac4fbd..713ccabeec 100644 --- a/pkg/operator/context/transformers.go +++ b/pkg/operator/context/transformers.go @@ -39,7 +39,7 @@ func loadUserTransformers( for _, transConfig := range config.Transformers { impl, ok := impls[transConfig.Path] if !ok { - return nil, errors.Wrap(ErrorImplDoesNotExist(transConfig.Path), userconfig.Identify(transConfig)) + return nil, errors.Wrap(userconfig.ErrorImplDoesNotExist(transConfig.Path), userconfig.Identify(transConfig)) } transformer, err := newTransformer(*transConfig, impl, nil, pythonPackages) if err != nil { @@ -55,7 +55,7 @@ func loadUserTransformers( impl, ok := impls[*transColConfig.TransformerPath] if !ok { - return nil, errors.Wrap(ErrorImplDoesNotExist(*transColConfig.TransformerPath), userconfig.Identify(transColConfig)) + return nil, errors.Wrap(userconfig.ErrorImplDoesNotExist(*transColConfig.TransformerPath), userconfig.Identify(transColConfig)) } implHash := hash.Bytes(impl) @@ -74,9 +74,11 @@ func loadUserTransformers( if err != nil { return nil, err } + transColConfig.Transformer = transformer.Name userTransformers[transformer.Name] = transformer } + return userTransformers, nil } @@ -90,9 +92,10 @@ func newTransformer( implID := hash.Bytes(impl) var buf bytes.Buffer - buf.WriteString(context.DataTypeID(transConfig.Inputs)) + buf.WriteString(context.DataTypeID(transConfig.Input)) buf.WriteString(context.DataTypeID(transConfig.OutputType)) buf.WriteString(implID) + for _, pythonPackage := range pythonPackages { buf.WriteString(pythonPackage.GetID()) } @@ -102,14 +105,12 @@ func newTransformer( transformer := &context.Transformer{ ResourceFields: &context.ResourceFields{ ID: id, - IDWithTags: id, ResourceType: resource.TransformerType, }, Transformer: &transConfig, Namespace: namespace, ImplKey: filepath.Join(consts.TransformersDir, implID+".py"), } - transformer.Transformer.Path = "" if err := uploadTransformer(transformer, impl); err != nil { return nil, err @@ -139,19 +140,6 @@ func uploadTransformer(transformer *context.Transformer, impl []byte) error { return nil } -func getTransformer( - name string, - userTransformers map[string]*context.Transformer, -) (*context.Transformer, error) { - if transformer, ok := builtinTransformers[name]; ok { - return transformer, nil - } - if transformer, ok := userTransformers[name]; ok { - return transformer, nil - } - return nil, userconfig.ErrorUndefinedResourceBuiltin(name, resource.TransformerType) -} - func getTransformers( config *userconfig.Config, userTransformers map[string]*context.Transformer, @@ -159,15 +147,23 @@ func getTransformers( transformers := context.Transformers{} for _, transformedColumnConfig := range config.TransformedColumns { - if _, ok := transformers[transformedColumnConfig.Transformer]; ok { + name := transformedColumnConfig.Transformer + + if _, ok := transformers[name]; ok { continue } - transformer, err := getTransformer(transformedColumnConfig.Transformer, userTransformers) - if err != nil { - return nil, errors.Wrap(err, userconfig.Identify(transformedColumnConfig), userconfig.TransformerKey) + if transformer, ok := builtinTransformers[name]; ok { + transformers[name] = transformer + continue } - transformers[transformedColumnConfig.Transformer] = transformer + + if transformer, ok := userTransformers[name]; ok { + transformers[name] = transformer + continue + } + + return nil, errors.Wrap(userconfig.ErrorUndefinedResource(name, resource.TransformerType), userconfig.Identify(transformedColumnConfig), userconfig.TransformerKey) } return transformers, nil diff --git a/pkg/operator/workloads/workload_spec.go b/pkg/operator/workloads/workload_spec.go index b01450590d..29525dfb33 100644 --- a/pkg/operator/workloads/workload_spec.go +++ b/pkg/operator/workloads/workload_spec.go @@ -58,7 +58,10 @@ func uploadWorkloadSpec(workloadSpec *WorkloadSpec, ctx *context.Context) error resources := make(map[string]*context.ResourceFields) for resourceID := range workloadSpec.ResourceIDs { resource := ctx.OneResourceByID(resourceID) - resources[resourceID] = resource.GetResourceFields() + resources[resourceID] = &context.ResourceFields{ + ID: resource.GetID(), + ResourceType: resource.GetResourceType(), + } } savedWorkloadSpec := SavedWorkloadSpec{ From 7c3a7c1a82557eb8d1f52b4951b357605de4bdc7 Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Thu, 6 Jun 2019 17:23:52 -0700 Subject: [PATCH 02/44] Update iris example, fix small bugs --- examples/iris/resources/aggregates.yaml | 36 ++--- examples/iris/resources/apis.yaml | 2 +- examples/iris/resources/models.yaml | 17 +- examples/iris/resources/raw_columns.yaml | 12 +- .../iris/resources/transformed_columns.yaml | 48 +++--- examples/poker/resources/raw_columns.yaml | 20 --- pkg/aggregators/aggregators.yaml | 147 ++++++------------ pkg/estimators/estimators.yaml | 15 ++ pkg/operator/api/userconfig/errors.go | 6 +- pkg/operator/api/userconfig/models.go | 2 +- pkg/operator/context/models.go | 2 +- pkg/transformers/transformers.yaml | 26 ++-- 12 files changed, 120 insertions(+), 213 deletions(-) create mode 100644 pkg/estimators/estimators.yaml diff --git a/examples/iris/resources/aggregates.yaml b/examples/iris/resources/aggregates.yaml index 83c2a8d78d..18c05b211d 100644 --- a/examples/iris/resources/aggregates.yaml +++ b/examples/iris/resources/aggregates.yaml @@ -1,62 +1,44 @@ - kind: aggregate name: sepal_length_mean aggregator: cortex.mean - inputs: - columns: - col: sepal_length + input: @sepal_length - kind: aggregate name: sepal_length_stddev aggregator: cortex.stddev - inputs: - columns: - col: sepal_length + input: @sepal_length - kind: aggregate name: sepal_width_mean aggregator: cortex.mean - inputs: - columns: - col: sepal_width + input: @sepal_width - kind: aggregate name: sepal_width_stddev aggregator: cortex.stddev - inputs: - columns: - col: sepal_width + input: @sepal_width - kind: aggregate name: petal_length_mean aggregator: cortex.mean - inputs: - columns: - col: petal_length + input: @petal_length - kind: aggregate name: petal_length_stddev aggregator: cortex.stddev - inputs: - columns: - col: petal_length + input: @petal_length - kind: aggregate name: petal_width_mean aggregator: cortex.mean - inputs: - columns: - col: petal_width + input: @petal_width - kind: aggregate name: petal_width_stddev aggregator: cortex.stddev - inputs: - columns: - col: petal_width + input: @petal_width - kind: aggregate name: class_index aggregator: cortex.index_string - inputs: - columns: - col: class + input: @class diff --git a/examples/iris/resources/apis.yaml b/examples/iris/resources/apis.yaml index d53768a038..7272bd6fe8 100644 --- a/examples/iris/resources/apis.yaml +++ b/examples/iris/resources/apis.yaml @@ -1,5 +1,5 @@ - kind: api name: iris-type - model_name: dnn + model: @dnn compute: replicas: 1 diff --git a/examples/iris/resources/models.yaml b/examples/iris/resources/models.yaml index 0c303b09c0..eedacfa142 100644 --- a/examples/iris/resources/models.yaml +++ b/examples/iris/resources/models.yaml @@ -1,12 +1,14 @@ - kind: model name: dnn - type: classification - target_column: class_indexed - feature_columns: - - sepal_length_normalized - - sepal_width_normalized - - petal_length_normalized - - petal_width_normalized + estimator_path: implementations/models/dnn.py + target_column: @class_indexed + input: + cols: + - @sepal_length_normalized + - @sepal_width_normalized + - @petal_length_normalized + - @petal_width_normalized + num_classes: 3 hparams: hidden_units: [4, 2] data_partition_ratio: @@ -15,4 +17,3 @@ training: batch_size: 10 num_steps: 1000 - aggregates: [class_index] diff --git a/examples/iris/resources/raw_columns.yaml b/examples/iris/resources/raw_columns.yaml index e31cdf048e..0991b9cc0a 100644 --- a/examples/iris/resources/raw_columns.yaml +++ b/examples/iris/resources/raw_columns.yaml @@ -3,7 +3,7 @@ data: type: csv path: s3a://cortex-examples/iris.csv - schema: [sepal_length, sepal_width, petal_length, petal_width, class] + schema: [@sepal_length, @sepal_width, @petal_length, @petal_width, @class] - kind: environment @@ -13,15 +13,15 @@ path: s3a://cortex-examples/iris.parquet schema: - parquet_column_name: sepal_length - raw_column_name: sepal_length + raw_column: @sepal_length - parquet_column_name: sepal_width - raw_column_name: sepal_width + raw_column: @sepal_width - parquet_column_name: petal_length - raw_column_name: petal_length + raw_column: @petal_length - parquet_column_name: petal_width - raw_column_name: petal_width + raw_column: @petal_width - parquet_column_name: class - raw_column_name: class + raw_column: @class - kind: raw_column name: sepal_length diff --git a/examples/iris/resources/transformed_columns.yaml b/examples/iris/resources/transformed_columns.yaml index 92a4472444..ddb96bfa9e 100644 --- a/examples/iris/resources/transformed_columns.yaml +++ b/examples/iris/resources/transformed_columns.yaml @@ -1,48 +1,38 @@ - kind: transformed_column name: sepal_length_normalized transformer: cortex.normalize - inputs: - columns: - num: sepal_length - args: - mean: sepal_length_mean - stddev: sepal_length_stddev + input: + col: @sepal_length + mean: @sepal_length_mean + stddev: @sepal_length_stddev - kind: transformed_column name: sepal_width_normalized transformer: cortex.normalize - inputs: - columns: - num: sepal_width - args: - mean: sepal_width_mean - stddev: sepal_width_stddev + input: + col: @sepal_width + mean: @sepal_width_mean + stddev: @sepal_width_stddev - kind: transformed_column name: petal_length_normalized transformer: cortex.normalize - inputs: - columns: - num: petal_length - args: - mean: petal_length_mean - stddev: petal_length_stddev + input: + col: @petal_length + mean: @petal_length_mean + stddev: @petal_length_stddev - kind: transformed_column name: petal_width_normalized transformer: cortex.normalize - inputs: - columns: - num: petal_width - args: - mean: petal_width_mean - stddev: petal_width_stddev + input: + col: @petal_width + mean: @petal_width_mean + stddev: @petal_width_stddev - kind: transformed_column name: class_indexed transformer: cortex.index_string - inputs: - columns: - text: class - args: - indexes: class_index + input: + col: @class + indexes: @class_index diff --git a/examples/poker/resources/raw_columns.yaml b/examples/poker/resources/raw_columns.yaml index 815f3e7c6d..ac3d6d2096 100644 --- a/examples/poker/resources/raw_columns.yaml +++ b/examples/poker/resources/raw_columns.yaml @@ -2,71 +2,51 @@ name: card_1_suit type: INT_COLUMN required: true - tags: - type: suit - kind: raw_column name: card_1_rank type: INT_COLUMN required: true - tags: - type: rank - kind: raw_column name: card_2_suit type: INT_COLUMN required: true - tags: - type: suit - kind: raw_column name: card_2_rank type: INT_COLUMN required: true - tags: - type: rank - kind: raw_column name: card_3_suit type: INT_COLUMN required: true - tags: - type: suit - kind: raw_column name: card_3_rank type: INT_COLUMN required: true - tags: - type: rank - kind: raw_column name: card_4_suit type: INT_COLUMN required: true - tags: - type: suit - kind: raw_column name: card_4_rank type: INT_COLUMN required: true - tags: - type: rank - kind: raw_column name: card_5_suit type: INT_COLUMN required: true - tags: - type: suit - kind: raw_column name: card_5_rank type: INT_COLUMN required: true - tags: - type: rank - kind: raw_column name: class diff --git a/pkg/aggregators/aggregators.yaml b/pkg/aggregators/aggregators.yaml index c8881bd0cf..afe99e85b0 100644 --- a/pkg/aggregators/aggregators.yaml +++ b/pkg/aggregators/aggregators.yaml @@ -19,11 +19,9 @@ name: approx_count_distinct path: spark/approx_count_distinct.py output_type: INT - inputs: - columns: - col: FLOAT_COLUMN|INT_COLUMN|STRING_COLUMN - args: - rsd: FLOAT + input: + col: FLOAT_COLUMN|INT_COLUMN|STRING_COLUMN + rsd: FLOAT # Spark Builtin: Calculate the average of the column. # source: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.avg @@ -31,9 +29,7 @@ name: avg path: spark/avg.py output_type: FLOAT - inputs: - columns: - col: FLOAT_COLUMN|INT_COLUMN + input: FLOAT_COLUMN|INT_COLUMN # Spark Builtin: Accumalate all of the unique int values in the int col. # source: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.collect_set @@ -41,9 +37,7 @@ name: collect_set_int path: spark/collect_set_int.py output_type: [INT] - inputs: - columns: - col: INT_COLUMN + input: INT_COLUMN # Spark Builtin: Accumalate all of the unique int values in the float col. # source: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.collect_set @@ -51,9 +45,7 @@ name: collect_set_float path: spark/collect_set_float.py output_type: [FLOAT] - inputs: - columns: - col: FLOAT_COLUMN + input: FLOAT_COLUMN # Spark Builtin: Accumalate all of the unique int values in the string col. # source: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.collect_set @@ -61,9 +53,7 @@ name: collect_set_string path: spark/collect_set_string.py output_type: [STRING] - inputs: - columns: - col: STRING_COLUMN + input: STRING_COLUMN # Spark Builtin: Count the number of values in the column. # source: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.count @@ -71,9 +61,7 @@ name: count path: spark/count.py output_type: INT - inputs: - columns: - col: INT_COLUMN|FLOAT_COLUMN|STRING_COLUMN + input: INT_COLUMN|FLOAT_COLUMN|STRING_COLUMN # Spark Builtin: Given a group of columns, count the unique rows in the group columns. # source: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.countDistinct @@ -81,10 +69,11 @@ name: count_distinct path: spark/count_distinct.py output_type: INT - inputs: - columns: - col: INT_COLUMN|FLOAT_COLUMN|STRING_COLUMN - cols: [INT_COLUMN|FLOAT_COLUMN|STRING_COLUMN] + input: + col: INT_COLUMN|FLOAT_COLUMN|STRING_COLUMN + cols: + _type: [INT_COLUMN|FLOAT_COLUMN|STRING_COLUMN] + _default: [] # Spark Builtin: Calculate the population covariance between col1 and col2 (scaled by 1/N). # source: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.covar_pop @@ -92,10 +81,9 @@ name: covar_pop path: spark/covar_pop.py output_type: FLOAT - inputs: - columns: - col1: INT_COLUMN|FLOAT_COLUMN - col2: INT_COLUMN|FLOAT_COLUMN + input: + col1: INT_COLUMN|FLOAT_COLUMN + col2: INT_COLUMN|FLOAT_COLUMN # Spark Builtin: Calculate the sample covariance between col1 and col2 (scaled by 1/(N-1)). @@ -104,10 +92,9 @@ name: covar_samp path: spark/covar_samp.py output_type: FLOAT - inputs: - columns: - col1: INT_COLUMN|FLOAT_COLUMN - col2: INT_COLUMN|FLOAT_COLUMN + input: + col1: INT_COLUMN|FLOAT_COLUMN + col2: INT_COLUMN|FLOAT_COLUMN # Spark Builtin: Calculate the sharpness of the peak of a frequency-distribution of the input column. # source: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.kurtosis @@ -115,9 +102,7 @@ name: kurtosis path: spark/kurtosis.py output_type: FLOAT - inputs: - columns: - col: INT_COLUMN|FLOAT_COLUMN + input: INT_COLUMN|FLOAT_COLUMN # Spark Builtin: Get the max value of the input int column. # source: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.max @@ -125,9 +110,7 @@ name: max_int path: spark/max_int.py output_type: INT - inputs: - columns: - col: INT_COLUMN + input: INT_COLUMN # Spark Builtin: Get the max value of the input float column. # source: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.max @@ -135,9 +118,7 @@ name: max_float path: spark/max_float.py output_type: FLOAT - inputs: - columns: - col: FLOAT_COLUMN + input: FLOAT_COLUMN # Spark Builtin: Get the max value of the input string column. # source: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.max @@ -145,9 +126,7 @@ name: max_string path: spark/max_string.py output_type: STRING - inputs: - columns: - col: STRING_COLUMN + input: STRING_COLUMN # Spark Builtin: Calculate the mean of the values in the input column. # source: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.mean @@ -155,9 +134,7 @@ name: mean path: spark/mean.py output_type: FLOAT - inputs: - columns: - col: FLOAT_COLUMN|INT_COLUMN + input: FLOAT_COLUMN|INT_COLUMN # Spark Builtin: Get the min value of the intput int column. # source: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.min @@ -165,9 +142,7 @@ name: min_int path: spark/min_int.py output_type: INT - inputs: - columns: - col: INT_COLUMN + input: INT_COLUMN # Spark Builtin: Get the min value of the input float column. # source: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.min @@ -175,9 +150,7 @@ name: min_float path: spark/min_float.py output_type: FLOAT - inputs: - columns: - col: FLOAT_COLUMN + input: FLOAT_COLUMN # Spark Builtin: Get the min value of the input string column. # source: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.min @@ -185,9 +158,7 @@ name: min_string path: spark/min_string.py output_type: STRING - inputs: - columns: - col: STRING_COLUMN + input: STRING_COLUMN # Spark Builtin: Calculate the skewness of the values in the input column. # source: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.skewness @@ -195,9 +166,7 @@ name: skewness path: spark/skewness.py output_type: FLOAT - inputs: - columns: - col: FLOAT_COLUMN|INT_COLUMN + input: FLOAT_COLUMN|INT_COLUMN # Spark Builtin: Calculate the standard deviation (scaled by 1/(N-1)). # source: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.stddev @@ -205,9 +174,7 @@ name: stddev path: spark/stddev.py output_type: FLOAT - inputs: - columns: - col: FLOAT_COLUMN|INT_COLUMN + input: FLOAT_COLUMN|INT_COLUMN # Spark Builtin: Calculate the standard deviation (scaled by 1/(N)). # source: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.stddev_pop @@ -215,9 +182,7 @@ name: stddev_pop path: spark/stddev_pop.py output_type: FLOAT - inputs: - columns: - col: FLOAT_COLUMN|INT_COLUMN + input: FLOAT_COLUMN|INT_COLUMN # Spark Builtin: Calculate the standard deviation (scaled by 1/(N-1)). # source: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.stddev_samp @@ -225,9 +190,7 @@ name: stddev_samp path: spark/stddev_samp.py output_type: FLOAT - inputs: - columns: - col: FLOAT_COLUMN|INT_COLUMN + input: FLOAT_COLUMN|INT_COLUMN # Spark Builtin: Sum all of the values in the input int column. # source: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.sum @@ -235,9 +198,7 @@ name: sum_int path: spark/sum_int.py output_type: INT - inputs: - columns: - col: INT_COLUMN + input: INT_COLUMN # Spark Builtin: Sum all of the values in the input float column. # source: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.sum @@ -245,9 +206,7 @@ name: sum_float path: spark/sum_float.py output_type: FLOAT - inputs: - columns: - col: FLOAT_COLUMN + input: FLOAT_COLUMN # Spark Builtin: Sum all of the distinct values in the input int column. # source: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.sumDistinct @@ -255,9 +214,7 @@ name: sum_distinct_int path: spark/sum_distinct_int.py output_type: INT - inputs: - columns: - col: INT_COLUMN + input: INT_COLUMN # Spark Builtin: Sum all of the distinct values in the input float column. @@ -266,9 +223,7 @@ name: sum_distinct_float path: spark/sum_distinct_float.py output_type: FLOAT - inputs: - columns: - col: FLOAT_COLUMN + input: FLOAT_COLUMN # Spark Builtin: Calculate the variance (scaled by 1/(N)). # source: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.var_pop @@ -276,9 +231,7 @@ name: var_pop path: spark/var_pop.py output_type: FLOAT - inputs: - columns: - col: FLOAT_COLUMN|INT_COLUMN + input: FLOAT_COLUMN|INT_COLUMN # Spark Builtin: Calculate the variance (scaled by 1/(N-1)). # source: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.var_samp @@ -286,9 +239,7 @@ name: var_samp path: spark/var_samp.py output_type: FLOAT - inputs: - columns: - col: FLOAT_COLUMN|INT_COLUMN + input: FLOAT_COLUMN|INT_COLUMN # Spark Builtin: Calculate the variance (scaled by 1/(N-1)). # source: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.variance @@ -296,9 +247,7 @@ name: variance path: spark/variance.py output_type: FLOAT - inputs: - columns: - col: FLOAT_COLUMN|INT_COLUMN + input: FLOAT_COLUMN|INT_COLUMN # Given the number of buckets, calculates the boundaries of a bucket such that the values in the column can be evenly # split into the buckets. Works well with transformers.bucketize. @@ -308,11 +257,9 @@ name: bucket_boundaries path: bucket_boundaries.py output_type: [FLOAT] - inputs: - columns: - col: FLOAT_COLUMN|INT_COLUMN - args: - num_buckets: INT + input: + col: FLOAT_COLUMN|INT_COLUMN + num_buckets: INT # Enumerates the unique values in a string column and orders them by placing the unique strings in # list ordered by most frequent starting at the 0th index. @@ -323,9 +270,7 @@ name: index_string path: index_string.py output_type: {"index": [STRING], "reversed_index": {STRING: INT}} - inputs: - columns: - col: STRING_COLUMN + input: STRING_COLUMN # Counts the occurrences of each value in the string input column and divides the counts by the total number of values. # For example: An input column with the following values ['t', 'f', 't', 't'] would return {'t': 0.75, 'f': 0.25}. @@ -333,9 +278,7 @@ name: class_distribution_string path: class_distribution.py output_type: {STRING: FLOAT} - inputs: - columns: - col: STRING_COLUMN + input: STRING_COLUMN # Counts the occurrences of each value in the int input column and divides the counts by the total number of values. # For example: An input column with the following values [1, 2, 3, 1] would return {1: 0.5, 2: 0.25, 3: 0.25}. @@ -343,6 +286,4 @@ name: class_distribution_int path: class_distribution.py output_type: {INT: FLOAT} - inputs: - columns: - col: INT_COLUMN + input: INT_COLUMN diff --git a/pkg/estimators/estimators.yaml b/pkg/estimators/estimators.yaml new file mode 100644 index 0000000000..711b2c1b84 --- /dev/null +++ b/pkg/estimators/estimators.yaml @@ -0,0 +1,15 @@ +# Copyright 2019 Cortex Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + diff --git a/pkg/operator/api/userconfig/errors.go b/pkg/operator/api/userconfig/errors.go index 58ae216bed..cf993ea796 100644 --- a/pkg/operator/api/userconfig/errors.go +++ b/pkg/operator/api/userconfig/errors.go @@ -441,9 +441,13 @@ func ErrorInvalidOutputType(provided interface{}) error { } func ErrorUnsupportedLiteralType(provided interface{}, allowedType interface{}) error { + message := fmt.Sprintf("input value's type is not supported by the schema (got %s, expected input with type %s)", DataTypeStr(provided), DataTypeStr(allowedType)) + if str, ok := provided.(string); ok { + message += fmt.Sprintf(" (note: if you are trying to reference a Cortex resource named %s, use \"@%s\")", str, str) + } return Error{ Kind: ErrUnsupportedLiteralType, - message: fmt.Sprintf("input value's type is not supported by the schema (got %s, expected input with type %s)", DataTypeStr(provided), DataTypeStr(allowedType)), + message: message, } } diff --git a/pkg/operator/api/userconfig/models.go b/pkg/operator/api/userconfig/models.go index 9073ac6c91..5a8bbc2f99 100644 --- a/pkg/operator/api/userconfig/models.go +++ b/pkg/operator/api/userconfig/models.go @@ -86,7 +86,7 @@ var modelValidation = &cr.StructValidation{ }, }, { - StructField: "HParams", + StructField: "Hparams", InterfaceValidation: &cr.InterfaceValidation{ Required: false, }, diff --git a/pkg/operator/context/models.go b/pkg/operator/context/models.go index 7a731fc614..e163d3ba83 100644 --- a/pkg/operator/context/models.go +++ b/pkg/operator/context/models.go @@ -103,7 +103,7 @@ func getModels( } // TargetColumn - targetColumnName, _ := yaml.UnescapeAtSymbol(modelConfig.TargetColumn) + targetColumnName, _ := yaml.ExtractAtSymbolText(modelConfig.TargetColumn) targetColumn := columns[targetColumnName] if targetColumn == nil { return nil, errors.Wrap(userconfig.ErrorUndefinedResource(targetColumnName, resource.RawColumnType, resource.TransformedColumnType), userconfig.Identify(modelConfig), userconfig.TargetColumnKey) diff --git a/pkg/transformers/transformers.yaml b/pkg/transformers/transformers.yaml index 0e08b9d5b8..e7fb10ed62 100644 --- a/pkg/transformers/transformers.yaml +++ b/pkg/transformers/transformers.yaml @@ -21,11 +21,9 @@ name: bucketize path: bucketize.py output_type: INT_COLUMN - inputs: - columns: - num: INT_COLUMN|FLOAT_COLUMN - args: - bucket_boundaries: [FLOAT] + input: + col: INT_COLUMN|FLOAT_COLUMN + bucket_boundaries: [FLOAT] # Given the mean and standard deviation of the column, normalize (z-score, standardize) # all of the values in column by (x - mean)/stddev where x is a value in the column. @@ -33,12 +31,10 @@ name: normalize path: normalize.py output_type: FLOAT_COLUMN - inputs: - columns: - num: FLOAT_COLUMN|INT_COLUMN - args: - mean: INT|FLOAT - stddev: INT|FLOAT + input: + col: FLOAT_COLUMN|INT_COLUMN + mean: INT|FLOAT + stddev: INT|FLOAT # Given labels, map the string column to its index in the labels array. # Example: @@ -48,8 +44,6 @@ name: index_string path: index_string.py output_type: INT_COLUMN - inputs: - columns: - text: STRING_COLUMN - args: - indexes: {"index": [STRING], "reversed_index": {STRING: INT}} + input: + col: STRING_COLUMN + indexes: {"index": [STRING], "reversed_index": {STRING: INT}} From be0f18dbfd654e4ade97a98a828283c073f96dd8 Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Thu, 6 Jun 2019 22:10:02 -0700 Subject: [PATCH 03/44] Fix ContextFromSerial() --- pkg/lib/msgpack/msgpack.go | 3 +- pkg/operator/api/context/aggregates.go | 4 +- pkg/operator/api/context/serialize.go | 100 +++++- pkg/operator/api/userconfig/validators.go | 26 +- .../api/userconfig/validators_test.go | 284 +++++++++--------- pkg/operator/context/resources_test.go | 6 +- 6 files changed, 257 insertions(+), 166 deletions(-) diff --git a/pkg/lib/msgpack/msgpack.go b/pkg/lib/msgpack/msgpack.go index 16fb6377b7..17030d01a5 100644 --- a/pkg/lib/msgpack/msgpack.go +++ b/pkg/lib/msgpack/msgpack.go @@ -17,8 +17,9 @@ limitations under the License. package msgpack import ( - "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/ugorji/go/codec" + + "github.com/cortexlabs/cortex/pkg/lib/errors" ) var mh codec.MsgpackHandle diff --git a/pkg/operator/api/context/aggregates.go b/pkg/operator/api/context/aggregates.go index 7daeda8712..c095456d8c 100644 --- a/pkg/operator/api/context/aggregates.go +++ b/pkg/operator/api/context/aggregates.go @@ -25,8 +25,8 @@ type Aggregates map[string]*Aggregate type Aggregate struct { *userconfig.Aggregate *ComputedResourceFields - Type interface{} `json:"type"` - Key string `json:"key"` + Type userconfig.OutputSchema `json:"type"` + Key string `json:"key"` } func (aggregates Aggregates) OneByID(id string) *Aggregate { diff --git a/pkg/operator/api/context/serialize.go b/pkg/operator/api/context/serialize.go index dee8e00c7d..df147dc310 100644 --- a/pkg/operator/api/context/serialize.go +++ b/pkg/operator/api/context/serialize.go @@ -37,12 +37,12 @@ type DataSplit struct { } type Serial struct { - Context + *Context RawColumnSplit *RawColumnsTypeSplit `json:"raw_columns"` DataSplit *DataSplit `json:"environment_data"` } -func (ctx Context) splitRawColumns() *RawColumnsTypeSplit { +func (ctx *Context) splitRawColumns() *RawColumnsTypeSplit { var rawIntColumns = make(map[string]*RawIntColumn) var rawFloatColumns = make(map[string]*RawFloatColumn) var rawStringColumns = make(map[string]*RawStringColumn) @@ -68,7 +68,7 @@ func (ctx Context) splitRawColumns() *RawColumnsTypeSplit { } } -func (serial Serial) collectRawColumns() RawColumns { +func (serial *Serial) collectRawColumns() RawColumns { var rawColumns = make(map[string]RawColumn) for name, rawColumn := range serial.RawColumnSplit.RawIntColumns { @@ -87,7 +87,7 @@ func (serial Serial) collectRawColumns() RawColumns { return rawColumns } -func (ctx Context) splitEnvironment() *DataSplit { +func (ctx *Context) splitEnvironment() *DataSplit { var split DataSplit switch typedData := ctx.Environment.Data.(type) { case *userconfig.CSVData: @@ -110,7 +110,85 @@ func (serial *Serial) collectEnvironment() (*Environment, error) { return serial.Environment, nil } -func (ctx Context) ToSerial() *Serial { +func (ctx *Context) castSchemaTypes() error { + for _, constant := range ctx.Constants { + if constant.Type != nil { + castedType, err := userconfig.ValidateOutputSchema(constant.Type) + if err != nil { + return err + } + constant.Constant.Type = castedType + } + } + + for _, aggregator := range ctx.Aggregators { + if aggregator.OutputType != nil { + casted, err := userconfig.ValidateOutputSchema(aggregator.OutputType) + if err != nil { + return err + } + aggregator.Aggregator.OutputType = casted + } + + if aggregator.Input != nil { + casted, err := userconfig.ValidateInputTypeSchema(aggregator.Input.Type, false, true) + if err != nil { + return err + } + aggregator.Aggregator.Input.Type = casted + } + } + + for _, aggregate := range ctx.Aggregates { + if aggregate.Type != nil { + casted, err := userconfig.ValidateOutputSchema(aggregate.Type) + if err != nil { + return err + } + aggregate.Type = casted + } + } + + for _, transformer := range ctx.Transformers { + if transformer.Input != nil { + casted, err := userconfig.ValidateInputTypeSchema(transformer.Input.Type, false, true) + if err != nil { + return err + } + transformer.Transformer.Input.Type = casted + } + } + + for _, estimator := range ctx.Estimators { + if estimator.Input != nil { + casted, err := userconfig.ValidateInputTypeSchema(estimator.Input.Type, false, true) + if err != nil { + return err + } + estimator.Estimator.Input.Type = casted + } + + if estimator.TrainingInput != nil { + casted, err := userconfig.ValidateInputTypeSchema(estimator.TrainingInput.Type, false, true) + if err != nil { + return err + } + estimator.Estimator.TrainingInput.Type = casted + } + + if estimator.Hparams != nil { + casted, err := userconfig.ValidateInputTypeSchema(estimator.Hparams.Type, true, true) + if err != nil { + return err + } + estimator.Estimator.Hparams.Type = casted + } + } + + return nil +} + +func (ctx *Context) ToSerial() *Serial { serial := Serial{ Context: ctx, RawColumnSplit: ctx.splitRawColumns(), @@ -120,15 +198,23 @@ func (ctx Context) ToSerial() *Serial { return &serial } -func (serial Serial) ContextFromSerial() (*Context, error) { +func (serial *Serial) ContextFromSerial() (*Context, error) { ctx := serial.Context + ctx.RawColumns = serial.collectRawColumns() + environment, err := serial.collectEnvironment() if err != nil { return nil, err } ctx.Environment = environment - return &ctx, nil + + err = ctx.castSchemaTypes() + if err != nil { + return nil, err + } + + return ctx, nil } func (ctx Context) ToMsgpackBytes() ([]byte, error) { diff --git a/pkg/operator/api/userconfig/validators.go b/pkg/operator/api/userconfig/validators.go index ef39ce3cd5..474638ef74 100644 --- a/pkg/operator/api/userconfig/validators.go +++ b/pkg/operator/api/userconfig/validators.go @@ -41,14 +41,14 @@ type InputTypeSchema interface{} // CompundType, length-one array of *InputSchem type OutputSchema interface{} // ValueType, length-one array of OutputSchema, or map of {scalar|ValueType -> OutputSchema} (no *_COLUMN types, compound types, or input options like _default) func inputSchemaValidator(in interface{}) (interface{}, error) { - return ValidateInputSchema(in, false) // This casts it to *InputSchema + return ValidateInputSchema(in, false, false) // This casts it to *InputSchema } func inputSchemaValidatorValueTypesOnly(in interface{}) (interface{}, error) { - return ValidateInputSchema(in, true) // This casts it to *InputSchema + return ValidateInputSchema(in, true, false) // This casts it to *InputSchema } -func ValidateInputSchema(in interface{}, disallowColumnTypes bool) (*InputSchema, error) { +func ValidateInputSchema(in interface{}, disallowColumnTypes bool, isAlreadyParsed bool) (*InputSchema, error) { // Check for cortex options vs short form if inMap, ok := cast.InterfaceToStrInterfaceMap(in); ok { foundUnderscore, foundNonUnderscore := false, false @@ -72,7 +72,7 @@ func ValidateInputSchema(in interface{}, disallowColumnTypes bool) (*InputSchema InterfaceValidation: &cr.InterfaceValidation{ Required: true, Validator: func(t interface{}) (interface{}, error) { - return validateInputTypeSchema(t, disallowColumnTypes) + return ValidateInputTypeSchema(t, disallowColumnTypes, isAlreadyParsed) }, }, }, @@ -81,8 +81,10 @@ func ValidateInputSchema(in interface{}, disallowColumnTypes bool) (*InputSchema BoolValidation: &cr.BoolValidation{}, }, { - StructField: "Default", - InterfaceValidation: &cr.InterfaceValidation{}, + StructField: "Default", + InterfaceValidation: &cr.InterfaceValidation{ + AllowExplicitNull: isAlreadyParsed, + }, }, { StructField: "AllowNull", @@ -92,12 +94,14 @@ func ValidateInputSchema(in interface{}, disallowColumnTypes bool) (*InputSchema StructField: "MinCount", Int64PtrValidation: &cr.Int64PtrValidation{ GreaterThanOrEqualTo: pointer.Int64(0), + AllowExplicitNull: isAlreadyParsed, }, }, { StructField: "MaxCount", Int64PtrValidation: &cr.Int64PtrValidation{ GreaterThanOrEqualTo: pointer.Int64(0), + AllowExplicitNull: isAlreadyParsed, }, }, }, @@ -117,7 +121,7 @@ func ValidateInputSchema(in interface{}, disallowColumnTypes bool) (*InputSchema } } - typeSchema, err := validateInputTypeSchema(in, disallowColumnTypes) + typeSchema, err := ValidateInputTypeSchema(in, disallowColumnTypes, isAlreadyParsed) if err != nil { return nil, err } @@ -132,7 +136,7 @@ func ValidateInputSchema(in interface{}, disallowColumnTypes bool) (*InputSchema return inputSchema, nil } -func validateInputTypeSchema(in interface{}, disallowColumnTypes bool) (InputTypeSchema, error) { +func ValidateInputTypeSchema(in interface{}, disallowColumnTypes bool, isAlreadyParsed bool) (InputTypeSchema, error) { // String if inStr, ok := in.(string); ok { compoundType, err := CompoundTypeFromString(inStr) @@ -150,7 +154,7 @@ func validateInputTypeSchema(in interface{}, disallowColumnTypes bool) (InputTyp if len(inSlice) != 1 { return nil, ErrorTypeListLength(inSlice) } - inputSchema, err := ValidateInputSchema(inSlice[0], disallowColumnTypes) + inputSchema, err := ValidateInputSchema(inSlice[0], disallowColumnTypes, isAlreadyParsed) if err != nil { return nil, errors.Wrap(err, s.Index(0)) } @@ -182,7 +186,7 @@ func validateInputTypeSchema(in interface{}, disallowColumnTypes bool) (InputTyp if disallowColumnTypes && typeKey.IsColumns() { return nil, ErrorColumnTypeNotAllowed(typeKey) } - valueInputSchema, err := ValidateInputSchema(typeValue, disallowColumnTypes) + valueInputSchema, err := ValidateInputSchema(typeValue, disallowColumnTypes, isAlreadyParsed) if err != nil { return nil, errors.Wrap(err, string(typeKey)) } @@ -201,7 +205,7 @@ func validateInputTypeSchema(in interface{}, disallowColumnTypes bool) (InputTyp } } - valueInputSchema, err := ValidateInputSchema(value, disallowColumnTypes) + valueInputSchema, err := ValidateInputSchema(value, disallowColumnTypes, isAlreadyParsed) if err != nil { return nil, errors.Wrap(err, s.UserStrStripped(key)) } diff --git a/pkg/operator/api/userconfig/validators_test.go b/pkg/operator/api/userconfig/validators_test.go index 79f057edab..e936992c2b 100644 --- a/pkg/operator/api/userconfig/validators_test.go +++ b/pkg/operator/api/userconfig/validators_test.go @@ -329,7 +329,7 @@ func TestCastOutputValue(t *testing.T) { } func checkCastInputValueEqual(t *testing.T, inputSchemaYAML string, valueYAML string, expected interface{}) { - inputSchema, err := ValidateInputSchema(cr.MustReadYAMLStr(inputSchemaYAML), false) + inputSchema, err := ValidateInputSchema(cr.MustReadYAMLStr(inputSchemaYAML), false, false) require.NoError(t, err) casted, err := CastInputValue(cr.MustReadYAMLStr(valueYAML), inputSchema) require.NoError(t, err) @@ -337,14 +337,14 @@ func checkCastInputValueEqual(t *testing.T, inputSchemaYAML string, valueYAML st } func checkCastInputValueError(t *testing.T, inputSchemaYAML string, valueYAML string) { - inputSchema, err := ValidateInputSchema(cr.MustReadYAMLStr(inputSchemaYAML), false) + inputSchema, err := ValidateInputSchema(cr.MustReadYAMLStr(inputSchemaYAML), false, false) require.NoError(t, err) _, err = CastInputValue(cr.MustReadYAMLStr(valueYAML), inputSchema) require.Error(t, err) } func checkCastInputValueNoError(t *testing.T, inputSchemaYAML string, valueYAML string) { - inputSchema, err := ValidateInputSchema(cr.MustReadYAMLStr(inputSchemaYAML), false) + inputSchema, err := ValidateInputSchema(cr.MustReadYAMLStr(inputSchemaYAML), false, false) require.NoError(t, err) _, err = CastInputValue(cr.MustReadYAMLStr(valueYAML), inputSchema) require.NoError(t, err) @@ -386,137 +386,137 @@ func TestValidateInputSchema(t *testing.T) { var err error inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( - `STRING`), false) + `STRING`), false, false) require.NoError(t, err) inputSchema2, err = ValidateInputSchema(cr.MustReadYAMLStr( - `_type: STRING`), false) + `_type: STRING`), false, false) require.NoError(t, err) require.Equal(t, inputSchema, inputSchema2) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `STRING_COLUMN`), false) + `STRING_COLUMN`), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `STRING_COLUMN`), true) + `STRING_COLUMN`), true, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `INFERRED_COLUMN`), false) + `INFERRED_COLUMN`), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `BAD_COLUMN`), false) + `BAD_COLUMN`), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: STRING _default: test - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: STRING_COLUMN _default: test - `), false) + `), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: STRING _default: Null - `), false) + `), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: STRING _default: 2 - `), false) + `), false, false) require.Error(t, err) // Lists inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( - `[STRING]`), false) + `[STRING]`), false, false) require.NoError(t, err) inputSchema2, err = ValidateInputSchema(cr.MustReadYAMLStr( - `_type: [STRING]`), false) + `_type: [STRING]`), false, false) require.NoError(t, err) inputSchema3, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: - _type: STRING - `), false) + `), false, false) require.NoError(t, err) require.Equal(t, inputSchema, inputSchema2) require.Equal(t, inputSchema, inputSchema3) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `[STRING|INT]`), false) + `[STRING|INT]`), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `[STRING_COLUMN|INT_COLUMN]`), false) + `[STRING_COLUMN|INT_COLUMN]`), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `[STRING_COLUMN|INT_COLUMN]`), true) + `[STRING_COLUMN|INT_COLUMN]`), true, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `[STRING|INT_COLUMN]`), false) + `[STRING|INT_COLUMN]`), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: [STRING] _default: [test1, test2, test3] - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: [STRING_COLUMN] _default: [test1, test2, test3] - `), false) + `), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: [STRING] _default: [test1, 2, test3] - `), false) + `), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: [STRING|INT] _default: [test1, 2, test3] - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: [STRING|FLOAT] _default: [test1, 2, test3] - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: [STRING] _default: test1 - `), false) + `), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: [STRING] _min_count: 2 - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -524,7 +524,7 @@ func TestValidateInputSchema(t *testing.T) { _type: [STRING] _min_count: 2 _max_count: 2 - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -532,7 +532,7 @@ func TestValidateInputSchema(t *testing.T) { _type: [STRING] _min_count: 2 _max_count: 1 - `), false) + `), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -540,7 +540,7 @@ func TestValidateInputSchema(t *testing.T) { _type: [STRING] _default: [test1] _min_count: 2 - `), false) + `), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -548,21 +548,21 @@ func TestValidateInputSchema(t *testing.T) { _type: [STRING] _default: [test1, test2] _min_count: 2 - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: [STRING] _min_count: -1 - `), false) + `), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: [STRING] _min_count: test - `), false) + `), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -570,7 +570,7 @@ func TestValidateInputSchema(t *testing.T) { _type: [STRING] _default: [test1, test2, test3] _max_count: 2 - `), false) + `), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -578,140 +578,140 @@ func TestValidateInputSchema(t *testing.T) { _type: [STRING] _default: [test1, test2] _max_count: 2 - `), false) + `), false, false) require.NoError(t, err) // Maps inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( - `arg1: STRING`), false) + `arg1: STRING`), false, false) require.NoError(t, err) inputSchema2, err = ValidateInputSchema(cr.MustReadYAMLStr( ` arg1: _type: STRING - `), false) + `), false, false) require.NoError(t, err) inputSchema3, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: {arg1: STRING} - `), false) + `), false, false) require.NoError(t, err) require.Equal(t, inputSchema, inputSchema2) require.Equal(t, inputSchema, inputSchema3) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `arg1: STRING_COLUMN`), false) + `arg1: STRING_COLUMN`), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `arg1: STRING_COLUMN`), true) + `arg1: STRING_COLUMN`), true, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `STRING_COLUMN: STRING`), false) + `STRING_COLUMN: STRING`), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `STRING_COLUMN: STRING`), true) + `STRING_COLUMN: STRING`), true, false) require.Error(t, err) inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( - `_arg1: STRING`), false) + `_arg1: STRING`), false, false) require.Error(t, err) inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( - `arg1: test`), false) + `arg1: test`), false, false) require.Error(t, err) inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( - `STRING: test`), false) + `STRING: test`), false, false) require.Error(t, err) inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( - `STRING_COLUMN: test`), false) + `STRING_COLUMN: test`), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: {arg1: STRING} _min_count: 2 - `), false) + `), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: {STRING: INT} _default: {test: 2} - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: {FLOAT: INT} _default: {2: 2} - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: {STRING: INT} _default: {test: test} - `), false) + `), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: {STRING: INT|STRING} _default: {test: test} - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: {STRING: INT_COLUMN} _min_count: 2 - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: {STRING_COLUMN: INT} _min_count: 2 - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: {STRING_COLUMN: INT} _default: {test: 2} - `), false) + `), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: {STRING_COLUMN: INT_COLUMN} _min_count: 2 - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: {STRING_COLUMN: INT_COLUMN|STRING} - `), false) + `), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: {STRING_COLUMN: INT_COLUMN|STRING_COLUMN} - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: {STRING: INT} _min_count: 2 - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -719,7 +719,7 @@ func TestValidateInputSchema(t *testing.T) { arg1: _type: STRING _optional: true - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -727,7 +727,7 @@ func TestValidateInputSchema(t *testing.T) { arg1: _type: STRING _default: test - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -735,7 +735,7 @@ func TestValidateInputSchema(t *testing.T) { arg1: _type: STRING _default: 2 - `), false) + `), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -743,7 +743,7 @@ func TestValidateInputSchema(t *testing.T) { arg1: _type: STRING _default: Null - `), false) + `), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -751,47 +751,47 @@ func TestValidateInputSchema(t *testing.T) { arg1: _type: STRING _default: Null - `), false) + `), false, false) require.Error(t, err) // Mixed _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `[[STRING]]`), false) + `[[STRING]]`), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: [[STRING]] _default: [[test1, test2]] - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: [[STRING_COLUMN]] _default: [[test1, test2]] - `), false) + `), false, false) require.Error(t, err) inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( ` - arg1: STRING arg2: INT - `), false) + `), false, false) require.NoError(t, err) inputSchema2, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: - arg1: STRING arg2: INT - `), false) + `), false, false) require.NoError(t, err) inputSchema3, err = ValidateInputSchema(cr.MustReadYAMLStr( ` - arg1: {_type: STRING} arg2: {_type: INT} - `), false) + `), false, false) require.NoError(t, err) inputSchema4, err = ValidateInputSchema(cr.MustReadYAMLStr( ` @@ -800,7 +800,7 @@ func TestValidateInputSchema(t *testing.T) { _type: STRING arg2: _type: INT - `), false) + `), false, false) require.NoError(t, err) require.Equal(t, inputSchema, inputSchema2) require.Equal(t, inputSchema, inputSchema3) @@ -814,7 +814,7 @@ func TestValidateInputSchema(t *testing.T) { arg2: _type: INT _default: 2 - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -828,7 +828,7 @@ func TestValidateInputSchema(t *testing.T) { arg2: _type: INT _default: 2 - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -843,7 +843,7 @@ func TestValidateInputSchema(t *testing.T) { arg2: _type: INT _default: 2 - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -854,7 +854,7 @@ func TestValidateInputSchema(t *testing.T) { arg_b: _type: STRING _default: test - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -865,7 +865,7 @@ func TestValidateInputSchema(t *testing.T) { arg_b: _type: STRING _default: test - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -876,14 +876,14 @@ func TestValidateInputSchema(t *testing.T) { arg_b: _type: STRING _default: test - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( ` arg1: 2: STRING - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -891,7 +891,7 @@ func TestValidateInputSchema(t *testing.T) { arg1: _type: {2: STRING} _default: {2: test} - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -900,53 +900,53 @@ func TestValidateInputSchema(t *testing.T) { 2: _type: STRING _default: test - `), false) + `), false, false) require.NoError(t, err) inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( - `[{INT_COLUMN: STRING|INT}]`), false) + `[{INT_COLUMN: STRING|INT}]`), false, false) require.NoError(t, err) inputSchema2, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: [{INT_COLUMN: STRING|INT}] - `), false) + `), false, false) require.NoError(t, err) require.Equal(t, inputSchema, inputSchema2) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `map: {BOOL|FLOAT: INT|STRING}`), false) + `map: {BOOL|FLOAT: INT|STRING}`), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `map: {mean: FLOAT, stddev: FLOAT}`), false) + `map: {mean: FLOAT, stddev: FLOAT}`), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `map: {STRING: {lat: FLOAT, lon: FLOAT}}`), false) + `map: {STRING: {lat: FLOAT, lon: FLOAT}}`), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `map: {STRING: {lat: FLOAT, lon: [FLOAT]}}`), false) + `map: {STRING: {lat: FLOAT, lon: [FLOAT]}}`), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `map: {STRING: {FLOAT: INT}}`), false) + `map: {STRING: {FLOAT: INT}}`), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `map: {STRING: {FLOAT: [INT]}}`), false) + `map: {STRING: {FLOAT: [INT]}}`), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `map: {STRING: {lat: FLOAT, lon: {lat2: FLOAT, lon2: INT}}}`), false) + `map: {STRING: {lat: FLOAT, lon: {lat2: FLOAT, lon2: INT}}}`), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `map6: {STRING: {lat: FLOAT, lon: {lat2: FLOAT, lon2: {INT: STRING}}}}`), false) + `map6: {STRING: {lat: FLOAT, lon: {lat2: FLOAT, lon2: {INT: STRING}}}}`), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `map6: {STRING: {lat: FLOAT, lon: {lat2: FLOAT, lon2: {INT: STRING}, mean: BOOL}}}`), false) + `map6: {STRING: {lat: FLOAT, lon: {lat2: FLOAT, lon2: {INT: STRING}, mean: BOOL}}}`), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -961,59 +961,59 @@ func TestValidateInputSchema(t *testing.T) { map5: {STRING: {BOOL: [INT]}} map6: {STRING: {lat: FLOAT, lon: {lat2: FLOAT, lon2: INT}}} map6: {STRING: {lat: FLOAT, lon: {lat2: FLOAT, lon2: {INT: STRING}, mean: BOOL}}} - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `map: {STRING: INT, INT: FLOAT}`), false) + `map: {STRING: INT, INT: FLOAT}`), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `map: {STRING: INT, INT: [FLOAT]}`), false) + `map: {STRING: INT, INT: [FLOAT]}`), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `map: {mean: FLOAT, INT: FLOAT}`), false) + `map: {mean: FLOAT, INT: FLOAT}`), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `map: {mean: FLOAT, INT: [FLOAT]}`), false) + `map: {mean: FLOAT, INT: [FLOAT]}`), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `map: {STRING: {lat: FLOAT, STRING: FLOAT}}`), false) + `map: {STRING: {lat: FLOAT, STRING: FLOAT}}`), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `map: {STRING: {STRING: test}}`), false) + `map: {STRING: {STRING: test}}`), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `cols: [STRING_COLUMN, INT_COLUMN]`), false) + `cols: [STRING_COLUMN, INT_COLUMN]`), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `cols: [STRING_COLUMNs]`), false) + `cols: [STRING_COLUMNs]`), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `cols: [STRING_COLUMN|BAD]`), false) + `cols: [STRING_COLUMN|BAD]`), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `cols: Null`), false) + `cols: Null`), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `cols: 1`), false) + `cols: 1`), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `cols: [1]`), false) + `cols: [1]`), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( - `cols: []`), false) + `cols: []`), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -1054,7 +1054,7 @@ func TestValidateInputSchema(t *testing.T) { nums4_scalar: [FLOAT|INT] nums5_scalar: [STRING|INT|FLOAT] nums6_scalar: [STRING|INT|FLOAT|BOOL] - `), false) + `), false, false) require.NoError(t, err) // Casting defaults @@ -1063,7 +1063,7 @@ func TestValidateInputSchema(t *testing.T) { ` _type: INT _default: 2 - `), false) + `), false, false) require.NoError(t, err) require.Equal(t, inputSchema.Default, int64(2)) @@ -1071,21 +1071,21 @@ func TestValidateInputSchema(t *testing.T) { ` _type: INT _default: test - `), false) + `), false, false) require.Error(t, err) inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: INT _default: 2.2 - `), false) + `), false, false) require.Error(t, err) inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: FLOAT _default: 2 - `), false) + `), false, false) require.NoError(t, err) require.Equal(t, inputSchema.Default, float64(2)) @@ -1093,7 +1093,7 @@ func TestValidateInputSchema(t *testing.T) { ` _type: FLOAT|INT _default: 2 - `), false) + `), false, false) require.NoError(t, err) require.Equal(t, inputSchema.Default, int64(2)) @@ -1101,7 +1101,7 @@ func TestValidateInputSchema(t *testing.T) { ` _type: BOOL _default: true - `), false) + `), false, false) require.NoError(t, err) require.Equal(t, inputSchema.Default, true) @@ -1109,7 +1109,7 @@ func TestValidateInputSchema(t *testing.T) { ` _type: {STRING: FLOAT} _default: {test: 2.2, test2: 4.4} - `), false) + `), false, false) require.NoError(t, err) require.Equal(t, inputSchema.Default, map[interface{}]interface{}{"test": 2.2, "test2": 4.4}) @@ -1117,14 +1117,14 @@ func TestValidateInputSchema(t *testing.T) { ` _type: {STRING: FLOAT} _default: {test: test2} - `), false) + `), false, false) require.Error(t, err) inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: {STRING: FLOAT} _default: {test: 2} - `), false) + `), false, false) require.NoError(t, err) require.Equal(t, inputSchema.Default, map[interface{}]interface{}{"test": float64(2)}) @@ -1132,7 +1132,7 @@ func TestValidateInputSchema(t *testing.T) { ` _type: {STRING: FLOAT} _default: {test: 2.0} - `), false) + `), false, false) require.NoError(t, err) require.Equal(t, inputSchema.Default, map[interface{}]interface{}{"test": float64(2)}) @@ -1140,7 +1140,7 @@ func TestValidateInputSchema(t *testing.T) { ` _type: {STRING: INT} _default: {test: 2} - `), false) + `), false, false) require.NoError(t, err) require.Equal(t, inputSchema.Default, map[interface{}]interface{}{"test": int64(2)}) @@ -1148,14 +1148,14 @@ func TestValidateInputSchema(t *testing.T) { ` _type: {STRING: INT} _default: {test: 2.0} - `), false) + `), false, false) require.Error(t, err) inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: {mean: FLOAT, sum: INT} _default: {mean: 2.2, sum: 4} - `), false) + `), false, false) require.NoError(t, err) require.Equal(t, inputSchema.Default, map[interface{}]interface{}{"mean": float64(2.2), "sum": int64(4)}) @@ -1163,49 +1163,49 @@ func TestValidateInputSchema(t *testing.T) { ` _type: {mean: FLOAT, sum: INT} _default: {mean: 2.2, sum: test} - `), false) + `), false, false) require.Error(t, err) inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: {mean: FLOAT, sum: INT} _default: {mean: false, sum: 4} - `), false) + `), false, false) require.Error(t, err) inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: {mean: FLOAT, sum: INT} _default: {mean: 2.2, 2: 4} - `), false) + `), false, false) require.Error(t, err) inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: {mean: FLOAT, sum: INT} _default: {mean: 2.2, sum: Null} - `), false) + `), false, false) require.Error(t, err) inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: {mean: FLOAT, sum: INT} _default: {mean: 2.2} - `), false) + `), false, false) require.Error(t, err) inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: {mean: FLOAT, sum: INT} _default: {mean: 2.2, sum: 4, stddev: 2} - `), false) + `), false, false) require.Error(t, err) inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: [INT] _default: [1, 2] - `), false) + `), false, false) require.NoError(t, err) require.Equal(t, inputSchema.Default, []interface{}{int64(1), int64(2)}) @@ -1213,14 +1213,14 @@ func TestValidateInputSchema(t *testing.T) { ` _type: [INT] _default: [1.0, 2] - `), false) + `), false, false) require.Error(t, err) inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: [FLOAT] _default: [1.0, 2] - `), false) + `), false, false) require.NoError(t, err) require.Equal(t, inputSchema.Default, []interface{}{float64(1), float64(2)}) @@ -1228,7 +1228,7 @@ func TestValidateInputSchema(t *testing.T) { ` _type: [FLOAT|INT] _default: [1.0, 2] - `), false) + `), false, false) require.NoError(t, err) require.Equal(t, inputSchema.Default, []interface{}{float64(1), int64(2)}) @@ -1236,14 +1236,14 @@ func TestValidateInputSchema(t *testing.T) { ` _type: [FLOAT|INT|BOOL] _default: [1.0, 2, true, test] - `), false) + `), false, false) require.Error(t, err) inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( ` _type: [FLOAT|INT|BOOL|STRING] _default: [1.0, 2, true, test] - `), false) + `), false, false) require.NoError(t, err) require.Equal(t, inputSchema.Default, []interface{}{float64(1), int64(2), true, "test"}) @@ -1270,7 +1270,7 @@ func TestValidateInputSchema(t *testing.T) { b: [testX, testY, testZ] c: {mean: 1.7, sum: [1], stddev: {z: 12}} d: 17 - `), false) + `), false, false) require.NoError(t, err) require.Equal(t, inputSchema.Default, map[interface{}]interface{}{ "testA": map[interface{}]interface{}{}, @@ -1303,7 +1303,7 @@ func TestValidateInputSchema(t *testing.T) { _default: 2 _default: testA: Null - `), false) + `), false, false) require.Error(t, err) inputSchema, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -1325,7 +1325,7 @@ func TestValidateInputSchema(t *testing.T) { _default: 2 _default: testA: Null - `), false) + `), false, false) require.NoError(t, err) require.Equal(t, inputSchema.Default, map[interface{}]interface{}{ "testA": nil, @@ -1348,7 +1348,7 @@ func TestValidateInputSchema(t *testing.T) { _default: 2 _default: testA: {} - `), false) + `), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -1371,7 +1371,7 @@ func TestValidateInputSchema(t *testing.T) { a: 88 c: {mean: 1.7, sum: [1], stddev: {z: 12}} d: 17 - `), false) + `), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -1410,7 +1410,7 @@ func TestValidateInputSchema(t *testing.T) { c: {mean: 1.7, sum: [1], stddev: {z: 12}} bools: [true, false, true] anything: [10, 2.2, test, false] - `), false) + `), false, false) require.NoError(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -1448,7 +1448,7 @@ func TestValidateInputSchema(t *testing.T) { c: {mean: 1.7, sum: [1], stddev: {z: 12}} bools: [true, false, true] anything: [10, 2.2, test, false] - `), false) + `), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -1487,7 +1487,7 @@ func TestValidateInputSchema(t *testing.T) { c: {mean: 1.7, sum: [1], stddev: {z: 12}} bools: [true, false, true] anything: [10, 2.2, test, false] - `), false) + `), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -1526,7 +1526,7 @@ func TestValidateInputSchema(t *testing.T) { c: {mean: 1.7, sum: [1], stddev: {z: 12}} bools: [true, false, true] anything: [10, 2.2, test, false] - `), false) + `), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -1565,7 +1565,7 @@ func TestValidateInputSchema(t *testing.T) { c: {mean: 1.7, sum: [1], stddev: {z: 12}} bools: [true, false, true] anything: [10, 2.2, test, false] - `), false) + `), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -1604,7 +1604,7 @@ func TestValidateInputSchema(t *testing.T) { c: {mean: 1.7, sum: [1], stddev: {z: 12}} bools: true anything: [10, 2.2, test, false] - `), false) + `), false, false) require.Error(t, err) _, err = ValidateInputSchema(cr.MustReadYAMLStr( @@ -1643,6 +1643,6 @@ func TestValidateInputSchema(t *testing.T) { c: {mean: 1.7, sum: [1], stddev: {z: 12}} bools: [1, 2, 3] anything: [10, 2.2, test, false] - `), false) + `), false, false) require.Error(t, err) } diff --git a/pkg/operator/context/resources_test.go b/pkg/operator/context/resources_test.go index 8f9c52a121..1d08da4699 100644 --- a/pkg/operator/context/resources_test.go +++ b/pkg/operator/context/resources_test.go @@ -26,7 +26,7 @@ import ( ) func checkValidateRuntimeTypesEqual(t *testing.T, schemaYAML string, inputYAML string, expected interface{}) { - schema, err := userconfig.ValidateInputSchema(cr.MustReadYAMLStr(schemaYAML), false) + schema, err := userconfig.ValidateInputSchema(cr.MustReadYAMLStr(schemaYAML), false, false) require.NoError(t, err) input := cr.MustReadYAMLStr(inputYAML) casted, err := validateRuntimeTypes(input, schema, allResourcesMap, aggregators, transformers, false) @@ -35,7 +35,7 @@ func checkValidateRuntimeTypesEqual(t *testing.T, schemaYAML string, inputYAML s } func checkValidateRuntimeTypesError(t *testing.T, schemaYAML string, inputYAML string) { - schema, err := userconfig.ValidateInputSchema(cr.MustReadYAMLStr(schemaYAML), false) + schema, err := userconfig.ValidateInputSchema(cr.MustReadYAMLStr(schemaYAML), false, false) require.NoError(t, err) input := cr.MustReadYAMLStr(inputYAML) _, err = validateRuntimeTypes(input, schema, allResourcesMap, aggregators, transformers, false) @@ -43,7 +43,7 @@ func checkValidateRuntimeTypesError(t *testing.T, schemaYAML string, inputYAML s } func checkValidateRuntimeTypesNoError(t *testing.T, schemaYAML string, inputYAML string) { - schema, err := userconfig.ValidateInputSchema(cr.MustReadYAMLStr(schemaYAML), false) + schema, err := userconfig.ValidateInputSchema(cr.MustReadYAMLStr(schemaYAML), false, false) require.NoError(t, err) input := cr.MustReadYAMLStr(inputYAML) _, err = validateRuntimeTypes(input, schema, allResourcesMap, aggregators, transformers, false) From 01af3d6db5806df43d8e0205bd014f33fa2c74b3 Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Thu, 6 Jun 2019 22:14:23 -0700 Subject: [PATCH 04/44] Fix lint issues --- pkg/estimators/estimators.yaml | 2 -- pkg/operator/api/context/dependencies.go | 4 ++-- pkg/operator/context/resource_fakes_test.go | 4 ++-- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/pkg/estimators/estimators.yaml b/pkg/estimators/estimators.yaml index 711b2c1b84..ff16bd7a73 100644 --- a/pkg/estimators/estimators.yaml +++ b/pkg/estimators/estimators.yaml @@ -11,5 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - diff --git a/pkg/operator/api/context/dependencies.go b/pkg/operator/api/context/dependencies.go index 2b9944c9d5..66791ca064 100644 --- a/pkg/operator/api/context/dependencies.go +++ b/pkg/operator/api/context/dependencies.go @@ -178,8 +178,8 @@ func ExtractCortexResources( // convert to slice and sort by ID var resourceIDs []string - for resourceId := range resources { - resourceIDs = append(resourceIDs, resourceId) + for resourceID := range resources { + resourceIDs = append(resourceIDs, resourceID) } sort.Strings(resourceIDs) resoucesSlice := make([]Resource, len(resources)) diff --git a/pkg/operator/context/resource_fakes_test.go b/pkg/operator/context/resource_fakes_test.go index 6b06bddf94..41538473ec 100644 --- a/pkg/operator/context/resource_fakes_test.go +++ b/pkg/operator/context/resource_fakes_test.go @@ -32,7 +32,7 @@ func mustValidateOutputSchema(yamlStr string) userconfig.OutputSchema { } func genConst(id string, outputType string, value string) *context.Constant { - var outType userconfig.OutputSchema = nil + var outType userconfig.OutputSchema if outputType != "" { outType = mustValidateOutputSchema(outputType) } @@ -60,7 +60,7 @@ func genConst(id string, outputType string, value string) *context.Constant { } func genAgg(id string, aggregatorType string) (*context.Aggregate, *context.Aggregator) { - var outputType userconfig.OutputSchema = nil + var outputType userconfig.OutputSchema if aggregatorType != "" { outputType = mustValidateOutputSchema(aggregatorType) } From b4d82e2340de2d46a0f2d24242c2223e750c4a19 Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Fri, 7 Jun 2019 11:19:01 -0700 Subject: [PATCH 05/44] Add outputSchemaValidator --- pkg/operator/api/userconfig/aggregators.go | 6 ++---- pkg/operator/api/userconfig/constants.go | 6 ++---- pkg/operator/api/userconfig/validators.go | 13 +++++++++++++ 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/pkg/operator/api/userconfig/aggregators.go b/pkg/operator/api/userconfig/aggregators.go index 8b7f429711..57543a0091 100644 --- a/pkg/operator/api/userconfig/aggregators.go +++ b/pkg/operator/api/userconfig/aggregators.go @@ -50,10 +50,8 @@ var aggregatorValidation = &cr.StructValidation{ { StructField: "OutputType", InterfaceValidation: &cr.InterfaceValidation{ - Required: true, - Validator: func(t interface{}) (interface{}, error) { - return ValidateOutputSchema(t) - }, + Required: true, + Validator: outputSchemaValidator, }, }, { diff --git a/pkg/operator/api/userconfig/constants.go b/pkg/operator/api/userconfig/constants.go index 44284cfa43..19ae4138ac 100644 --- a/pkg/operator/api/userconfig/constants.go +++ b/pkg/operator/api/userconfig/constants.go @@ -43,10 +43,8 @@ var constantValidation = &cr.StructValidation{ { StructField: "Type", InterfaceValidation: &cr.InterfaceValidation{ - Required: false, - Validator: func(t interface{}) (interface{}, error) { - return ValidateOutputSchema(t) - }, + Required: false, + Validator: outputSchemaValidator, }, }, { diff --git a/pkg/operator/api/userconfig/validators.go b/pkg/operator/api/userconfig/validators.go index 474638ef74..1d9baf36b9 100644 --- a/pkg/operator/api/userconfig/validators.go +++ b/pkg/operator/api/userconfig/validators.go @@ -41,13 +41,26 @@ type InputTypeSchema interface{} // CompundType, length-one array of *InputSchem type OutputSchema interface{} // ValueType, length-one array of OutputSchema, or map of {scalar|ValueType -> OutputSchema} (no *_COLUMN types, compound types, or input options like _default) func inputSchemaValidator(in interface{}) (interface{}, error) { + if in == nil { + return nil, nil + } return ValidateInputSchema(in, false, false) // This casts it to *InputSchema } func inputSchemaValidatorValueTypesOnly(in interface{}) (interface{}, error) { + if in == nil { + return nil, nil + } return ValidateInputSchema(in, true, false) // This casts it to *InputSchema } +func outputSchemaValidator(in interface{}) (interface{}, error) { + if in == nil { + return nil, nil + } + return ValidateOutputSchema(in) +} + func ValidateInputSchema(in interface{}, disallowColumnTypes bool, isAlreadyParsed bool) (*InputSchema, error) { // Check for cortex options vs short form if inMap, ok := cast.InterfaceToStrInterfaceMap(in); ok { From ef8b3b4213a2bae4fb2ca1284bf51e8f3ba343fb Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Fri, 7 Jun 2019 11:19:51 -0700 Subject: [PATCH 06/44] Use compund type for estimator target column --- pkg/operator/api/userconfig/errors.go | 9 ------- pkg/operator/api/userconfig/estimators.go | 30 +++++++++++++---------- pkg/operator/context/estimators.go | 2 +- pkg/operator/context/models.go | 7 ++---- 4 files changed, 20 insertions(+), 28 deletions(-) diff --git a/pkg/operator/api/userconfig/errors.go b/pkg/operator/api/userconfig/errors.go index cf993ea796..d5a284835f 100644 --- a/pkg/operator/api/userconfig/errors.go +++ b/pkg/operator/api/userconfig/errors.go @@ -69,7 +69,6 @@ const ( ErrTypeMapZeroLength ErrGenericTypeMapLength ErrK8sQuantityMustBeInt - ErrTargetColumnIntOrFloat ErrPredictionKeyOnModelWithEstimator ErrSpecifyOnlyOneMissing ErrEnvSchemaMismatch @@ -117,7 +116,6 @@ var errorKinds = []string{ "err_type_map_zero_length", "err_generic_type_map_length", "err_k8s_quantity_must_be_int", - "err_target_column_int_or_float", "err_prediction_key_on_model_with_estimator", "err_specify_only_one_missing", "err_env_schema_mismatch", @@ -507,13 +505,6 @@ func ErrorK8sQuantityMustBeInt(quantityStr string) error { } } -func ErrorTargetColumnIntOrFloat() error { - return Error{ - Kind: ErrTargetColumnIntOrFloat, - message: "models can only predict values of type INT_COLUMN (i.e. classification) or FLOAT_COLUMN (i.e. regression)", - } -} - func ErrorPredictionKeyOnModelWithEstimator() error { return Error{ Kind: ErrPredictionKeyOnModelWithEstimator, diff --git a/pkg/operator/api/userconfig/estimators.go b/pkg/operator/api/userconfig/estimators.go index 83bf653ad5..d6fe913334 100644 --- a/pkg/operator/api/userconfig/estimators.go +++ b/pkg/operator/api/userconfig/estimators.go @@ -25,12 +25,12 @@ type Estimators []*Estimator type Estimator struct { ResourceFields - TargetColumn ColumnType `json:"target_column" yaml:"target_column"` - Input *InputSchema `json:"input" yaml:"input"` - TrainingInput *InputSchema `json:"training_input" yaml:"training_input"` - Hparams *InputSchema `json:"hparams" yaml:"hparams"` - PredictionKey string `json:"prediction_key" yaml:"prediction_key"` - Path string `json:"path" yaml:"path"` + TargetColumn *CompoundType `json:"target_column" yaml:"target_column"` + Input *InputSchema `json:"input" yaml:"input"` + TrainingInput *InputSchema `json:"training_input" yaml:"training_input"` + Hparams *InputSchema `json:"hparams" yaml:"hparams"` + PredictionKey string `json:"prediction_key" yaml:"prediction_key"` + Path string `json:"path" yaml:"path"` } var estimatorValidation = &cr.StructValidation{ @@ -54,16 +54,20 @@ var estimatorValidation = &cr.StructValidation{ StructField: "TargetColumn", StringValidation: &cr.StringValidation{ Required: true, - Validator: func(col string) (string, error) { - colType := ColumnTypeFromString(col) - if colType != IntegerColumnType && colType != FloatColumnType { - return "", ErrorTargetColumnIntOrFloat() + Validator: func(colStr string) (string, error) { + _, err := CompoundTypeFromString(colStr) + if err != nil { + return "", err } - return col, nil + return colStr, nil }, }, - Parser: func(str string) (interface{}, error) { - return ColumnTypeFromString(str), nil + Parser: func(colStr string) (interface{}, error) { + colType, err := CompoundTypeFromString(colStr) + if err != nil { + return nil, err + } + return &colType, nil }, }, { diff --git a/pkg/operator/context/estimators.go b/pkg/operator/context/estimators.go index 14da666cdd..08086feaa2 100644 --- a/pkg/operator/context/estimators.go +++ b/pkg/operator/context/estimators.go @@ -67,7 +67,7 @@ func loadUserEstimators( ResourceFields: userconfig.ResourceFields{ Name: implHash, }, - TargetColumn: userconfig.InferredColumnType, + TargetColumn: nil, PredictionKey: modelConfig.PredictionKey, Path: *modelConfig.EstimatorPath, } diff --git a/pkg/operator/context/models.go b/pkg/operator/context/models.go index e163d3ba83..5fdac9f3a7 100644 --- a/pkg/operator/context/models.go +++ b/pkg/operator/context/models.go @@ -108,11 +108,8 @@ func getModels( if targetColumn == nil { return nil, errors.Wrap(userconfig.ErrorUndefinedResource(targetColumnName, resource.RawColumnType, resource.TransformedColumnType), userconfig.Identify(modelConfig), userconfig.TargetColumnKey) } - if targetColumn.GetColumnType() != userconfig.IntegerColumnType && targetColumn.GetColumnType() != userconfig.FloatColumnType { - return nil, userconfig.ErrorTargetColumnIntOrFloat() - } - if estimator.TargetColumn != userconfig.InferredColumnType { - if targetColumn.GetColumnType() != estimator.TargetColumn { + if estimator.TargetColumn != nil { + if !estimator.TargetColumn.SupportsType(targetColumn.GetColumnType()) { return nil, errors.Wrap(userconfig.ErrorUnsupportedOutputType(targetColumn.GetColumnType(), estimator.TargetColumn), userconfig.Identify(modelConfig), userconfig.TargetColumnKey) } } From b48b936f84c7abed526033f12322907b395ff197 Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Sat, 8 Jun 2019 10:16:28 -0700 Subject: [PATCH 07/44] Fix bugs --- pkg/operator/api/userconfig/errors.go | 11 ++++++++++- pkg/operator/api/userconfig/estimators.go | 5 ++++- pkg/operator/context/models.go | 6 ++++++ pkg/operator/context/resources.go | 9 ++++++++- 4 files changed, 28 insertions(+), 3 deletions(-) diff --git a/pkg/operator/api/userconfig/errors.go b/pkg/operator/api/userconfig/errors.go index d5a284835f..078d21d0dc 100644 --- a/pkg/operator/api/userconfig/errors.go +++ b/pkg/operator/api/userconfig/errors.go @@ -65,6 +65,7 @@ const ( ErrUnsupportedOutputType ErrMustBeDefined ErrCannotBeNull + ErrUnsupportedConfigKey ErrTypeListLength ErrTypeMapZeroLength ErrGenericTypeMapLength @@ -112,6 +113,7 @@ var errorKinds = []string{ "err_unsupported_output_type", "err_must_be_defined", "err_cannot_be_null", + "error_unsupported_config_key", "err_type_list_length", "err_type_map_zero_length", "err_generic_type_map_length", @@ -364,7 +366,7 @@ func ErrorCannotMixValueAndColumnTypes(provided interface{}) error { func ErrorColumnTypeLiteral(provided interface{}) error { return Error{ Kind: ErrColumnTypeLiteral, - message: fmt.Sprintf("%s: literal values cannot be provided for column input types", s.UserStrStripped(provided)), + message: fmt.Sprintf("%s: literal values cannot be provided for column input types (e.g. use FLOAT_COLUMN instead of FLOAT)", s.UserStrStripped(provided)), } } @@ -477,6 +479,13 @@ func ErrorCannotBeNull() error { } } +func ErrorUnsupportedConfigKey() error { + return Error{ + Kind: ErrUnsupportedConfigKey, + message: "is not supported for this resource", + } +} + func ErrorTypeListLength(provided interface{}) error { return Error{ Kind: ErrTypeListLength, diff --git a/pkg/operator/api/userconfig/estimators.go b/pkg/operator/api/userconfig/estimators.go index d6fe913334..4ac2030231 100644 --- a/pkg/operator/api/userconfig/estimators.go +++ b/pkg/operator/api/userconfig/estimators.go @@ -55,10 +55,13 @@ var estimatorValidation = &cr.StructValidation{ StringValidation: &cr.StringValidation{ Required: true, Validator: func(colStr string) (string, error) { - _, err := CompoundTypeFromString(colStr) + colType, err := CompoundTypeFromString(colStr) if err != nil { return "", err } + if colType.IsValues() { + return "", ErrorColumnTypeLiteral(colStr) + } return colStr, nil }, }, diff --git a/pkg/operator/context/models.go b/pkg/operator/context/models.go index 5fdac9f3a7..6285e1ef69 100644 --- a/pkg/operator/context/models.go +++ b/pkg/operator/context/models.go @@ -79,6 +79,9 @@ func getModels( modelConfig.Input = castedInput // TrainingInput + if modelConfig.EstimatorPath == nil && estimator.TrainingInput == nil && modelConfig.TrainingInput != nil { + return nil, errors.Wrap(userconfig.ErrorUnsupportedConfigKey(), userconfig.Identify(modelConfig), userconfig.TrainingInputKey) + } castedTrainingInput, trainingInputID, err := ValidateInput( modelConfig.TrainingInput, estimator.TrainingInput, @@ -94,6 +97,9 @@ func getModels( modelConfig.TrainingInput = castedTrainingInput // Hparams + if modelConfig.EstimatorPath == nil && estimator.Hparams == nil && modelConfig.Hparams != nil { + return nil, errors.Wrap(userconfig.ErrorUnsupportedConfigKey(), userconfig.Identify(modelConfig), userconfig.HparamsKey) + } if estimator.Hparams != nil { castedHparams, err := userconfig.CastInputValue(modelConfig.Hparams, estimator.Hparams) if err != nil { diff --git a/pkg/operator/context/resources.go b/pkg/operator/context/resources.go index 6230e83898..cbab68bca4 100644 --- a/pkg/operator/context/resources.go +++ b/pkg/operator/context/resources.go @@ -51,8 +51,15 @@ func ValidateInput( castedInput := input - // Skip validation if schema is nil (i.e. user didn't define the aggregator/transformer/estimator) if schema != nil { + if input == nil { + if schema.Optional { + return nil, hash.Any(nil), nil + } else { + return nil, "", userconfig.ErrorMustBeDefined(schema) + } + } + castedInput, err = validateRuntimeTypes(input, schema, validResourcesMap, aggregators, transformers, false) if err != nil { return nil, "", err From 09a8452327c22b51b3038c422425c741631055c2 Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Sat, 8 Jun 2019 10:20:15 -0700 Subject: [PATCH 08/44] Add estimators, update transformers and aggregators --- examples/iris/implementations/models/dnn.py | 17 - examples/iris/resources/models.yaml | 4 +- pkg/aggregators/bucket_boundaries.py | 4 +- pkg/aggregators/class_distribution.py | 4 +- pkg/aggregators/index_string.py | 4 +- pkg/estimators/boosted_trees_classifier.py | 91 +++ pkg/estimators/boosted_trees_regressor.py | 76 +++ pkg/estimators/dnn_classifier.py | 82 +++ .../dnn_linear_combined_classifier.py | 125 ++++ .../dnn_linear_combined_regressor.py | 110 ++++ pkg/estimators/dnn_regressor.py | 67 ++ pkg/estimators/estimators.yaml | 570 ++++++++++++++++++ pkg/estimators/linear_classifier.py | 66 ++ pkg/estimators/linear_regressor.py | 51 ++ pkg/transformers/bucketize.py | 10 +- pkg/transformers/index_string.py | 16 +- pkg/transformers/normalize.py | 12 +- 17 files changed, 1265 insertions(+), 44 deletions(-) delete mode 100644 examples/iris/implementations/models/dnn.py create mode 100644 pkg/estimators/boosted_trees_classifier.py create mode 100644 pkg/estimators/boosted_trees_regressor.py create mode 100644 pkg/estimators/dnn_classifier.py create mode 100644 pkg/estimators/dnn_linear_combined_classifier.py create mode 100644 pkg/estimators/dnn_linear_combined_regressor.py create mode 100644 pkg/estimators/dnn_regressor.py create mode 100644 pkg/estimators/linear_classifier.py create mode 100644 pkg/estimators/linear_regressor.py diff --git a/examples/iris/implementations/models/dnn.py b/examples/iris/implementations/models/dnn.py deleted file mode 100644 index c3bd71c895..0000000000 --- a/examples/iris/implementations/models/dnn.py +++ /dev/null @@ -1,17 +0,0 @@ -import tensorflow as tf - - -def create_estimator(run_config, model_config): - feature_columns = [ - tf.feature_column.numeric_column("sepal_length_normalized"), - tf.feature_column.numeric_column("sepal_width_normalized"), - tf.feature_column.numeric_column("petal_length_normalized"), - tf.feature_column.numeric_column("petal_width_normalized"), - ] - - return tf.estimator.DNNClassifier( - feature_columns=feature_columns, - hidden_units=model_config["hparams"]["hidden_units"], - n_classes=len(model_config["aggregates"]["class_index"]["index"]), - config=run_config, - ) diff --git a/examples/iris/resources/models.yaml b/examples/iris/resources/models.yaml index eedacfa142..761e7f6d24 100644 --- a/examples/iris/resources/models.yaml +++ b/examples/iris/resources/models.yaml @@ -1,9 +1,9 @@ - kind: model name: dnn - estimator_path: implementations/models/dnn.py + estimator: cortex.dnn_classifier target_column: @class_indexed input: - cols: + numeric_columns: - @sepal_length_normalized - @sepal_width_normalized - @petal_length_normalized diff --git a/pkg/aggregators/bucket_boundaries.py b/pkg/aggregators/bucket_boundaries.py index 1959c76631..26b8cdf2ca 100644 --- a/pkg/aggregators/bucket_boundaries.py +++ b/pkg/aggregators/bucket_boundaries.py @@ -13,11 +13,11 @@ # limitations under the License. -def aggregate_spark(data, columns, args): +def aggregate_spark(data, input): from pyspark.ml.feature import QuantileDiscretizer discretizer = QuantileDiscretizer( - numBuckets=args["num_buckets"], inputCol=columns["col"], outputCol="_" + numBuckets=input["num_buckets"], inputCol=input["col"], outputCol="_" ).fit(data) return discretizer.getSplits() diff --git a/pkg/aggregators/class_distribution.py b/pkg/aggregators/class_distribution.py index cc7cc2ef5b..276d2bce3b 100644 --- a/pkg/aggregators/class_distribution.py +++ b/pkg/aggregators/class_distribution.py @@ -13,11 +13,11 @@ # limitations under the License. -def aggregate_spark(data, columns, args): +def aggregate_spark(data, input): import pyspark.sql.functions as F from functools import reduce - rows = data.groupBy(F.col(columns["col"])).count().orderBy(F.col("count").desc()).collect() + rows = data.groupBy(F.col(input)).count().orderBy(F.col("count").desc()).collect() sum = float(reduce(lambda x, y: x + y, (r[1] for r in rows))) diff --git a/pkg/aggregators/index_string.py b/pkg/aggregators/index_string.py index 1012057e02..0770b36d68 100644 --- a/pkg/aggregators/index_string.py +++ b/pkg/aggregators/index_string.py @@ -13,10 +13,10 @@ # limitations under the License. -def aggregate_spark(data, columns, args): +def aggregate_spark(data, input): from pyspark.ml.feature import StringIndexer - indexer = StringIndexer(inputCol=columns["col"]) + indexer = StringIndexer(inputCol=input) index = indexer.fit(data).labels reversed_index = {v: k for k, v in enumerate(index)} return {"index": index, "reversed_index": reversed_index} diff --git a/pkg/estimators/boosted_trees_classifier.py b/pkg/estimators/boosted_trees_classifier.py new file mode 100644 index 0000000000..67594299a3 --- /dev/null +++ b/pkg/estimators/boosted_trees_classifier.py @@ -0,0 +1,91 @@ +import tensorflow as tf + + +def create_estimator(run_config, model_config): + feature_columns = [] + + for col_name in model_config["input"]["numeric_columns"]: + feature_columns.append(tf.feature_column.numeric_column(col_name)) + + for col_info in model_config["input"]["categorical_columns_with_vocab"]: + col = tf.feature_column.categorical_column_with_vocabulary_list( + col_info["col"], col_info["vocab"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + if "embedding_size" in col_info: + col = tf.feature_column.embedding_column(col, col_info["embedding_size"]) + else: + col = tf.feature_column.indicator_column(col) + + feature_columns.append(col) + + for col_info in model_config["input"]["categorical_columns_with_identity"]: + col = tf.feature_column.categorical_columns_with_identity( + col_info["col"], col_info["num_classes"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + if "embedding_size" in col_info: + col = tf.feature_column.embedding_column(col, col_info["embedding_size"]) + else: + col = tf.feature_column.indicator_column(col) + + feature_columns.append(col) + + for col_info in model_config["input"]["categorical_columns_with_hash_bucket"]: + col = tf.feature_column.categorical_columns_with_hash_bucket( + col_info["col"], col_info["hash_bucket_size"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + if "embedding_size" in col_info: + col = tf.feature_column.embedding_column(col, col_info["embedding_size"]) + else: + col = tf.feature_column.indicator_column(col) + + feature_columns.append(col) + + for col_info in model_config["input"]["bucketized_columns"]: + feature_columns.append( + tf.feature_column.bucketized_column( + tf.feature_column.numeric_column(col_info["col"]), col_info["boundaries"] + ) + ) + + if "num_classes" in model_config["input"] and "target_vocab" in model_config["input"]: + raise ValueError('either "num_classes" or "target_vocab" must be specified, but not both') + + if "num_classes" not in model_config["input"] and "target_vocab" not in model_config["input"]: + raise ValueError('either "num_classes" or "target_vocab" must be specified') + + if "num_classes" in model_config["input"]: + target_vocab = None + num_classes = model_config["input"]["num_classes"] + else: + target_vocab = model_config["input"]["target_vocab"] + num_classes = len(target_vocab) + + return tf.estimator.BoostedTreesClassifier( + feature_columns=feature_columns, + n_batches_per_layer=model_config["hparams"]["batches_per_layer"], + n_classes=num_classes, + label_vocabulary=target_vocab, + weight_column=model_config["input"].get("weight_column", None), + n_trees=model_config["hparams"]["num_trees"], + max_depth=model_config["hparams"]["max_depth"], + learning_rate=model_config["hparams"]["learning_rate"], + l1_regularization=model_config["hparams"]["l1_regularization"], + l2_regularization=model_config["hparams"]["l2_regularization"], + tree_complexity=model_config["hparams"]["tree_complexity"], + min_node_weight=model_config["hparams"]["min_node_weight"], + center_bias=model_config["hparams"]["center_bias"], + quantile_sketch_epsilon=model_config["hparams"]["quantile_sketch_epsilon"], + config=run_config, + ) diff --git a/pkg/estimators/boosted_trees_regressor.py b/pkg/estimators/boosted_trees_regressor.py new file mode 100644 index 0000000000..3dfc3321e8 --- /dev/null +++ b/pkg/estimators/boosted_trees_regressor.py @@ -0,0 +1,76 @@ +import tensorflow as tf + + +def create_estimator(run_config, model_config): + feature_columns = [] + + for col_name in model_config["input"]["numeric_columns"]: + feature_columns.append(tf.feature_column.numeric_column(col_name)) + + for col_info in model_config["input"]["categorical_columns_with_vocab"]: + col = tf.feature_column.categorical_column_with_vocabulary_list( + col_info["col"], col_info["vocab"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + if "embedding_size" in col_info: + col = tf.feature_column.embedding_column(col, col_info["embedding_size"]) + else: + col = tf.feature_column.indicator_column(col) + + feature_columns.append(col) + + for col_info in model_config["input"]["categorical_columns_with_identity"]: + col = tf.feature_column.categorical_columns_with_identity( + col_info["col"], col_info["num_classes"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + if "embedding_size" in col_info: + col = tf.feature_column.embedding_column(col, col_info["embedding_size"]) + else: + col = tf.feature_column.indicator_column(col) + + feature_columns.append(col) + + for col_info in model_config["input"]["categorical_columns_with_hash_bucket"]: + col = tf.feature_column.categorical_columns_with_hash_bucket( + col_info["col"], col_info["hash_bucket_size"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + if "embedding_size" in col_info: + col = tf.feature_column.embedding_column(col, col_info["embedding_size"]) + else: + col = tf.feature_column.indicator_column(col) + + feature_columns.append(col) + + for col_info in model_config["input"]["bucketized_columns"]: + feature_columns.append( + tf.feature_column.bucketized_column( + tf.feature_column.numeric_column(col_info["col"]), col_info["boundaries"] + ) + ) + + return tf.estimator.BoostedTreesClassifier( + feature_columns=feature_columns, + n_batches_per_layer=model_config["hparams"]["batches_per_layer"], + weight_column=model_config["input"].get("weight_column", None), + n_trees=model_config["hparams"]["num_trees"], + max_depth=model_config["hparams"]["max_depth"], + learning_rate=model_config["hparams"]["learning_rate"], + l1_regularization=model_config["hparams"]["l1_regularization"], + l2_regularization=model_config["hparams"]["l2_regularization"], + tree_complexity=model_config["hparams"]["tree_complexity"], + min_node_weight=model_config["hparams"]["min_node_weight"], + center_bias=model_config["hparams"]["center_bias"], + quantile_sketch_epsilon=model_config["hparams"]["quantile_sketch_epsilon"], + config=run_config, + ) diff --git a/pkg/estimators/dnn_classifier.py b/pkg/estimators/dnn_classifier.py new file mode 100644 index 0000000000..de5bfcc599 --- /dev/null +++ b/pkg/estimators/dnn_classifier.py @@ -0,0 +1,82 @@ +import tensorflow as tf + + +def create_estimator(run_config, model_config): + feature_columns = [] + + for col_name in model_config["input"]["numeric_columns"]: + feature_columns.append(tf.feature_column.numeric_column(col_name)) + + for col_info in model_config["input"]["categorical_columns_with_vocab"]: + col = tf.feature_column.categorical_column_with_vocabulary_list( + col_info["col"], col_info["vocab"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + if "embedding_size" in col_info: + col = tf.feature_column.embedding_column(col, col_info["embedding_size"]) + else: + col = tf.feature_column.indicator_column(col) + + feature_columns.append(col) + + for col_info in model_config["input"]["categorical_columns_with_identity"]: + col = tf.feature_column.categorical_columns_with_identity( + col_info["col"], col_info["num_classes"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + if "embedding_size" in col_info: + col = tf.feature_column.embedding_column(col, col_info["embedding_size"]) + else: + col = tf.feature_column.indicator_column(col) + + feature_columns.append(col) + + for col_info in model_config["input"]["categorical_columns_with_hash_bucket"]: + col = tf.feature_column.categorical_columns_with_hash_bucket( + col_info["col"], col_info["hash_bucket_size"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + if "embedding_size" in col_info: + col = tf.feature_column.embedding_column(col, col_info["embedding_size"]) + else: + col = tf.feature_column.indicator_column(col) + + feature_columns.append(col) + + for col_info in model_config["input"]["bucketized_columns"]: + feature_columns.append( + tf.feature_column.bucketized_column( + tf.feature_column.numeric_column(col_info["col"]), col_info["boundaries"] + ) + ) + + if "num_classes" in model_config["input"] and "target_vocab" in model_config["input"]: + raise ValueError('either "num_classes" or "target_vocab" must be specified, but not both') + + if "num_classes" not in model_config["input"] and "target_vocab" not in model_config["input"]: + raise ValueError('either "num_classes" or "target_vocab" must be specified') + + if "num_classes" in model_config["input"]: + target_vocab = None + num_classes = model_config["input"]["num_classes"] + else: + target_vocab = model_config["input"]["target_vocab"] + num_classes = len(target_vocab) + + return tf.estimator.DNNClassifier( + feature_columns=feature_columns, + n_classes=num_classes, + label_vocabulary=target_vocab, + hidden_units=model_config["hparams"]["hidden_units"], + weight_column=model_config["input"].get("weight_column", None), + config=run_config, + ) diff --git a/pkg/estimators/dnn_linear_combined_classifier.py b/pkg/estimators/dnn_linear_combined_classifier.py new file mode 100644 index 0000000000..f9e4410a94 --- /dev/null +++ b/pkg/estimators/dnn_linear_combined_classifier.py @@ -0,0 +1,125 @@ +import tensorflow as tf + + +def create_estimator(run_config, model_config): + dnn_feature_columns = [] + + for col_name in model_config["input"]["dnn_columns"]["numeric_columns"]: + feature_columns.append(tf.feature_column.numeric_column(col_name)) + + for col_info in model_config["input"]["dnn_columns"]["categorical_columns_with_vocab"]: + col = tf.feature_column.categorical_column_with_vocabulary_list( + col_info["col"], col_info["vocab"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + if "embedding_size" in col_info: + col = tf.feature_column.embedding_column(col, col_info["embedding_size"]) + else: + col = tf.feature_column.indicator_column(col) + + dnn_feature_columns.append(col) + + for col_info in model_config["input"]["dnn_columns"]["categorical_columns_with_identity"]: + col = tf.feature_column.categorical_columns_with_identity( + col_info["col"], col_info["num_classes"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + if "embedding_size" in col_info: + col = tf.feature_column.embedding_column(col, col_info["embedding_size"]) + else: + col = tf.feature_column.indicator_column(col) + + dnn_feature_columns.append(col) + + for col_info in model_config["input"]["dnn_columns"]["categorical_columns_with_hash_bucket"]: + col = tf.feature_column.categorical_columns_with_hash_bucket( + col_info["col"], col_info["hash_bucket_size"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + if "embedding_size" in col_info: + col = tf.feature_column.embedding_column(col, col_info["embedding_size"]) + else: + col = tf.feature_column.indicator_column(col) + + dnn_feature_columns.append(col) + + for col_info in model_config["input"]["dnn_columns"]["bucketized_columns"]: + dnn_feature_columns.append( + tf.feature_column.bucketized_column( + tf.feature_column.numeric_column(col_info["col"]), col_info["boundaries"] + ) + ) + + linear_feature_columns = [] + + for col_name in model_config["input"]["linear_columns"]["numeric_columns"]: + linear_feature_columns.append(tf.feature_column.numeric_column(col_name)) + + for col_info in model_config["input"]["linear_columns"]["categorical_columns_with_vocab"]: + col = tf.feature_column.categorical_column_with_vocabulary_list( + col_info["col"], col_info["vocab"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + linear_feature_columns.append(col) + + for col_info in model_config["input"]["linear_columns"]["categorical_columns_with_identity"]: + col = tf.feature_column.categorical_columns_with_identity( + col_info["col"], col_info["num_classes"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + linear_feature_columns.append(col) + + for col_info in model_config["input"]["linear_columns"]["categorical_columns_with_hash_bucket"]: + col = tf.feature_column.categorical_columns_with_hash_bucket( + col_info["col"], col_info["hash_bucket_size"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + linear_feature_columns.append(col) + + for col_info in model_config["input"]["linear_columns"]["bucketized_columns"]: + linear_feature_columns.append( + tf.feature_column.bucketized_column( + tf.feature_column.numeric_column(col_info["col"]), col_info["boundaries"] + ) + ) + + if "num_classes" in model_config["input"] and "target_vocab" in model_config["input"]: + raise ValueError('either "num_classes" or "target_vocab" must be specified, but not both') + + if "num_classes" not in model_config["input"] and "target_vocab" not in model_config["input"]: + raise ValueError('either "num_classes" or "target_vocab" must be specified') + + if "num_classes" in model_config["input"]: + target_vocab = None + num_classes = model_config["input"]["num_classes"] + else: + target_vocab = model_config["input"]["target_vocab"] + num_classes = len(target_vocab) + + return tf.estimator.DNNClassifier( + linear_feature_columns=linear_feature_columns, + dnn_feature_columns=dnn_feature_columns, + n_classes=num_classes, + label_vocabulary=target_vocab, + dnn_hidden_units=model_config["hparams"]["dnn_hidden_units"], + weight_column=model_config["input"].get("weight_column", None), + config=run_config, + ) diff --git a/pkg/estimators/dnn_linear_combined_regressor.py b/pkg/estimators/dnn_linear_combined_regressor.py new file mode 100644 index 0000000000..cb6730d763 --- /dev/null +++ b/pkg/estimators/dnn_linear_combined_regressor.py @@ -0,0 +1,110 @@ +import tensorflow as tf + + +def create_estimator(run_config, model_config): + dnn_feature_columns = [] + + for col_name in model_config["input"]["dnn_columns"]["numeric_columns"]: + feature_columns.append(tf.feature_column.numeric_column(col_name)) + + for col_info in model_config["input"]["dnn_columns"]["categorical_columns_with_vocab"]: + col = tf.feature_column.categorical_column_with_vocabulary_list( + col_info["col"], col_info["vocab"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + if "embedding_size" in col_info: + col = tf.feature_column.embedding_column(col, col_info["embedding_size"]) + else: + col = tf.feature_column.indicator_column(col) + + dnn_feature_columns.append(col) + + for col_info in model_config["input"]["dnn_columns"]["categorical_columns_with_identity"]: + col = tf.feature_column.categorical_columns_with_identity( + col_info["col"], col_info["num_classes"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + if "embedding_size" in col_info: + col = tf.feature_column.embedding_column(col, col_info["embedding_size"]) + else: + col = tf.feature_column.indicator_column(col) + + dnn_feature_columns.append(col) + + for col_info in model_config["input"]["dnn_columns"]["categorical_columns_with_hash_bucket"]: + col = tf.feature_column.categorical_columns_with_hash_bucket( + col_info["col"], col_info["hash_bucket_size"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + if "embedding_size" in col_info: + col = tf.feature_column.embedding_column(col, col_info["embedding_size"]) + else: + col = tf.feature_column.indicator_column(col) + + dnn_feature_columns.append(col) + + for col_info in model_config["input"]["dnn_columns"]["bucketized_columns"]: + dnn_feature_columns.append( + tf.feature_column.bucketized_column( + tf.feature_column.numeric_column(col_info["col"]), col_info["boundaries"] + ) + ) + + linear_feature_columns = [] + + for col_name in model_config["input"]["linear_columns"]["numeric_columns"]: + linear_feature_columns.append(tf.feature_column.numeric_column(col_name)) + + for col_info in model_config["input"]["linear_columns"]["categorical_columns_with_vocab"]: + col = tf.feature_column.categorical_column_with_vocabulary_list( + col_info["col"], col_info["vocab"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + linear_feature_columns.append(col) + + for col_info in model_config["input"]["linear_columns"]["categorical_columns_with_identity"]: + col = tf.feature_column.categorical_columns_with_identity( + col_info["col"], col_info["num_classes"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + linear_feature_columns.append(col) + + for col_info in model_config["input"]["linear_columns"]["categorical_columns_with_hash_bucket"]: + col = tf.feature_column.categorical_columns_with_hash_bucket( + col_info["col"], col_info["hash_bucket_size"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + linear_feature_columns.append(col) + + for col_info in model_config["input"]["linear_columns"]["bucketized_columns"]: + linear_feature_columns.append( + tf.feature_column.bucketized_column( + tf.feature_column.numeric_column(col_info["col"]), col_info["boundaries"] + ) + ) + + return tf.estimator.DNNClassifier( + linear_feature_columns=linear_feature_columns, + dnn_feature_columns=dnn_feature_columns, + dnn_hidden_units=model_config["hparams"]["dnn_hidden_units"], + weight_column=model_config["input"].get("weight_column", None), + config=run_config, + ) diff --git a/pkg/estimators/dnn_regressor.py b/pkg/estimators/dnn_regressor.py new file mode 100644 index 0000000000..4a3cf491cb --- /dev/null +++ b/pkg/estimators/dnn_regressor.py @@ -0,0 +1,67 @@ +import tensorflow as tf + + +def create_estimator(run_config, model_config): + feature_columns = [] + + for col_name in model_config["input"]["numeric_columns"]: + feature_columns.append(tf.feature_column.numeric_column(col_name)) + + for col_info in model_config["input"]["categorical_columns_with_vocab"]: + col = tf.feature_column.categorical_column_with_vocabulary_list( + col_info["col"], col_info["vocab"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + if "embedding_size" in col_info: + col = tf.feature_column.embedding_column(col, col_info["embedding_size"]) + else: + col = tf.feature_column.indicator_column(col) + + feature_columns.append(col) + + for col_info in model_config["input"]["categorical_columns_with_identity"]: + col = tf.feature_column.categorical_columns_with_identity( + col_info["col"], col_info["num_classes"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + if "embedding_size" in col_info: + col = tf.feature_column.embedding_column(col, col_info["embedding_size"]) + else: + col = tf.feature_column.indicator_column(col) + + feature_columns.append(col) + + for col_info in model_config["input"]["categorical_columns_with_hash_bucket"]: + col = tf.feature_column.categorical_columns_with_hash_bucket( + col_info["col"], col_info["hash_bucket_size"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + if "embedding_size" in col_info: + col = tf.feature_column.embedding_column(col, col_info["embedding_size"]) + else: + col = tf.feature_column.indicator_column(col) + + feature_columns.append(col) + + for col_info in model_config["input"]["bucketized_columns"]: + feature_columns.append( + tf.feature_column.bucketized_column( + tf.feature_column.numeric_column(col_info["col"]), col_info["boundaries"] + ) + ) + + return tf.estimator.DNNRegressor( + feature_columns=feature_columns, + hidden_units=model_config["hparams"]["hidden_units"], + weight_column=model_config["input"].get("weight_column", None), + config=run_config, + ) diff --git a/pkg/estimators/estimators.yaml b/pkg/estimators/estimators.yaml index ff16bd7a73..388853157f 100644 --- a/pkg/estimators/estimators.yaml +++ b/pkg/estimators/estimators.yaml @@ -11,3 +11,573 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +- kind: estimator + name: dnn_classifier + path: dnn_classifier.py + target_column: FLOAT + input: + # Specify num_classes if target is INT_COLUMN + num_classes: + _type: INT + _optional: True + + # Specify target_vocab if target is STRING_COLUMN + target_vocab: + _type: [STRING] + _optional: True + + numeric_columns: + _type: [INT_COLUMN|FLOAT_COLUMN] + _default: [] + categorical_columns_with_identity: + _type: + - col: INT_COLUMN + num_classes: INT + embedding_size: # If not specified, an indicator column will be used instead + _type: INT + _optional: true + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + categorical_columns_with_vocab: + _type: + - col: STRING_COLUMN + vocab: [STRING] + embedding_size: # If not specified, an indicator column will be used instead + _type: INT + _optional: true + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + categorical_columns_with_hash_bucket: + _type: + - col: STRING_COLUMN|INT_COLUMN + hash_bucket_size: INT + embedding_size: # If not specified, an indicator column will be used instead + _type: INT + _optional: true + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + bucketized_columns: + _type: + - col: INT_COLUMN|FLOAT_COLUMN + boundaries: [FLOAT] + _default: [] + training_input: + _type: + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: {} + hparams: + hidden_units: [INT] + +- kind: estimator + name: dnn_regressor + path: dnn_regressor.py + target_column: FLOAT_COLUMN + input: + numeric_columns: + _type: [INT_COLUMN|FLOAT_COLUMN] + _default: [] + categorical_columns_with_identity: + _type: + - col: INT_COLUMN + num_classes: INT + embedding_size: # If not specified, an indicator column will be used instead + _type: INT + _optional: true + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + categorical_columns_with_vocab: + _type: + - col: STRING_COLUMN + vocab: [STRING] + embedding_size: # If not specified, an indicator column will be used instead + _type: INT + _optional: true + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + categorical_columns_with_hash_bucket: + _type: + - col: STRING_COLUMN|INT_COLUMN + hash_bucket_size: INT + embedding_size: # If not specified, an indicator column will be used instead + _type: INT + _optional: true + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + bucketized_columns: + _type: + - col: INT_COLUMN|FLOAT_COLUMN + boundaries: [FLOAT] + _default: [] + training_input: + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + hparams: + hidden_units: [INT] + +- kind: estimator + name: linear_classifier + path: linear_classifier.py + target_column: INT_COLUMN|STRING_COLUMN + input: + # Specify num_classes if target is INT_COLUMN + num_classes: + _type: INT + _optional: True + + # Specify target_vocab if target is STRING_COLUMN + target_vocab: + _type: [STRING] + _optional: True + + numeric_columns: + _type: [INT_COLUMN|FLOAT_COLUMN] + _default: [] + categorical_columns_with_identity: + _type: + - col: INT_COLUMN + num_classes: INT + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + categorical_columns_with_vocab: + _type: + - col: STRING_COLUMN + vocab: [STRING] + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + categorical_columns_with_hash_bucket: + _type: + - col: STRING_COLUMN|INT_COLUMN + hash_bucket_size: INT + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + bucketized_columns: + _type: + - col: INT_COLUMN|FLOAT_COLUMN + boundaries: [FLOAT] + _default: [] + training_input: + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + +- kind: estimator + name: linear_regressor + path: linear_regressor.py + target_column: FLOAT_COLUMN + input: + numeric_columns: + _type: [INT_COLUMN|FLOAT_COLUMN] + _default: [] + categorical_columns_with_identity: + _type: + - col: INT_COLUMN + num_classes: INT + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + categorical_columns_with_vocab: + _type: + - col: STRING_COLUMN + vocab: [STRING] + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + categorical_columns_with_hash_bucket: + _type: + - col: STRING_COLUMN|INT_COLUMN + hash_bucket_size: INT + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + bucketized_columns: + _type: + - col: INT_COLUMN|FLOAT_COLUMN + boundaries: [FLOAT] + _default: [] + training_input: + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + +- kind: estimator + name: dnn_linear_combined_classifier + path: dnn_linear_combined_classifier.py + target_column: INT_COLUMN|STRING_COLUMN + input: + # Specify num_classes if target is INT_COLUMN + num_classes: + _type: INT + _optional: True + + # Specify target_vocab if target is STRING_COLUMN + target_vocab: + _type: [STRING] + _optional: True + + dnn_columns: + numeric_columns: + _type: [INT_COLUMN|FLOAT_COLUMN] + _default: [] + categorical_columns_with_identity: + _type: + - col: INT_COLUMN + num_classes: INT + embedding_size: # If not specified, an indicator column will be used instead + _type: INT + _optional: true + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + categorical_columns_with_vocab: + _type: + - col: STRING_COLUMN + vocab: [STRING] + embedding_size: # If not specified, an indicator column will be used instead + _type: INT + _optional: true + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + categorical_columns_with_hash_bucket: + _type: + - col: STRING_COLUMN|INT_COLUMN + hash_bucket_size: INT + embedding_size: # If not specified, an indicator column will be used instead + _type: INT + _optional: true + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + bucketized_columns: + _type: + - col: INT_COLUMN|FLOAT_COLUMN + boundaries: [FLOAT] + _default: [] + + linear_columns: + numeric_columns: + _type: [INT_COLUMN|FLOAT_COLUMN] + _default: [] + categorical_columns_with_identity: + _type: + - col: INT_COLUMN + num_classes: INT + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + categorical_columns_with_vocab: + _type: + - col: STRING_COLUMN + vocab: [STRING] + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + categorical_columns_with_hash_bucket: + _type: + - col: STRING_COLUMN|INT_COLUMN + hash_bucket_size: INT + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + bucketized_columns: + _type: + - col: INT_COLUMN|FLOAT_COLUMN + boundaries: [FLOAT] + _default: [] + + training_input: + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + hparams: + dnn_hidden_units: [INT] + +- kind: estimator + name: dnn_linear_combined_regressor + path: dnn_linear_combined_regressor.py + target_column: FLOAT_COLUMN + input: + dnn_columns: + numeric_columns: + _type: [INT_COLUMN|FLOAT_COLUMN] + _default: [] + categorical_columns_with_identity: + _type: + - col: INT_COLUMN + num_classes: INT + embedding_size: # If not specified, an indicator column will be used instead + _type: INT + _optional: true + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + categorical_columns_with_vocab: + _type: + - col: STRING_COLUMN + vocab: [STRING] + embedding_size: # If not specified, an indicator column will be used instead + _type: INT + _optional: true + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + categorical_columns_with_hash_bucket: + _type: + - col: STRING_COLUMN|INT_COLUMN + hash_bucket_size: INT + embedding_size: # If not specified, an indicator column will be used instead + _type: INT + _optional: true + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + bucketized_columns: + _type: + - col: INT_COLUMN|FLOAT_COLUMN + boundaries: [FLOAT] + _default: [] + + linear_columns: + numeric_columns: + _type: [INT_COLUMN|FLOAT_COLUMN] + _default: [] + categorical_columns_with_identity: + _type: + - col: INT_COLUMN + num_classes: INT + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + categorical_columns_with_vocab: + _type: + - col: STRING_COLUMN + vocab: [STRING] + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + categorical_columns_with_hash_bucket: + _type: + - col: STRING_COLUMN|INT_COLUMN + hash_bucket_size: INT + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + bucketized_columns: + _type: + - col: INT_COLUMN|FLOAT_COLUMN + boundaries: [FLOAT] + _default: [] + + training_input: + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + hparams: + dnn_hidden_units: [INT] + +- kind: estimator + name: boosted_trees_classifier + path: boosted_trees_classifier.py + target_column: INT_COLUMN|STRING_COLUMN + input: + # Specify num_classes if target is INT_COLUMN + num_classes: + _type: INT + _optional: True + + # Specify target_vocab if target is STRING_COLUMN + target_vocab: + _type: [STRING] + _optional: True + + numeric_columns: + _type: [INT_COLUMN|FLOAT_COLUMN] + _default: [] + categorical_columns_with_identity: + _type: + - col: INT_COLUMN + num_classes: INT + embedding_size: # If not specified, an indicator column will be used instead + _type: INT + _optional: true + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + categorical_columns_with_vocab: + _type: + - col: STRING_COLUMN + vocab: [STRING] + embedding_size: # If not specified, an indicator column will be used instead + _type: INT + _optional: true + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + categorical_columns_with_hash_bucket: + _type: + - col: STRING_COLUMN|INT_COLUMN + hash_bucket_size: INT + embedding_size: # If not specified, an indicator column will be used instead + _type: INT + _optional: true + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + bucketized_columns: + _type: + - col: INT_COLUMN|FLOAT_COLUMN + boundaries: [FLOAT] + _default: [] + training_input: + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + hparams: + batches_per_layer: INT + num_trees: + _type: INT + _default: 100 + max_depth: + _type: INT + _default: 6 + learning_rate: + _type: FLOAT + _default: 0.1 + l1_regularization: + _type: FLOAT + _default: 0 + l2_regularization: + _type: FLOAT + _default: 0 + tree_complexity: + _type: FLOAT + _default: 0 + min_node_weight: + _type: FLOAT + _default: 0 + center_bias: + _type: BOOL + _default: False + quantile_sketch_epsilon: + _type: FLOAT + _default: 0.01 + +- kind: estimator + name: boosted_trees_regressor + path: boosted_trees_regressor.py + target_column: FLOAT_COLUMN + input: + numeric_columns: + _type: [INT_COLUMN|FLOAT_COLUMN] + _default: [] + categorical_columns_with_identity: + _type: + - col: INT_COLUMN + num_classes: INT + embedding_size: # If not specified, an indicator column will be used instead + _type: INT + _optional: true + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + categorical_columns_with_vocab: + _type: + - col: STRING_COLUMN + vocab: [STRING] + embedding_size: # If not specified, an indicator column will be used instead + _type: INT + _optional: true + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + categorical_columns_with_hash_bucket: + _type: + - col: STRING_COLUMN|INT_COLUMN + hash_bucket_size: INT + embedding_size: # If not specified, an indicator column will be used instead + _type: INT + _optional: true + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: [] + bucketized_columns: + _type: + - col: INT_COLUMN|FLOAT_COLUMN + boundaries: [FLOAT] + _default: [] + training_input: + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + hparams: + batches_per_layer: INT + num_trees: + _type: INT + _default: 100 + max_depth: + _type: INT + _default: 6 + learning_rate: + _type: FLOAT + _default: 0.1 + l1_regularization: + _type: FLOAT + _default: 0 + l2_regularization: + _type: FLOAT + _default: 0 + tree_complexity: + _type: FLOAT + _default: 0 + min_node_weight: + _type: FLOAT + _default: 0 + center_bias: + _type: BOOL + _default: False + quantile_sketch_epsilon: + _type: FLOAT + _default: 0.01 diff --git a/pkg/estimators/linear_classifier.py b/pkg/estimators/linear_classifier.py new file mode 100644 index 0000000000..adc9d24afb --- /dev/null +++ b/pkg/estimators/linear_classifier.py @@ -0,0 +1,66 @@ +import tensorflow as tf + + +def create_estimator(run_config, model_config): + feature_columns = [] + + for col_name in model_config["input"]["numeric_columns"]: + feature_columns.append(tf.feature_column.numeric_column(col_name)) + + for col_info in model_config["input"]["categorical_columns_with_vocab"]: + col = tf.feature_column.categorical_column_with_vocabulary_list( + col_info["col"], col_info["vocab"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + feature_columns.append(col) + + for col_info in model_config["input"]["categorical_columns_with_identity"]: + col = tf.feature_column.categorical_columns_with_identity( + col_info["col"], col_info["num_classes"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + feature_columns.append(col) + + for col_info in model_config["input"]["categorical_columns_with_hash_bucket"]: + col = tf.feature_column.categorical_columns_with_hash_bucket( + col_info["col"], col_info["hash_bucket_size"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + feature_columns.append(col) + + for col_info in model_config["input"]["bucketized_columns"]: + feature_columns.append( + tf.feature_column.bucketized_column( + tf.feature_column.numeric_column(col_info["col"]), col_info["boundaries"] + ) + ) + + if "num_classes" in model_config["input"] and "target_vocab" in model_config["input"]: + raise ValueError('either "num_classes" or "target_vocab" must be specified, but not both') + + if "num_classes" not in model_config["input"] and "target_vocab" not in model_config["input"]: + raise ValueError('either "num_classes" or "target_vocab" must be specified') + + if "num_classes" in model_config["input"]: + target_vocab = None + num_classes = model_config["input"]["num_classes"] + else: + target_vocab = model_config["input"]["target_vocab"] + num_classes = len(target_vocab) + + return tf.estimator.DNNClassifier( + feature_columns=feature_columns, + n_classes=num_classes, + label_vocabulary=target_vocab, + weight_column=model_config["input"].get("weight_column", None), + config=run_config, + ) diff --git a/pkg/estimators/linear_regressor.py b/pkg/estimators/linear_regressor.py new file mode 100644 index 0000000000..cf95399ddc --- /dev/null +++ b/pkg/estimators/linear_regressor.py @@ -0,0 +1,51 @@ +import tensorflow as tf + + +def create_estimator(run_config, model_config): + feature_columns = [] + + for col_name in model_config["input"]["numeric_columns"]: + feature_columns.append(tf.feature_column.numeric_column(col_name)) + + for col_info in model_config["input"]["categorical_columns_with_vocab"]: + col = tf.feature_column.categorical_column_with_vocabulary_list( + col_info["col"], col_info["vocab"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + feature_columns.append(col) + + for col_info in model_config["input"]["categorical_columns_with_identity"]: + col = tf.feature_column.categorical_columns_with_identity( + col_info["col"], col_info["num_classes"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + feature_columns.append(col) + + for col_info in model_config["input"]["categorical_columns_with_hash_bucket"]: + col = tf.feature_column.categorical_columns_with_hash_bucket( + col_info["col"], col_info["hash_bucket_size"] + ) + + if "weight_column" in col_info: + col = tf.feature_column.weighted_categorical_column(col, col_info["weight_column"]) + + feature_columns.append(col) + + for col_info in model_config["input"]["bucketized_columns"]: + feature_columns.append( + tf.feature_column.bucketized_column( + tf.feature_column.numeric_column(col_info["col"]), col_info["boundaries"] + ) + ) + + return tf.estimator.DNNRegressor( + feature_columns=feature_columns, + weight_column=model_config["input"].get("weight_column", None), + config=run_config, + ) diff --git a/pkg/transformers/bucketize.py b/pkg/transformers/bucketize.py index d0b93c7512..a38a4cae0b 100644 --- a/pkg/transformers/bucketize.py +++ b/pkg/transformers/bucketize.py @@ -13,21 +13,21 @@ # limitations under the License. -def transform_spark(data, columns, args, transformed_column_name): +def transform_spark(data, input, transformed_column_name): from pyspark.ml.feature import Bucketizer import pyspark.sql.functions as F new_b = Bucketizer( - splits=args["bucket_boundaries"], inputCol=columns["num"], outputCol=transformed_column_name + splits=input["bucket_boundaries"], inputCol=input["col"], outputCol=transformed_column_name ) return new_b.transform(data).withColumn( transformed_column_name, F.col(transformed_column_name).cast("int") ) -def transform_python(sample, args): - num = sample["num"] - buckets = args["bucket_boundaries"][1:] +def transform_python(input): + num = input["col"] + buckets = input["bucket_boundaries"][1:] for id, v in enumerate(buckets): if num < v: return id diff --git a/pkg/transformers/index_string.py b/pkg/transformers/index_string.py index 8ab183b57f..06410bcda7 100644 --- a/pkg/transformers/index_string.py +++ b/pkg/transformers/index_string.py @@ -13,12 +13,12 @@ # limitations under the License. -def transform_spark(data, columns, args, transformed_column_name): +def transform_spark(data, input, transformed_column_name): from pyspark.ml.feature import StringIndexerModel import pyspark.sql.functions as F indexer = StringIndexerModel.from_labels( - args["indexes"]["index"], inputCol=columns["text"], outputCol=transformed_column_name + input["indexes"]["index"], inputCol=input["col"], outputCol=transformed_column_name ) return indexer.transform(data).withColumn( @@ -26,12 +26,12 @@ def transform_spark(data, columns, args, transformed_column_name): ) -def transform_python(sample, args): - if sample["text"] in args["indexes"]["reversed_index"]: - return args["indexes"]["reversed_index"][sample["text"]] +def transform_python(input): + if input["col"] in input["indexes"]["reversed_index"]: + return input["indexes"]["reversed_index"][input["col"]] - raise Exception("Could not find {} in index: {}".format(sample["text"], args)) + raise Exception("Could not find {} in index".format(input["col"])) -def reverse_transform_python(transformed_value, args): - return args["indexes"]["index"][transformed_value] +def reverse_transform_python(transformed_value, input): + return input["indexes"]["index"][transformed_value] diff --git a/pkg/transformers/normalize.py b/pkg/transformers/normalize.py index e24faddd2e..75a3d92c80 100644 --- a/pkg/transformers/normalize.py +++ b/pkg/transformers/normalize.py @@ -13,15 +13,15 @@ # limitations under the License. -def transform_spark(data, columns, args, transformed_column_name): +def transform_spark(data, input, transformed_column_name): return data.withColumn( - transformed_column_name, ((data[columns["num"]] - args["mean"]) / args["stddev"]) + transformed_column_name, ((data[input["col"]] - input["mean"]) / input["stddev"]) ) -def transform_python(sample, args): - return (sample["num"] - args["mean"]) / args["stddev"] +def transform_python(input): + return (input["col"] - input["mean"]) / input["stddev"] -def reverse_transform_python(transformed_value, args): - return args["mean"] + (transformed_value * args["stddev"]) +def reverse_transform_python(transformed_value, input): + return input["mean"] + (transformed_value * input["stddev"]) From ebc016f4f763ed22c41c6b21a5903d58f6c3ceb7 Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Mon, 10 Jun 2019 18:48:49 -0700 Subject: [PATCH 09/44] Progress --- images/operator/Dockerfile | 1 + pkg/aggregators/aggregators.yaml | 10 +- pkg/estimators/estimators.yaml | 2 +- pkg/operator/api/userconfig/compound_type.go | 7 - pkg/workloads/consts.py | 2 + pkg/workloads/lib/context.py | 641 ++++++++++++++----- pkg/workloads/lib/test/util_test.py | 134 ++-- pkg/workloads/lib/tf_lib.py | 14 +- pkg/workloads/lib/util.py | 174 +++-- pkg/workloads/spark_job/spark_job.py | 13 +- pkg/workloads/spark_job/spark_util.py | 335 ++++++---- pkg/workloads/tf_api/api.py | 91 +-- pkg/workloads/tf_train/train.py | 4 +- pkg/workloads/tf_train/train_util.py | 37 +- 14 files changed, 958 insertions(+), 507 deletions(-) diff --git a/images/operator/Dockerfile b/images/operator/Dockerfile index b4db9f5002..5767564b17 100644 --- a/images/operator/Dockerfile +++ b/images/operator/Dockerfile @@ -21,6 +21,7 @@ RUN chmod +x /usr/local/bin/kubectl COPY pkg/transformers /src/transformers COPY pkg/aggregators /src/aggregators +COPY pkg/estimators /src/estimators COPY --from=builder /go/src/github.com/cortexlabs/cortex/pkg/operator/operator /root/ RUN chmod +x /root/operator diff --git a/pkg/aggregators/aggregators.yaml b/pkg/aggregators/aggregators.yaml index afe99e85b0..b72fac69af 100644 --- a/pkg/aggregators/aggregators.yaml +++ b/pkg/aggregators/aggregators.yaml @@ -21,7 +21,9 @@ output_type: INT input: col: FLOAT_COLUMN|INT_COLUMN|STRING_COLUMN - rsd: FLOAT + rsd: + _type: FLOAT + _optional: TRUE # Spark Builtin: Calculate the average of the column. # source: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.avg @@ -70,10 +72,8 @@ path: spark/count_distinct.py output_type: INT input: - col: INT_COLUMN|FLOAT_COLUMN|STRING_COLUMN - cols: - _type: [INT_COLUMN|FLOAT_COLUMN|STRING_COLUMN] - _default: [] + _type: [INT_COLUMN|FLOAT_COLUMN|STRING_COLUMN] + _min_count: 1 # Spark Builtin: Calculate the population covariance between col1 and col2 (scaled by 1/N). # source: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.covar_pop diff --git a/pkg/estimators/estimators.yaml b/pkg/estimators/estimators.yaml index 388853157f..43a508645c 100644 --- a/pkg/estimators/estimators.yaml +++ b/pkg/estimators/estimators.yaml @@ -15,7 +15,7 @@ - kind: estimator name: dnn_classifier path: dnn_classifier.py - target_column: FLOAT + target_column: INT_COLUMN|STRING_COLUMN input: # Specify num_classes if target is INT_COLUMN num_classes: diff --git a/pkg/operator/api/userconfig/compound_type.go b/pkg/operator/api/userconfig/compound_type.go index 0620d7ee66..fdab9891d2 100644 --- a/pkg/operator/api/userconfig/compound_type.go +++ b/pkg/operator/api/userconfig/compound_type.go @@ -20,7 +20,6 @@ import ( "strings" "github.com/cortexlabs/cortex/pkg/lib/cast" - "github.com/cortexlabs/cortex/pkg/lib/configreader" ) type CompoundType string @@ -138,10 +137,7 @@ func (compoundType *CompoundType) CastValue(value interface{}) (interface{}, err return nil, ErrorColumnTypeLiteral(value) } - var validPrimitiveTypes []configreader.PrimitiveType - if parsed.valueTypes[IntegerValueType] { - validPrimitiveTypes = append(validPrimitiveTypes, configreader.PrimTypeInt) valueInt, ok := cast.InterfaceToInt64(value) if ok { return valueInt, nil @@ -149,7 +145,6 @@ func (compoundType *CompoundType) CastValue(value interface{}) (interface{}, err } if parsed.valueTypes[FloatValueType] { - validPrimitiveTypes = append(validPrimitiveTypes, configreader.PrimTypeFloat) valueFloat, ok := cast.InterfaceToFloat64(value) if ok { return valueFloat, nil @@ -157,14 +152,12 @@ func (compoundType *CompoundType) CastValue(value interface{}) (interface{}, err } if parsed.valueTypes[StringValueType] { - validPrimitiveTypes = append(validPrimitiveTypes, configreader.PrimTypeString) if valueStr, ok := value.(string); ok { return valueStr, nil } } if parsed.valueTypes[BoolValueType] { - validPrimitiveTypes = append(validPrimitiveTypes, configreader.PrimTypeBool) if valueBool, ok := value.(bool); ok { return valueBool, nil } diff --git a/pkg/workloads/consts.py b/pkg/workloads/consts.py index ed330d68be..01b03211a4 100644 --- a/pkg/workloads/consts.py +++ b/pkg/workloads/consts.py @@ -40,3 +40,5 @@ VALUE_TYPE_BOOL = "BOOL" VALUE_TYPES = [VALUE_TYPE_INT, VALUE_TYPE_FLOAT, VALUE_TYPE_STRING, VALUE_TYPE_BOOL] + +ALL_TYPES = set(COLUMN_LIST_TYPES + VALUE_TYPES) diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index 2bbb8f4f88..e155ee96b2 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -78,6 +78,7 @@ def __init__(self, **kwargs): self.aggregates = self.ctx["aggregates"] self.constants = self.ctx["constants"] self.models = self.ctx["models"] + self.estimators = self.ctx["estimators"] self.apis = self.ctx["apis"] self.training_datasets = {k: v["dataset"] for k, v in self.models.items()} self.api_version = self.cortex_config["api_version"] @@ -98,11 +99,7 @@ def __init__(self, **kwargs): ) ) - self.columns = util.merge_dicts_overwrite( - self.raw_columns, self.transformed_columns # self.aggregates - ) - - self.values = util.merge_dicts_overwrite(self.aggregates, self.constants) + self.columns = util.merge_dicts_overwrite(self.raw_columns, self.transformed_columns) self.raw_column_names = list(self.raw_columns.keys()) self.transformed_column_names = list(self.transformed_columns.keys()) @@ -111,8 +108,9 @@ def __init__(self, **kwargs): # Internal caches self._transformer_impls = {} self._aggregator_impls = {} - self._model_impls = {} + self._estimator_impls = {} self._metadatas = {} + self._obj_cache = {} self.spark_uploaded_impls = {} # This affects Tensorflow S3 access @@ -151,41 +149,30 @@ def is_constant(self, name): def is_aggregate(self, name): return name in self.aggregates - def create_column_inputs_map(self, values_map, column_name): - """Construct an inputs dict with actual data""" - columns_input_config = self.transformed_columns[column_name]["inputs"]["columns"] - return create_inputs_map(values_map, columns_input_config) - def download_file(self, impl_key, cache_impl_path): if not os.path.isfile(cache_impl_path): self.storage.download_file(impl_key, cache_impl_path) return cache_impl_path - def get_python_file(self, impl_key, module_name): + def download_python_file(self, impl_key, module_name): cache_impl_path = os.path.join(self.cache_dir, "{}.py".format(module_name)) self.download_file(impl_key, cache_impl_path) return cache_impl_path def get_obj(self, key): + if key in self._obj_cache: + return self._obj_cache[key] + cache_path = os.path.join(self.cache_dir, key) self.download_file(key, cache_path) - - return util.read_msgpack(cache_path) - - def populate_args(self, args_dict): - return { - arg_name: self.get_obj(self.values[value_name]["key"]) - for arg_name, value_name in args_dict.items() - } - - def store_aggregate_result(self, result, aggregate): - self.storage.put_msgpack(result, aggregate["key"]) + self._obj_cache[key] = util.read_msgpack(cache_path) + return self._obj_cache[key] def load_module(self, module_prefix, module_name, impl_key): full_module_name = "{}_{}".format(module_prefix, module_name) try: - impl_path = self.get_python_file(impl_key, full_module_name) + impl_path = self.download_python_file(impl_key, full_module_name) except CortexException as e: e.wrap("unable to find python file " + module_name) raise @@ -197,10 +184,11 @@ def load_module(self, module_prefix, module_name, impl_key): return impl, impl_path - def get_aggregator_impl(self, column_name): - aggregator_name = self.aggregates[column_name]["aggregator"] + def get_aggregator_impl(self, aggregate_name): + aggregator_name = self.aggregates[aggregate_name]["aggregator"] if aggregator_name in self._aggregator_impls: return self._aggregator_impls[aggregator_name] + aggregator = self.aggregators[aggregator_name] module_prefix = "aggregator" @@ -212,13 +200,13 @@ def get_aggregator_impl(self, column_name): module_prefix, aggregator["name"], aggregator["impl_key"] ) except CortexException as e: - e.wrap("aggregate " + column_name, "aggregator") + e.wrap("aggregate " + aggregate_name, "aggregator") raise try: _validate_impl(impl, AGGREGATOR_IMPL_VALIDATION) except CortexException as e: - e.wrap("aggregate " + column_name, "aggregator " + aggregator["name"]) + e.wrap("aggregate " + aggregate_name, "aggregator " + aggregator["name"]) raise self._aggregator_impls[aggregator_name] = (impl, impl_path) @@ -255,21 +243,33 @@ def get_transformer_impl(self, column_name): self._transformer_impls[transformer_name] = (impl, impl_path) return (impl, impl_path) - def get_model_impl(self, model_name): - if model_name in self._model_impls: - return self._model_impls[model_name] + def get_estimator_impl(self, model_name): + estimator_name = self.models[model_name]["aggregator"] + if estimator_name in self._estimator_impls: + return self._estimator_impls[estimator_name] - model = self.models[model_name] + estimator = self.estimators[estimator_name] + + module_prefix = "estimator" + if "namespace" in estimator and estimator.get("namespace", None) is not None: + module_prefix += "_" + estimator["namespace"] + + try: + impl, impl_path = self.load_module( + module_prefix, estimator["name"], estimator["impl_key"] + ) + except CortexException as e: + e.wrap("model " + model_name, "estimator") + raise try: - impl, impl_path = self.load_module("model", model_name, model["impl_key"]) _validate_impl(impl, MODEL_IMPL_VALIDATION) except CortexException as e: - e.wrap("model " + model_name) + e.wrap("model " + model_name, "estimator " + estimator["name"]) raise - self._model_impls[model_name] = impl - return impl + self._estimator_impls[estimator_name] = (impl, impl_path) + return (impl, impl_path) # Mode must be "training" or "evaluation" def get_training_data_parts(self, model_name, mode, part_prefix="part"): @@ -288,135 +288,54 @@ def get_training_data_parts(self, model_name, mode, part_prefix="part"): training_data_parts_prefix = os.path.join(data_key, part_prefix) return self.storage.search(prefix=training_data_parts_prefix) - def column_config(self, column_name): - if self.is_raw_column(column_name): - return self.raw_column_config(column_name) - elif self.is_transformed_column(column_name): - return self.transformed_column_config(column_name) - return None + def store_aggregate_result(self, result, aggregate): + self.storage.put_msgpack(result, aggregate["key"]) - def raw_column_config(self, column_name): - raw_column = self.raw_columns[column_name] - if raw_column is None: - return None - config = deepcopy(raw_column) - config_keys = ["name", "type", "required", "min", "max", "values", "tags"] - util.keep_dict_keys(config, config_keys) - return config - - def transformed_column_config(self, column_name): - transformed_column = self.transformed_columns[column_name] - if transformed_column is None: - return None - config = deepcopy(transformed_column) - config_keys = ["name", "transformer", "inputs", "tags", "type"] - util.keep_dict_keys(config, config_keys) - config["inputs"] = self._expand_inputs_config(config["inputs"]) - config["transformer"] = self.transformer_config(config["transformer"]) - return config - - def value_config(self, value_name): - if self.is_constant(value_name): - return self.constant_config(value_name) - elif self.is_aggregate(value_name): - return self.aggregate_config(value_name) - return None - - def constant_config(self, constant_name): - constant = self.constants[constant_name] - if constant is None: - return None - config = deepcopy(constant) - config_keys = ["name", "type", "tags"] - util.keep_dict_keys(config, config_keys) - return config - - def aggregate_config(self, aggregate_name): - aggregate = self.aggregates[aggregate_name] - if aggregate is None: - return None - config = deepcopy(aggregate) - config_keys = ["name", "type", "inputs", "aggregator", "tags"] - util.keep_dict_keys(config, config_keys) - config["inputs"] = self._expand_inputs_config(config["inputs"]) - config["aggregator"] = self.aggregator_config(config["aggregator"]) - return config + def extract_column_names(self, input): + column_names = set() + for resource_name in util.extract_resource_refs(input): + if resource_name in self.columns: + column_names.add(resource_name) + return column_names def model_config(self, model_name): model = self.models[model_name] if model is None: return None + estimator = self.estimators[model["estimator"]] + model_config = deepcopy(model) config_keys = [ "name", - "type", - "path", - "target_column", - "prediction_key", - "feature_columns", - "training_columns", - "hparams", - "data_partition_ratio", - "aggregates", - "training", - "evaluation", + "estimator" + "estimator_path" + "target_column" + "input" + "training_input" + "hparams" + "prediction_key" + "data_partition_ratio" + "training" + "evaluation" "tags", ] util.keep_dict_keys(model_config, config_keys) - for i, column_name in enumerate(model_config["feature_columns"]): - model_config["feature_columns"][i] = self.column_config(column_name) - - for i, column_name in enumerate(model_config["training_columns"]): - model_config["training_columns"][i] = self.column_config(column_name) - - model_config["target_column"] = self.column_config(model_config["target_column"]) - - aggregates_dict = {key: key for key in model_config["aggregates"]} - model_config["aggregates"] = self.populate_args(aggregates_dict) + model_config["target_column"] = util.get_resource_ref(model["target_column"]) + model_config["input"] = self.populate_values( + model["input"], estimator["input"], preserve_column_refs=False + ) + if model["training_input"] is not None: + model_config["training_input"] = self.populate_values( + model["training_input"], estimator["training_input"], preserve_column_refs=False + ) + if model["hparams"] is not None: + model_config["hparams"] = self.populate_values( + model["hparams"], estimator["hparams"], preserve_column_refs=False + ) return model_config - def aggregator_config(self, aggregator_name): - aggregator = self.aggregators[aggregator_name] - if aggregator is None: - return None - config = deepcopy(aggregator) - config_keys = ["name", "output_type", "inputs"] - util.keep_dict_keys(config, config_keys) - config["name"] = aggregator_name # Use the fully qualified name (includes namespace) - return config - - def transformer_config(self, transformer_name): - transformer = self.transformers[transformer_name] - if transformer is None: - return None - config = deepcopy(transformer) - config_keys = ["name", "output_type", "inputs"] - util.keep_dict_keys(config, config_keys) - config["name"] = transformer_name # Use the fully qualified name (includes namespace) - return config - - def _expand_inputs_config(self, inputs_config): - inputs_config["columns"] = self._expand_columns_input_dict(inputs_config["columns"]) - inputs_config["args"] = self._expand_args_dict(inputs_config["args"]) - return inputs_config - - def _expand_columns_input_dict(self, input_columns_dict): - expanded = {} - for column_name, value in input_columns_dict.items(): - if util.is_str(value): - expanded[column_name] = self.column_config(value) - elif util.is_list(value): - expanded[column_name] = [self.column_config(name) for name in value] - return expanded - - def _expand_args_dict(self, args_dict): - expanded = {} - for arg_name, value_name in args_dict.items(): - expanded[arg_name] = self.value_config(value_name) - return expanded - def get_resource_status(self, resource): key = self.resource_status_key(resource) return self.storage.get_json(key) @@ -496,23 +415,208 @@ def get_inferred_column_type(self, column_name): return column_type + # replaces column references with column names (unless preserve_column_refs = true, then leaves them untouched) + def populate_values(self, input, input_schema, preserve_column_refs): + if input is None: + if input_schema is None: + return None + if input_schema["_allow_null"]: + return None + raise UserException("Null is not allowed") + + if util.is_resource_ref(input): + res_name = util.get_resource_ref(input) + if res_name in self.constants: + const_val = self.constants[res_name]["value"] + try: + return self.populate_values(const_val, input_schema, preserve_column_refs) + except CortexException as e: + e.wrap("constant " + res_name) + raise + + if res_name in self.aggregates: + agg_val = self.get_obj(self.aggregates[res_name]["key"]) + try: + return self.populate_values(agg_val, input_schema, preserve_column_refs) + except CortexException as e: + e.wrap("aggregate " + res_name) + raise + + if res_name in self.columns: + if input_schema is not None: + col_type = self.get_inferred_column_type(res_name) + if not column_type_matches(col_type, input_schema["_type"]): + raise UserException( + "column {}: column type mismatch: got {}, expected {}".format( + res_name, col_type, input_schema["_type"] + ) + ) + if preserve_column_refs: + return input + else: + return res_name + + if util.is_list(input): + elem_schema = None + if input_schema is not None: + if not util.is_list(input_schema["_type"]): + raise UserException("unexpected type (list)") + elem_schema = input_schema["_type"][0] + + min_count = input_schema.get("_min_count") + if min_count is not None and len(input) < min_count: + raise UserException( + "list has length {}, but the minimum length is {}".format( + len(input), min_count + ) + ) + + max_count = input_schema.get("_max_count") + if max_count is not None and len(input) > max_count: + raise UserException( + "list has length {}, but the maximum length is {}".format( + len(input), max_count + ) + ) + + casted = [] + for i, elem in enumerate(input): + try: + casted.append(self.populate_values(elem, elem_schema, preserve_column_refs)) + except CortexException as e: + e.wrap("index " + i) + raise + return casted + + if util.is_dict(input): + if input_schema is None: + casted = {} + for key, val in input.items(): + key_casted = self.populate_values(key, None, preserve_column_refs) + try: + val_casted = self.populate_values(val, None, preserve_column_refs) + except CortexException as e: + e.wrap(util.pp_str_flat(key_casted)) + raise + casted[key_casted] = val_casted + return casted + + if not util.is_dict(input_schema["_type"]): + raise UserException("unexpected type (map)") + + min_count = input_schema.get("_min_count") + if min_count is not None and len(input) < min_count: + raise UserException( + "map has length {}, but the minimum length is {}".format(len(input), min_count) + ) + + max_count = input_schema.get("_max_count") + if max_count is not None and len(input) > max_count: + raise UserException( + "map has length {}, but the maximum length is {}".format(len(input), max_count) + ) + + is_generic_map = False + if len(input_schema["_type"]) == 1: + input_type_key = next(iter(input_schema["_type"].keys())) + if is_compound_type(input_type_key): + is_generic_map = True + generic_map_key = input_type_key + generic_map_value = input_schema["_type"][input_type_key] + + if is_generic_map: + casted = {} + for key, val in input.items(): + key_casted = self.populate_values(key, generic_map_key, preserve_column_refs) + try: + val_casted = self.populate_values( + val, generic_map_value, preserve_column_refs + ) + except CortexException as e: + e.wrap(util.pp_str_flat(key_casted)) + raise + casted[key_casted] = val_casted + return casted + + # fixed map + casted = {} + for key, val_schema in input_schema["_type"]: + default = None + if key not in input: + if input_schema.get("_optional") is not True: + raise UserException("missing key: " + util.pp_str_flat(key)) + if input_schema.get("_default") is None: + continue + default = input_schema["_default"] + + val = input.get(key, default) + try: + val_casted = self.populate_values(val, val_schema, preserve_column_refs) + except CortexException as e: + e.wrap(util.pp_str_flat(key)) + raise + casted[key] = val_casted + return casted + + if input_schema is None: + return input + if util.is_list(input_schema["_type"]) or util.is_dict(input_schema["_type"]): + raise UserException("unexpected type (scalar)") + return cast_compound_type(input, input_schema["_type"]) + + +def is_compound_type(type_str): + if not util.is_str(type_str): + return False + for subtype in type_str.split("|"): + if subtype not in consts.ALL_TYPES: + return False + return True + + +def column_type_matches(value_type, schema_type): + if consts.COLUMN_TYPE_FLOAT in schema_type: + schema_type = schema_type + "|" + consts.COLUMN_TYPE_INT + return value_type in schema_type + + +def cast_compound_type(value, type_str): + allowed_types = type_str.split("|") + if consts.VALUE_TYPE_INT in allowed_types: + if util.is_int(value): + return value + if consts.VALUE_TYPE_FLOAT in allowed_types: + if util.is_int(value): + return float(value) + if util.is_float(value): + return value + if consts.VALUE_TYPE_STRING in allowed_types: + if util.is_string(value): + return value + if consts.VALUE_TYPE_BOOL in allowed_types: + if util.is_bool(value): + return value + + raise UserException( + "input value's type is not supported by the schema (got {}, expected input with type {})".format( + util.pp_str_flat(value), type_str + ) + ) + MODEL_IMPL_VALIDATION = { "required": [{"name": "create_estimator", "args": ["run_config", "model_config"]}], "optional": [{"name": "transform_tensorflow", "args": ["features", "labels", "model_config"]}], } -AGGREGATOR_IMPL_VALIDATION = { - "required": [{"name": "aggregate_spark", "args": ["data", "columns", "args"]}] -} +AGGREGATOR_IMPL_VALIDATION = {"required": [{"name": "aggregate_spark", "args": ["data", "input"]}]} TRANSFORMER_IMPL_VALIDATION = { "optional": [ - {"name": "transform_spark", "args": ["data", "columns", "args", "transformed_column_name"]}, - {"name": "reverse_transform_python", "args": ["transformed_value", "args"]}, - # it is possible to not define transform_python() - # if you are only using the transformation for training columns, and not for inference - {"name": "transform_python", "args": ["sample", "args"]}, + {"name": "transform_spark", "args": ["data", "input", "transformed_column_name"]}, + # it is possible to not define transform_python() if you are only using the transformation for training columns, and not for inference + {"name": "transform_python", "args": ["input"]}, + {"name": "reverse_transform_python", "args": ["transformed_value", "input"]}, ] } @@ -559,21 +663,6 @@ def _validate_required_fn_args(impl, fn_name, args): ) -def create_inputs_map(values_map, input_config): - inputs = {} - for input_name, input_config_item in input_config.items(): - if util.is_str(input_config_item): - inputs[input_name] = values_map[input_config_item] - elif util.is_int(input_config_item): - inputs[input_name] = values_map[input_config_item] - elif util.is_list(input_config_item): - inputs[input_name] = [values_map[f] for f in input_config_item] - else: - raise CortexException("invalid column inputs") - - return inputs - - def _deserialize_raw_ctx(raw_ctx): raw_columns = raw_ctx["raw_columns"] raw_ctx["raw_columns"] = util.merge_dicts_overwrite(*raw_columns.values()) @@ -587,3 +676,211 @@ def _deserialize_raw_ctx(raw_ctx): else: raise CortexException("expected csv_data or parquet_data but found " + data_split) return raw_ctx + + +# input should already have non-column arguments replaced, and all types validated +def create_transformer_inputs_from_map(input, input_schema, col_value_map): + if util.is_str(input): + res_name = util.get_resource_ref(input) + if res_name in col_value_map: + value = col_value_map[res_name] + if input_schema is not None: + valid_col_types = input_schema["_type"].split("|") + if util.is_int(value) and consts.COLUMN_TYPE_INT not in valid_col_types: + value = float(value) + if util.is_int_list(value) and consts.COLUMN_TYPE_INT_LIST not in valid_col_types: + value = [float(elem) for elem in value] + return value + return input + + if util.is_list(input): + replaced = [] + for item in input: + sub_schema = None + if input_schema is not None: + sub_schema = input_schema["_type"][0] + replaced.append(create_transformer_inputs_from_map(item, sub_schema, col_value_map)) + return replaced + + if util.is_dict(input): + replaced = {} + + if input_schema is not None: + is_generic_map = False + generic_key_sub_schema = None + generic_val_sub_schema = None + if len(input_schema["_type"]) == 1: + input_type_key = next(iter(input_schema["_type"].keys())) + if is_compound_type(input_type_key): + is_generic_map = True + generic_key_sub_schema = { + "_type": input_type_key, + "_optional": False, + "_default": None, + "_allow_null": False, + "_min_count": None, + "_max_count": None, + } + generic_val_sub_schema = input_schema["_type"][input_type_key] + + for key, val in input.items(): + key_sub_schema = None + val_sub_schema = None + if input_schema is not None: + if is_generic_map: + key_sub_schema = generic_key_sub_schema + val_sub_schema = generic_val_sub_schema + else: + val_sub_schema = input_schema["_type"].get(key) + + key_replaced = create_transformer_inputs_from_map(key, key_sub_schema, col_value_map) + val_replaced = create_transformer_inputs_from_map(val, val_sub_schema, col_value_map) + replaced[key_replaced] = val_replaced + return replaced + + return input + + +# input should already have non-column arguments replaced, and all types validated +def create_transformer_inputs_from_lists(input, input_schema, input_cols_sorted, col_values): + col_value_map = {} + for col_name, col_value in zip(input_cols_sorted, col_values): + col_value_map[col_name] = col_value + + return create_transformer_inputs_from_map(input, input_schema, col_value_map, col_type_map) + + +# def create_column_inputs_map(self, values_map, column_name): +# """Construct an inputs dict with actual data""" +# columns_input_config = self.transformed_columns[column_name]["inputs"]["columns"] +# return create_inputs_map(values_map, columns_input_config) + +# def create_inputs_map(values_map, input_config): +# inputs = {} +# for input_name, input_config_item in input_config.items(): +# if util.is_str(input_config_item): +# inputs[input_name] = values_map[input_config_item] +# elif util.is_int(input_config_item): +# inputs[input_name] = values_map[input_config_item] +# elif util.is_list(input_config_item): +# inputs[input_name] = [values_map[f] for f in input_config_item] +# else: +# raise CortexException("invalid column inputs") + +# return inputs + +# def populate_args(self, args_dict): +# return { +# arg_name: self.get_obj(self.values[value_name]["key"]) +# for arg_name, value_name in args_dict.items() +# } + +# def get_model_impl(self, model_name): +# if model_name in self._model_impls: +# return self._model_impls[model_name] + +# model = self.models[model_name] + +# try: +# impl, impl_path = self.load_module("model", model_name, model["impl_key"]) +# _validate_impl(impl, MODEL_IMPL_VALIDATION) +# except CortexException as e: +# e.wrap("model " + model_name) +# raise + +# self._model_impls[model_name] = impl +# return impl + +# def column_config(self, column_name): +# if self.is_raw_column(column_name): +# return self.raw_column_config(column_name) +# elif self.is_transformed_column(column_name): +# return self.transformed_column_config(column_name) +# return None + +# def raw_column_config(self, column_name): +# raw_column = self.raw_columns[column_name] +# if raw_column is None: +# return None +# config = deepcopy(raw_column) +# config_keys = ["name", "type", "required", "min", "max", "values", "tags"] +# util.keep_dict_keys(config, config_keys) +# return config + +# def transformed_column_config(self, column_name): +# transformed_column = self.transformed_columns[column_name] +# if transformed_column is None: +# return None +# config = deepcopy(transformed_column) +# config_keys = ["name", "transformer", "inputs", "tags", "type"] +# util.keep_dict_keys(config, config_keys) +# config["inputs"] = self._expand_inputs_config(config["inputs"]) +# config["transformer"] = self.transformer_config(config["transformer"]) +# return config + +# def value_config(self, value_name): +# if self.is_constant(value_name): +# return self.constant_config(value_name) +# elif self.is_aggregate(value_name): +# return self.aggregate_config(value_name) +# return None + +# def constant_config(self, constant_name): +# constant = self.constants[constant_name] +# if constant is None: +# return None +# config = deepcopy(constant) +# config_keys = ["name", "type", "tags"] +# util.keep_dict_keys(config, config_keys) +# return config + +# def aggregate_config(self, aggregate_name): +# aggregate = self.aggregates[aggregate_name] +# if aggregate is None: +# return None +# config = deepcopy(aggregate) +# config_keys = ["name", "type", "inputs", "aggregator", "tags"] +# util.keep_dict_keys(config, config_keys) +# config["inputs"] = self._expand_inputs_config(config["inputs"]) +# config["aggregator"] = self.aggregator_config(config["aggregator"]) +# return config + +# def aggregator_config(self, aggregator_name): +# aggregator = self.aggregators[aggregator_name] +# if aggregator is None: +# return None +# config = deepcopy(aggregator) +# config_keys = ["name", "output_type", "inputs"] +# util.keep_dict_keys(config, config_keys) +# config["name"] = aggregator_name # Use the fully qualified name (includes namespace) +# return config + +# def transformer_config(self, transformer_name): +# transformer = self.transformers[transformer_name] +# if transformer is None: +# return None +# config = deepcopy(transformer) +# config_keys = ["name", "output_type", "inputs"] +# util.keep_dict_keys(config, config_keys) +# config["name"] = transformer_name # Use the fully qualified name (includes namespace) +# return config + +# def _expand_inputs_config(self, inputs_config): +# inputs_config["columns"] = self._expand_columns_input_dict(inputs_config["columns"]) +# inputs_config["args"] = self._expand_args_dict(inputs_config["args"]) +# return inputs_config + +# def _expand_columns_input_dict(self, input_columns_dict): +# expanded = {} +# for column_name, value in input_columns_dict.items(): +# if util.util.is_str(value): +# expanded[column_name] = self.column_config(value) +# elif util.is_list(value): +# expanded[column_name] = [self.column_config(name) for name in value] +# return expanded + +# def _expand_args_dict(self, args_dict): +# expanded = {} +# for arg_name, value_name in args_dict.items(): +# expanded[arg_name] = self.value_config(value_name) +# return expanded diff --git a/pkg/workloads/lib/test/util_test.py b/pkg/workloads/lib/test/util_test.py index d9a75a903f..4ad0c193e8 100644 --- a/pkg/workloads/lib/test/util_test.py +++ b/pkg/workloads/lib/test/util_test.py @@ -166,64 +166,66 @@ def test_validate_column_type(): assert util.validate_column_type(["2", "string"], "STRING_LIST_COLUMN") == True -def test_validate_value_type(): - assert util.validate_value_type(2, "INT") == True - assert util.validate_value_type(2.2, "INT") == False - assert util.validate_value_type("2", "INT") == False - assert util.validate_value_type(None, "INT") == True - - assert util.validate_value_type(2.2, "FLOAT") == True - assert util.validate_value_type(2, "FLOAT") == False - assert util.validate_value_type("2", "FLOAT") == False - assert util.validate_value_type(None, "FLOAT") == True - - assert util.validate_value_type(False, "BOOL") == True - assert util.validate_value_type(2, "BOOL") == False - assert util.validate_value_type("2", "BOOL") == False - assert util.validate_value_type(None, "BOOL") == True - - assert util.validate_value_type(2.2, "INT|FLOAT") == True - assert util.validate_value_type(2, "FLOAT|INT") == True - assert util.validate_value_type("2", "FLOAT|INT") == False - assert util.validate_value_type(None, "INT|FLOAT") == True - - assert util.validate_value_type({"test": 2.2}, {"STRING": "FLOAT"}) == True - assert util.validate_value_type({"test": 2.2, "test2": 3.3}, {"STRING": "FLOAT"}) == True - assert util.validate_value_type({}, {"STRING": "FLOAT"}) == True - assert util.validate_value_type({"test": "2.2"}, {"STRING": "FLOAT"}) == False - assert util.validate_value_type({2: 2.2}, {"STRING": "FLOAT"}) == False - - assert util.validate_value_type({"test": 2.2}, {"STRING": "INT|FLOAT"}) == True - assert util.validate_value_type({"a": 2.2, "b": False}, {"STRING": "FLOAT|BOOL"}) == True - assert util.validate_value_type({"test": 2.2, "test2": 3}, {"STRING": "FLOAT|BOOL"}) == False - assert util.validate_value_type({}, {"STRING": "INT|FLOAT"}) == True - assert util.validate_value_type({"test": "2.2"}, {"STRING": "FLOAT|INT"}) == False - assert util.validate_value_type({2: 2.2}, {"STRING": "INT|FLOAT"}) == False - - assert util.validate_value_type({"f": 2.2, "i": 2}, {"f": "FLOAT", "i": "INT"}) == True - assert util.validate_value_type({"f": 2.2, "i": 2.2}, {"f": "FLOAT", "i": "INT"}) == False - assert util.validate_value_type({"f": "s", "i": 2}, {"f": "FLOAT", "i": "INT"}) == False - assert util.validate_value_type({"f": 2.2}, {"f": "FLOAT", "i": "INT"}) == False - assert util.validate_value_type({"f": 2.2, "i": None}, {"f": "FLOAT", "i": "INT"}) == True - assert util.validate_value_type({"f": 0.2, "i": 2, "e": 1}, {"f": "FLOAT", "i": "INT"}) == False - - assert util.validate_value_type(["s"], ["STRING"]) == True - assert util.validate_value_type(["a", "b", "c"], ["STRING"]) == True - assert util.validate_value_type([], ["STRING"]) == True - assert util.validate_value_type(None, ["STRING"]) == True - assert util.validate_value_type([2], ["STRING"]) == False - assert util.validate_value_type(["a", False, "c"], ["STRING"]) == False - assert util.validate_value_type("a", ["STRING"]) == False - - assert util.validate_value_type([2], ["FLOAT|INT|BOOL"]) == True - assert util.validate_value_type([2.2], ["FLOAT|INT|BOOL"]) == True - assert util.validate_value_type([False], ["FLOAT|INT|BOOL"]) == True - assert util.validate_value_type([2, 2.2, False], ["FLOAT|INT|BOOL"]) == True - assert util.validate_value_type([], ["FLOAT|INT|BOOL"]) == True - assert util.validate_value_type(None, ["FLOAT|INT|BOOL"]) == True - assert util.validate_value_type([2, "s", True], ["FLOAT|INT|BOOL"]) == False - assert util.validate_value_type(["s"], ["FLOAT|INT|BOOL"]) == False - assert util.validate_value_type(2, ["FLOAT|INT|BOOL"]) == False +def test_validate_output_type(): + assert util.validate_output_type(2, "INT") == True + assert util.validate_output_type(2.2, "INT") == False + assert util.validate_output_type("2", "INT") == False + assert util.validate_output_type(None, "INT") == True + + assert util.validate_output_type(2.2, "FLOAT") == True + assert util.validate_output_type(2, "FLOAT") == False + assert util.validate_output_type("2", "FLOAT") == False + assert util.validate_output_type(None, "FLOAT") == True + + assert util.validate_output_type(False, "BOOL") == True + assert util.validate_output_type(2, "BOOL") == False + assert util.validate_output_type("2", "BOOL") == False + assert util.validate_output_type(None, "BOOL") == True + + assert util.validate_output_type(2.2, "INT|FLOAT") == True + assert util.validate_output_type(2, "FLOAT|INT") == True + assert util.validate_output_type("2", "FLOAT|INT") == False + assert util.validate_output_type(None, "INT|FLOAT") == True + + assert util.validate_output_type({"test": 2.2}, {"STRING": "FLOAT"}) == True + assert util.validate_output_type({"test": 2.2, "test2": 3.3}, {"STRING": "FLOAT"}) == True + assert util.validate_output_type({}, {"STRING": "FLOAT"}) == True + assert util.validate_output_type({"test": "2.2"}, {"STRING": "FLOAT"}) == False + assert util.validate_output_type({2: 2.2}, {"STRING": "FLOAT"}) == False + + assert util.validate_output_type({"test": 2.2}, {"STRING": "INT|FLOAT"}) == True + assert util.validate_output_type({"a": 2.2, "b": False}, {"STRING": "FLOAT|BOOL"}) == True + assert util.validate_output_type({"test": 2.2, "test2": 3}, {"STRING": "FLOAT|BOOL"}) == False + assert util.validate_output_type({}, {"STRING": "INT|FLOAT"}) == True + assert util.validate_output_type({"test": "2.2"}, {"STRING": "FLOAT|INT"}) == False + assert util.validate_output_type({2: 2.2}, {"STRING": "INT|FLOAT"}) == False + + assert util.validate_output_type({"f": 2.2, "i": 2}, {"f": "FLOAT", "i": "INT"}) == True + assert util.validate_output_type({"f": 2.2, "i": 2.2}, {"f": "FLOAT", "i": "INT"}) == False + assert util.validate_output_type({"f": "s", "i": 2}, {"f": "FLOAT", "i": "INT"}) == False + assert util.validate_output_type({"f": 2.2}, {"f": "FLOAT", "i": "INT"}) == False + assert util.validate_output_type({"f": 2.2, "i": None}, {"f": "FLOAT", "i": "INT"}) == True + assert ( + util.validate_output_type({"f": 0.2, "i": 2, "e": 1}, {"f": "FLOAT", "i": "INT"}) == False + ) + + assert util.validate_output_type(["s"], ["STRING"]) == True + assert util.validate_output_type(["a", "b", "c"], ["STRING"]) == True + assert util.validate_output_type([], ["STRING"]) == True + assert util.validate_output_type(None, ["STRING"]) == True + assert util.validate_output_type([2], ["STRING"]) == False + assert util.validate_output_type(["a", False, "c"], ["STRING"]) == False + assert util.validate_output_type("a", ["STRING"]) == False + + assert util.validate_output_type([2], ["FLOAT|INT|BOOL"]) == True + assert util.validate_output_type([2.2], ["FLOAT|INT|BOOL"]) == True + assert util.validate_output_type([False], ["FLOAT|INT|BOOL"]) == True + assert util.validate_output_type([2, 2.2, False], ["FLOAT|INT|BOOL"]) == True + assert util.validate_output_type([], ["FLOAT|INT|BOOL"]) == True + assert util.validate_output_type(None, ["FLOAT|INT|BOOL"]) == True + assert util.validate_output_type([2, "s", True], ["FLOAT|INT|BOOL"]) == False + assert util.validate_output_type(["s"], ["FLOAT|INT|BOOL"]) == False + assert util.validate_output_type(2, ["FLOAT|INT|BOOL"]) == False value_type = { "map": {"STRING": "FLOAT"}, @@ -270,7 +272,7 @@ def test_validate_value_type(): }, }, } - assert util.validate_value_type(value, value_type) == True + assert util.validate_output_type(value, value_type) == True value = { "map": {"a": 2.2, "b": float(3)}, @@ -292,7 +294,7 @@ def test_validate_value_type(): "testB": None, }, } - assert util.validate_value_type(value, value_type) == True + assert util.validate_output_type(value, value_type) == True value = { "map": {"a": 2.2, "b": float(3)}, @@ -320,7 +322,7 @@ def test_validate_value_type(): }, }, } - assert util.validate_value_type(value, value_type) == False + assert util.validate_output_type(value, value_type) == False value = { "map": {"a": 2.2, "b": float(3)}, @@ -349,7 +351,7 @@ def test_validate_value_type(): }, }, } - assert util.validate_value_type(value, value_type) == False + assert util.validate_output_type(value, value_type) == False value = { "map": {"a": 2.2, "b": float(3)}, @@ -378,7 +380,7 @@ def test_validate_value_type(): }, }, } - assert util.validate_value_type(value, value_type) == False + assert util.validate_output_type(value, value_type) == False value = { "map": {"a": 2.2, "b": float(3)}, @@ -407,7 +409,7 @@ def test_validate_value_type(): }, }, } - assert util.validate_value_type(value, value_type) == False + assert util.validate_output_type(value, value_type) == False value = { "map": {"a": 2.2, "b": float(3)}, @@ -436,7 +438,7 @@ def test_validate_value_type(): }, }, } - assert util.validate_value_type(value, value_type) == False + assert util.validate_output_type(value, value_type) == False value = { "map": {"a": 2.2, "b": float(3)}, @@ -465,4 +467,4 @@ def test_validate_value_type(): }, }, } - assert util.validate_value_type(value, value_type) == False + assert util.validate_output_type(value, value_type) == False diff --git a/pkg/workloads/lib/tf_lib.py b/pkg/workloads/lib/tf_lib.py index 09962f5d50..86eb2b4791 100644 --- a/pkg/workloads/lib/tf_lib.py +++ b/pkg/workloads/lib/tf_lib.py @@ -40,16 +40,16 @@ def get_column_tf_types(model_name, ctx, training=True): model = ctx.models[model_name] column_types = {} - for column_name in model["feature_columns"]: + for column_name in ctx.extract_column_names(model["input"]): column_type = ctx.get_inferred_column_type(column_name) column_types[column_name] = CORTEX_TYPE_TO_TF_TYPE[column_type] if training: - target_column_name = model["target_column"] - column_type = ctx.get_inferred_column_type(target_column_name) - column_types[target_column_name] = CORTEX_TYPE_TO_TF_TYPE[column_type] + for column_name in ctx.extract_column_names(model["target_column"]): + column_type = ctx.get_inferred_column_type(column_name) + column_types[column_name] = CORTEX_TYPE_TO_TF_TYPE[column_type] - for column_name in model["training_columns"]: + for column_name in ctx.extract_column_names(model.get("training_input")): column_type = ctx.get_inferred_column_type(column_name) column_types[column_name] = CORTEX_TYPE_TO_TF_TYPE[column_type] @@ -74,12 +74,12 @@ def get_feature_spec(model_name, ctx, training=True): def get_base_input_columns(model_name, ctx): model = ctx.models[model_name] base_column_names = set() - for column_name in model["feature_columns"]: + for column_name in ctx.extract_column_names(model["input"]): if ctx.is_raw_column(column_name): base_column_names.add(column_name) else: transformed_column = ctx.transformed_columns[column_name] - for name in util.flatten_all_values(transformed_column["inputs"]["columns"]): + for name in ctx.extract_column_names(transformed_column["input"]): base_column_names.add(name) return [ctx.raw_columns[name] for name in base_column_names] diff --git a/pkg/workloads/lib/util.py b/pkg/workloads/lib/util.py index b9ab3fb2a1..9d96b722b8 100644 --- a/pkg/workloads/lib/util.py +++ b/pkg/workloads/lib/util.py @@ -671,25 +671,18 @@ def log_job_finished(workload_id): CORTEX_TYPE_TO_VALIDATOR = { consts.COLUMN_TYPE_INT: is_int, consts.COLUMN_TYPE_INT_LIST: is_int_list, - consts.COLUMN_TYPE_FLOAT: is_float, - consts.COLUMN_TYPE_FLOAT_LIST: is_float_list, + consts.COLUMN_TYPE_FLOAT: is_float_or_int, + consts.COLUMN_TYPE_FLOAT_LIST: is_float_or_int_list, consts.COLUMN_TYPE_STRING: is_str, consts.COLUMN_TYPE_STRING_LIST: is_str_list, consts.VALUE_TYPE_INT: is_int, - consts.VALUE_TYPE_FLOAT: is_float, + consts.VALUE_TYPE_FLOAT: is_float_or_int, consts.VALUE_TYPE_STRING: is_str, consts.VALUE_TYPE_BOOL: is_bool, } -CORTEX_TYPE_TO_UPCAST_VALIDATOR = merge_dicts_overwrite( - CORTEX_TYPE_TO_VALIDATOR, - { - consts.COLUMN_TYPE_FLOAT: is_float_or_int, - consts.COLUMN_TYPE_FLOAT_LIST: is_float_or_int_list, - }, -) - CORTEX_TYPE_TO_UPCASTER = { + consts.VALUE_TYPE_FLOAT: lambda x: float(x), consts.COLUMN_TYPE_FLOAT: lambda x: float(x), consts.COLUMN_TYPE_FLOAT_LIST: lambda ls: [float(item) for item in ls], } @@ -717,12 +710,12 @@ def validate_column_type(value, column_type): return False -def validate_value_type(value, value_type): +def validate_output_type(value, output_type): if value is None: return True - if is_str(value_type): - valid_types = value_type.split("|") + if is_str(output_type): + valid_types = output_type.split("|") for valid_type in valid_types: if CORTEX_TYPE_TO_VALIDATOR[valid_type](value): @@ -730,47 +723,150 @@ def validate_value_type(value, value_type): return False - if is_dict(value_type): + if is_list(output_type): + if not (len(output_type) == 1 and is_str(output_type[0])): + return False + if not is_list(value): + return False + for value_item in value: + if not validate_output_type(value_item, output_type[0]): + return False + return True + + if is_dict(output_type): if not is_dict(value): return False - if len(value_type) == 0: - if len(value) == 0: - return True + if len(output_type) == 0: return False is_generic_map = False - if len(value_type) == 1: - value_type_key = next(iter(value_type.keys())) - if value_type_key in consts.VALUE_TYPES: + if len(output_type) == 1: + output_type_key = next(iter(output_type.keys())) + if output_type_key in consts.VALUE_TYPES: is_generic_map = True - generic_map_key = value_type_key - generic_map_value = value_type[value_type_key] + generic_map_key = output_type_key + generic_map_value = output_type[output_type_key] if is_generic_map: for value_key, value_val in value.items(): - if not validate_value_type(value_key, generic_map_key): + if not validate_output_type(value_key, generic_map_key): return False - if not validate_value_type(value_val, generic_map_value): + if not validate_output_type(value_val, generic_map_value): return False return True - if len(value) != len(value_type): - return False + # Fixed map for value_key, value_val in value.items(): - if value_key not in value_type: + if value_key not in output_type: return False - if not validate_value_type(value_val, value_type[value_key]): - return False - return True - - if is_list(value_type): - if not (len(value_type) == 1 and is_str(value_type[0])): - return False - if not is_list(value): - return False - for value_item in value: - if not validate_value_type(value_item, value_type[0]): + if not validate_output_type(value_val, output_type[value_key]): return False return True return False + + +# Casts int -> float. Input is assumed to be already validated +def cast_output_type(value, output_type): + if is_str(output_type): + if ( + is_int(value) + and consts.VALUE_TYPE_FLOAT in output_type + and consts.VALUE_TYPE_INT not in output_type + ): + return float(value) + return value + + if is_list(output_type): + casted = [] + for item in value: + casted.append(cast_output_type(item, output_type[0])) + return casted + + if is_dict(output_type): + is_generic_map = False + if len(output_type) == 1: + output_type_key = next(iter(output_type.keys())) + if output_type_key in consts.VALUE_TYPES: + is_generic_map = True + generic_map_key = output_type_key + generic_map_value = output_type[output_type_key] + + if is_generic_map: + casted = {} + for value_key, value_val in value.items(): + casted_key = cast_output_type(value_key, generic_map_key) + casted_val = cast_output_type(value_val, generic_map_value) + casted[casted_key] = casted_val + return casted + + # Fixed map + casted = {} + for output_type_key, output_type_val in output_type.items(): + casted_val = cast_output_type(value[output_type_key], output_type_val) + casted[output_type_key] = casted_val + return casted + + return value + + +escape_seq = "🌝🌝🌝🌝🌝" + + +def is_resource_ref(obj): + if not is_str(obj): + return False + return obj.startswith(escape_seq) + + +def get_resource_ref(obj): + if not is_str(obj): + return None + if not obj.startswith(escape_seq): + return None + return obj[len(escape_seq) :] + + +def extract_resource_refs(input): + if is_str(input): + res = util.get_resource_ref(input) + if res is not None: + return set(res) + return set() + + if is_dict(input): + resources = set() + for key, val in input.items(): + resources = resources.union(extract_resource_refs(key)) + resources = resources.union(extract_resource_refs(val)) + return resources + + if is_list(input): + resources = set() + for item in input: + resources = resources.union(extract_resource_refs(val)) + return resources + + return set() + + +# def replace_resource_refs(input): +# if is_str(input): +# res = util.get_resource_ref(input) +# if res is not None: +# return res +# return input + +# if is_dict(input): +# replaced = {} +# for key, val in input.items(): +# replaced[replace_resource_refs(key)] = replace_resource_refs(val) +# return replaced + +# if is_list(input): +# replaced = [] +# for item in input: +# replaced.append(replace_resource_refs(item)) +# return replaced + +# return input diff --git a/pkg/workloads/spark_job/spark_job.py b/pkg/workloads/spark_job/spark_job.py index 7d20715b9c..9e839e57d1 100644 --- a/pkg/workloads/spark_job/spark_job.py +++ b/pkg/workloads/spark_job/spark_job.py @@ -194,7 +194,7 @@ def ingest_raw_dataset(spark, ctx, cols_to_validate, should_ingest): return raw_df -def run_custom_aggregators(spark, ctx, cols_to_aggregate, raw_df): +def run_aggregators(spark, ctx, cols_to_aggregate, raw_df): logger.info("Aggregating") results = {} @@ -245,14 +245,7 @@ def validate_transformers(spark, ctx, cols_to_transform, raw_df): try: input_columns_dict = transformed_column["inputs"]["columns"] - input_cols = [] - - for k in sorted(input_columns_dict.keys()): - if util.is_list(input_columns_dict[k]): - input_cols += sorted(input_columns_dict[k]) - else: - input_cols.append(input_columns_dict[k]) - + input_cols = sorted(ctx.extract_column_names(transformed_column["input"])) tf_name = transformed_column["name"] logger.info("Transforming {} to {}".format(", ".join(input_cols), tf_name)) @@ -317,7 +310,7 @@ def run_job(args): raw_df = ingest_raw_dataset(spark, ctx, cols_to_validate, should_ingest) if len(cols_to_aggregate) > 0: - run_custom_aggregators(spark, ctx, cols_to_aggregate, raw_df) + run_aggregators(spark, ctx, cols_to_aggregate, raw_df) if len(cols_to_transform) > 0: validate_transformers(spark, ctx, cols_to_transform, raw_df) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 9c65a7f6b6..1bd2875cc2 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -42,8 +42,13 @@ CORTEX_TYPE_TO_ACCEPTABLE_SPARK_TYPES = { consts.COLUMN_TYPE_INT: [IntegerType(), LongType()], consts.COLUMN_TYPE_INT_LIST: [ArrayType(IntegerType(), True), ArrayType(LongType(), True)], - consts.COLUMN_TYPE_FLOAT: [FloatType(), DoubleType()], - consts.COLUMN_TYPE_FLOAT_LIST: [ArrayType(FloatType(), True), ArrayType(DoubleType(), True)], + consts.COLUMN_TYPE_FLOAT: [FloatType(), DoubleType(), IntegerType(), LongType()], + consts.COLUMN_TYPE_FLOAT_LIST: [ + ArrayType(FloatType(), True), + ArrayType(DoubleType(), True), + ArrayType(IntegerType(), True), + ArrayType(LongType(), True), + ], consts.COLUMN_TYPE_STRING: [StringType()], consts.COLUMN_TYPE_STRING_LIST: [ArrayType(StringType(), True)], } @@ -143,7 +148,9 @@ def log_df_schema(df, logger_func=logger.info): def write_training_data(model_name, df, ctx, spark): model = ctx.models[model_name] training_dataset = model["dataset"] - column_names = model["feature_columns"] + [model["target_column"]] + model["training_columns"] + column_names = ctx.extract_column_names( + [model["input"], model["target_column"], model.get("training_input")] + ) df = df.select(*column_names) @@ -328,7 +335,8 @@ def read_csv(ctx, spark): "expected " + len(data_config["schema"]) + " column(s) but got " + len(df.columns) ) - renamed_cols = [F.col(c).alias(data_config["schema"][idx]) for idx, c in enumerate(df.columns)] + col_names = [util.get_resource_ref(col_ref) for col_ref in data_config["schema"]] + renamed_cols = [F.col(c).alias(col_names[idx]) for idx, c in enumerate(df.columns)] return df.select(*renamed_cols) @@ -336,11 +344,11 @@ def read_parquet(ctx, spark): parquet_config = ctx.environment["data"] df = spark.read.parquet(parquet_config["path"]) - alias_map = { - c["parquet_column_name"]: c["raw_column_name"] - for c in parquet_config["schema"] - if c["parquet_column_name"] in ctx.raw_columns - } + alias_map = {} + for parquet_col_config in parquet_config["schema"]: + col_name = util.get_resource_ref(parquet_col_config["raw_column"]) + if col_name in ctx.raw_columns: + alias_map[col_name] = parquet_col_config["parquet_column_name"] missing_cols = set(alias_map.keys()) - set(df.columns) if len(missing_cols) > 0: @@ -348,187 +356,223 @@ def read_parquet(ctx, spark): log_df_schema(df, logger.error) raise UserException("missing column(s) in input dataset", str(missing_cols)) - selectExprs = ["{} as {}".format(alias_map[alias], alias) for alias in alias_map.keys()] + selectExprs = [ + "{} as {}".format(parq_name, col_name) for col_name, parq_name in alias_map.items() + ] return df.selectExpr(*selectExprs) -def column_names_to_index(columns_input_config): - column_list = [] - for k, v in columns_input_config.items(): - if util.is_list(v): - column_list += v - else: - column_list.append(v) - - required_input_columns_sorted = sorted(set(column_list)) - - index_to_col_map = dict( - [(column_name, idx) for idx, column_name in enumerate(required_input_columns_sorted)] - ) - - columns_input_config_indexed = create_inputs_map(index_to_col_map, columns_input_config) - return required_input_columns_sorted, columns_input_config_indexed - - # not included in this list: collect_list, grouping, grouping_id AGG_SPARK_LIST = set( - [ - "approx_count_distinct", - "avg", - "collect_set", - "count", - "countDistinct", - "kurtosis", - "max", - "mean", - "min", - "skewness", - "stddev", - "stddev_pop", - "stddev_samp", - "sum", - "sumDistinct", - "var_pop", - "var_samp", - "variance", - ] + "approx_count_distinct", + "avg", + "collect_set_int", + "collect_set_float", + "collect_set_string", + "count", + "count_distinct", + "covar_pop", + "covar_samp", + "kurtosis", + "max_int", + "max_float", + "max_string", + "mean", + "min_int", + "min_float", + "min_string", + "skewness", + "stddev", + "stddev_pop", + "stddev_samp", + "sum_int", + "sum_float", + "sum_distinct_int", + "sum_distinct_float", + "var_pop", + "var_samp", + "variance", ) -def extract_spark_name(f_name): - if f_name.endswith("_string") or f_name.endswith("_float") or f_name.endswith("_int"): - f_name = "_".join(f_name.split("_")[:-1]) - snake_case_mapping = {"sum_distinct": "sumDistinct", "count_distinct": "countDistinct"} - return snake_case_mapping.get(f_name, f_name) - - -def split_aggregators(columns_to_aggregate, ctx): - aggregate_resources = [ctx.aggregates[r] for r in columns_to_aggregate] +def split_aggregators(aggregate_names, ctx): + aggregate_resources = [ctx.aggregates[agg_name] for agg_name in aggregate_names] builtin_aggregates = [] custom_aggregates = [] - for r in aggregate_resources: - aggregator = ctx.aggregators[r["aggregator"]] - spark_name = extract_spark_name(aggregator["name"]) - if aggregator.get("namespace", None) == "cortex" and spark_name in AGG_SPARK_LIST: - builtin_aggregates.append(r) + for agg in aggregate_resources: + aggregator = ctx.aggregators[agg["aggregator"]] + if aggregator.get("namespace", None) == "cortex" and aggregator["name"] in AGG_SPARK_LIST: + builtin_aggregates.append(agg) else: - custom_aggregates.append(r) + custom_aggregates.append(agg) return builtin_aggregates, custom_aggregates def run_builtin_aggregators(builtin_aggregates, df, ctx, spark): agg_cols = [] - for r in builtin_aggregates: - aggregator = ctx.aggregators[r["aggregator"]] - f_name = extract_spark_name(aggregator["name"]) - - agg_func = getattr(F, f_name) - col_name_list = [] - columns_dict = r["inputs"]["columns"] - - if "col" in columns_dict.keys(): - col_name_list.append(columns_dict["col"]) - if "cols" in columns_dict.keys(): - col_name_list += columns_dict["cols"] - if "col1" in columns_dict.keys() and "col2" in columns_dict.keys(): - col_name_list.append(columns_dict["col1"]) - col_name_list.append(columns_dict["col2"]) - - if len(col_name_list) == 0: - raise CortexException("input columns not found in aggregator: {}".format(r)) - - args = {} - if r["inputs"].get("args", None) is not None and len(r["inputs"]["args"]) > 0: - args = ctx.populate_args(r["inputs"]["args"]) - col_list = [F.col(c) for c in col_name_list] - agg_cols.append(agg_func(*col_list, **args).alias(r["name"])) + for agg in builtin_aggregates: + aggregator = ctx.aggregators[agg["aggregator"]] + input_repl = ctx.populate_values( + agg["input"], aggregator["input"], preserve_column_refs=False + ) + + if aggregator["name"] == "approx_count_distinct": + agg_cols.append( + F.approxCountDistinct(input_repl["col"], input_repl.get("rsd")).alias(agg["name"]) + ) + if aggregator["name"] == "avg": + agg_cols.append(F.avg(input_repl).alias(agg["name"])) + if aggregator["name"] in set("collect_set_int", "collect_set_float", "collect_set_string"): + agg_cols.append(F.collect_set(input_repl).alias(agg["name"])) + if aggregator["name"] == "count": + agg_cols.append(F.count(input_repl).alias(agg["name"])) + if aggregator["name"] == "count_distinct": + agg_cols.append(F.countDistinct(*input_repl).alias(agg["name"])) + if aggregator["name"] == "covar_pop": + agg_cols.append(F.covar_pop(input_repl["col1"], input_repl["col2"]).alias(agg["name"])) + if aggregator["name"] == "covar_samp": + agg_cols.append(F.covar_samp(input_repl["col1"], input_repl["col2"]).alias(agg["name"])) + if aggregator["name"] == "kurtosis": + agg_cols.append(F.kurtosis(input_repl).alias(agg["name"])) + if aggregator["name"] in set("max_int", "max_float", "max_string"): + agg_cols.append(F.max(input_repl).alias(agg["name"])) + if aggregator["name"] == "mean": + agg_cols.append(F.mean(input_repl).alias(agg["name"])) + if aggregator["name"] in set("min_int", "min_float", "min_string"): + agg_cols.append(F.min(input_repl).alias(agg["name"])) + if aggregator["name"] == "skewness": + agg_cols.append(F.skewness(input_repl).alias(agg["name"])) + if aggregator["name"] == "stddev": + agg_cols.append(F.stddev(input_repl).alias(agg["name"])) + if aggregator["name"] == "stddev_pop": + agg_cols.append(F.stddev_pop(input_repl).alias(agg["name"])) + if aggregator["name"] == "stddev_samp": + agg_cols.append(F.stddev_samp(input_repl).alias(agg["name"])) + if aggregator["name"] in set("sum_int", "sum_float"): + agg_cols.append(F.sum(input_repl).alias(agg["name"])) + if aggregator["name"] in set("sum_distinct_int", "sum_distinct_float"): + agg_cols.append(F.sumDistinct(input_repl).alias(agg["name"])) + if aggregator["name"] == "var_pop": + agg_cols.append(F.var_pop(input_repl).alias(agg["name"])) + if aggregator["name"] == "var_samp": + agg_cols.append(F.var_samp(input_repl).alias(agg["name"])) + if aggregator["name"] == "variance": + agg_cols.append(F.variance(input_repl).alias(agg["name"])) results = df.agg(*agg_cols).collect()[0].asDict() - for r in builtin_aggregates: - ctx.store_aggregate_result(results[r["name"]], r) + for agg in builtin_aggregates: + result = results[agg["name"]] + aggregator = ctx.aggregators[agg["aggregator"]] + result = util.cast_output_type(result, aggregator["output_type"]) + + results[agg["name"]] = result + ctx.store_aggregate_result(result, agg) return results -def run_custom_aggregator(aggregator_resource, df, ctx, spark): - aggregator = ctx.aggregators[aggregator_resource["aggregator"]] - aggregate_name = aggregator_resource["name"] - aggregator_impl, _ = ctx.get_aggregator_impl(aggregate_name) - input_schema = aggregator_resource["inputs"] - aggregator_column_input = input_schema["columns"] - args_schema = input_schema["args"] - args = {} - if input_schema.get("args", None) is not None and len(args_schema) > 0: - args = ctx.populate_args(input_schema["args"]) +def run_custom_aggregator(aggregate, df, ctx, spark): + aggregator = ctx.aggregators[aggregate["aggregator"]] + aggregator_impl, _ = ctx.get_aggregator_impl(aggregate["name"]) + input_repl = ctx.populate_values( + aggregate["input"], aggregator["input"], preserve_column_refs=False + ) + try: - result = aggregator_impl.aggregate_spark(df, aggregator_column_input, args) + result = aggregator_impl.aggregate_spark(df, input_repl) except Exception as e: raise UserRuntimeException( - "aggregate " + aggregator_resource["name"], + "aggregate " + aggregate["name"], "aggregator " + aggregator["name"], "function aggregate_spark", ) from e - if aggregator["output_type"] and not util.validate_value_type( + if "output_type" in aggregator and not util.validate_output_type( result, aggregator["output_type"] ): raise UserException( - "aggregate " + aggregator_resource["name"], + "aggregate " + aggregate["name"], "aggregator " + aggregator["name"], "type of {} is not {}".format( util.str_rep(util.pp_str(result), truncate=100), aggregator["output_type"] ), ) - ctx.store_aggregate_result(result, aggregator_resource) + result = util.cast_output_type(result, aggregator["output_type"]) + ctx.store_aggregate_result(result, aggregate) return result -def extract_inputs(column_name, ctx): - columns_input_config = ctx.transformed_columns[column_name]["inputs"]["columns"] - impl_args_schema = ctx.transformed_columns[column_name]["inputs"]["args"] - if impl_args_schema is not None: - impl_args = ctx.populate_args(impl_args_schema) - else: - impl_args = {} - return columns_input_config, impl_args +# def extract_inputs(column_name, ctx): +# columns_input_config = ctx.transformed_columns[column_name]["inputs"]["columns"] +# impl_args_schema = ctx.transformed_columns[column_name]["inputs"]["args"] +# if impl_args_schema is not None: +# impl_args = ctx.populate_args(impl_args_schema) +# else: +# impl_args = {} +# return columns_input_config, impl_args def execute_transform_spark(column_name, df, ctx, spark): trans_impl, trans_impl_path = ctx.get_transformer_impl(column_name) + transformed_column = ctx.transformed_columns[column_name] + transformer = ctx.transformers[transformed_column["transformer"]] if trans_impl_path not in ctx.spark_uploaded_impls: spark.sparkContext.addPyFile(trans_impl_path) # Executor pods need this because of the UDF ctx.spark_uploaded_impls[trans_impl_path] = True - columns_input_config, impl_args = extract_inputs(column_name, ctx) + input_repl = ctx.populate_values( + transformed_column["input"], transformer["input"], preserve_column_refs=False + ) try: - return trans_impl.transform_spark(df, columns_input_config, impl_args, column_name) + return trans_impl.transform_spark(df, input_repl, column_name) except Exception as e: raise UserRuntimeException("function transform_spark") from e +# def column_names_to_index(columns_input_config): +# column_list = [] +# for k, v in columns_input_config.items(): +# if util.is_list(v): +# column_list += v +# else: +# column_list.append(v) + +# required_input_columns_sorted = sorted(set(column_list)) + +# index_to_col_map = dict( +# [(column_name, idx) for idx, column_name in enumerate(required_input_columns_sorted)] +# ) + +# columns_input_config_indexed = create_inputs_map(index_to_col_map, columns_input_config) +# return required_input_columns_sorted, columns_input_config_indexed + + def execute_transform_python(column_name, df, ctx, spark, validate=False): trans_impl, trans_impl_path = ctx.get_transformer_impl(column_name) - columns_input_config, impl_args = extract_inputs(column_name, ctx) + transformed_column = ctx.transformed_columns[column_name] + transformer = ctx.transformers[transformed_column["transformer"]] + + input_cols_sorted = sorted(ctx.extract_column_names(transformed_column["input"])) + input_repl = ctx.populate_values( + transformed_column["input"], transformer["input"], preserve_column_refs=True + ) if trans_impl_path not in ctx.spark_uploaded_impls: spark.sparkContext.addPyFile(trans_impl_path) # Executor pods need this because of the UDF - # not a dictionary because it is possible that one column may map to multiple input names ctx.spark_uploaded_impls[trans_impl_path] = True - required_columns_sorted, columns_input_config_indexed = column_names_to_index( - columns_input_config - ) - def _transform(*values): - inputs = create_inputs_map(values, columns_input_config_indexed) - return trans_impl.transform_python(inputs, impl_args) + transformer_input = create_transformer_inputs_from_lists( + input_repl, transformer["input"], input_cols_sorted, values + ) + return trans_impl.transform_python(transformer_input) transform_python_func = _transform @@ -541,7 +585,7 @@ def _transform_and_validate(*values): if not util.validate_column_type(result, column_type): raise UserException( "transformed column " + column_name, - "tranformation " + transformed_column["transformer"], + "tranformer " + transformed_column["transformer"], "type of {} is not {}".format(result, column_type), ) @@ -549,9 +593,9 @@ def _transform_and_validate(*values): transform_python_func = _transform_and_validate - column_data_type_str = ctx.get_inferred_column_type(column_name) - transform_udf = F.udf(transform_python_func, CORTEX_TYPE_TO_SPARK_TYPE[column_data_type_str]) - return df.withColumn(column_name, transform_udf(*required_columns_sorted)) + column_type = ctx.get_inferred_column_type(column_name) + transform_udf = F.udf(transform_python_func, CORTEX_TYPE_TO_SPARK_TYPE[column_type]) + return df.withColumn(column_name, transform_udf(*input_cols_sorted)) def infer_type(obj): @@ -577,14 +621,20 @@ def validate_transformer(column_name, test_df, ctx, spark): if transformer["output_type"] == consts.COLUMN_TYPE_INFERRED: sample_df = test_df.collect() sample = sample_df[0] - inputs = ctx.create_column_inputs_map(sample, column_name) - _, impl_args = extract_inputs(column_name, ctx) - initial_transformed_sample = trans_impl.transform_python(inputs, impl_args) + input_repl = ctx.populate_values( + transformed_column["input"], transformer["input"], preserve_column_refs=True + ) + transformer_input = create_transformer_inputs_from_map( + input_repl, transformer["input"], sample + ) + initial_transformed_sample = trans_impl.transform_python(transformer_input) inferred_python_type = infer_type(initial_transformed_sample) for row in sample_df: - inputs = ctx.create_column_inputs_map(row, column_name) - transformed_sample = trans_impl.transform_python(inputs, impl_args) + transformer_input = create_transformer_inputs_from_map( + input_repl, transformer["input"], row + ) + transformed_sample = trans_impl.transform_python(transformer_input) if inferred_python_type != infer_type(transformed_sample): raise UserRuntimeException( "transformed column " + column_name, @@ -635,9 +685,12 @@ def validate_transformer(column_name, test_df, ctx, spark): except Exception as e: raise UserRuntimeException("function transform_spark") from e - actual_structfield = transform_spark_df.select(column_name).schema.fields[0] + if transformer["output_type"] == consts.COLUMN_TYPE_INFERRED: + inferred_spark_type = transform_spark_df.select(column_name).schema[0].dataType + ctx.write_metadata(transformed_column["id"], {"type": inferred_spark_type}) # check that expected output column has the correct data type + actual_structfield = transform_spark_df.select(column_name).schema.fields[0] if ( actual_structfield.dataType not in CORTEX_TYPE_TO_ACCEPTABLE_SPARK_TYPES[ @@ -656,11 +709,7 @@ def validate_transformer(column_name, test_df, ctx, spark): ) ) - if transformer["output_type"] == consts.COLUMN_TYPE_INFERRED: - inferred_spark_type = transform_spark_df.select(column_name).schema[0].dataType - ctx.write_metadata(transformed_column["id"], {"type": inferred_spark_type}) - - # perform the necessary upcast/downcast for the column e.g INT -> LONG or DOUBLE -> FLOAT + # perform the necessary casting for the column transform_spark_df = transform_spark_df.withColumn( column_name, F.col(column_name).cast( @@ -732,9 +781,13 @@ def transform_column(column_name, df, ctx, spark): trans_impl, _ = ctx.get_transformer_impl(column_name) if hasattr(trans_impl, "transform_spark"): - column_type = CORTEX_TYPE_TO_SPARK_TYPE[ctx.get_inferred_column_type(column_name)] df = execute_transform_spark(column_name, df, ctx, spark) - return df.withColumn(column_name, F.col(column_name).cast(column_type)) + return df.withColumn( + column_name, + F.col(column_name).cast( + CORTEX_TYPE_TO_SPARK_TYPE[ctx.get_inferred_column_type(column_name)] + ), + ) elif hasattr(trans_impl, "transform_python"): return execute_transform_python(column_name, df, ctx, spark) else: @@ -747,7 +800,9 @@ def transform_column(column_name, df, ctx, spark): def transform(model_name, accumulated_df, ctx, spark): model = ctx.models[model_name] - column_names = model["feature_columns"] + [model["target_column"]] + model["training_columns"] + column_names = ctx.extract_column_names( + [model["input"], model["target_column"], model.get("training_input")] + ) for column_name in column_names: accumulated_df = transform_column(column_name, accumulated_df, ctx, spark) diff --git a/pkg/workloads/tf_api/api.py b/pkg/workloads/tf_api/api.py index 4d3e6d69c6..858b938a4f 100644 --- a/pkg/workloads/tf_api/api.py +++ b/pkg/workloads/tf_api/api.py @@ -39,10 +39,12 @@ local_cache = { "ctx": None, "model": None, + "estimator": None, + "target_col": None, + "target_col_type": None, "stub": None, "api": None, "trans_impls": {}, - "transform_args_cache": {}, "required_inputs": None, "metadata": None, } @@ -66,21 +68,23 @@ def transform_sample(sample): transformed_sample = {} - for column_name in model["feature_columns"]: + for column_name in ctx.extract_column_names(model["input"]): if ctx.is_raw_column(column_name): transformed_value = sample[column_name] else: - inputs = ctx.create_column_inputs_map(sample, column_name) + transformed_column = ctx.transformed_columns[column_name] + input_repl = ctx.populate_values( + transformed_column["input"], None, preserve_column_refs=False + ) trans_impl = local_cache["trans_impls"][column_name] if not hasattr(trans_impl, "transform_python"): raise UserException( "transformed column " + column_name, - "transformer " + ctx.transformed_sample[column_name]["transformer"], - "transform_python function missing", + "transformer " + transformed_column["transformer"], + "transform_python() function is missing", ) + transformed_value = trans_impl.transform_python(input_repl) - args = local_cache["transform_args_cache"].get(column_name, {}) - transformed_value = trans_impl.transform_python(inputs, args) transformed_sample[column_name] = transformed_value return transformed_sample @@ -109,22 +113,18 @@ def create_prediction_request(transformed_sample): def reverse_transform(value): ctx = local_cache["ctx"] model = local_cache["model"] + target_col = local_cache["target_col"] - trans_impl = local_cache["trans_impls"].get(model["target_column"], None) + trans_impl = local_cache["trans_impls"].get(target_col["name"]) if not (trans_impl and hasattr(trans_impl, "reverse_transform_python")): return None - transformer_name = model["target_column"] - input_schema = ctx.transformed_columns[transformer_name]["inputs"] - - if input_schema.get("args", None) is not None and len(input_schema["args"]) > 0: - args = local_cache["transform_args_cache"].get(transformer_name, {}) + input_repl = ctx.populate_values(target_col["input"], None, preserve_column_refs=False) try: - result = trans_impl.reverse_transform_python(value, args) + result = trans_impl.reverse_transform_python(value, input_repl) except Exception as e: raise UserRuntimeException( - "transformer " + ctx.transformed_columns[model["target_column"]]["transformer"], - "function reverse_transform_python", + "transformer " + target_col["transformer"], "function reverse_transform_python" ) from e return result @@ -140,33 +140,34 @@ def parse_response_proto(response_proto): response_proto.result() may be necessary (TF > 1.2?) """ model = local_cache["model"] + estimator = local_cache["estimator"] + target_col_type = local_cache["target_col_type"] - if model["type"] == "regression": - prediction_key = "predictions" - if model["type"] == "classification": + if target_col_type in {consts.COLUMN_TYPE_STRING, consts.COLUMN_TYPE_INT}: prediction_key = "class_ids" + else: + prediction_key = "predictions" - if model["prediction_key"]: - prediction_key = model["prediction_key"] + if estimator["prediction_key"]: + prediction_key = estimator["prediction_key"] results_dict = json_format.MessageToDict(response_proto) outputs = results_dict["outputs"] value_key = DTYPE_TO_VALUE_KEY[outputs[prediction_key]["dtype"]] predicted = outputs[prediction_key][value_key][0] + predicted = util.upcast(predicted, target_col_type) result = {} for key in outputs.keys(): value_key = DTYPE_TO_VALUE_KEY[outputs[key]["dtype"]] result[key] = outputs[key][value_key] - if model["type"] == "regression": - predicted = float(predicted) - result["predicted_value"] = predicted - result["predicted_value_reversed"] = reverse_transform(predicted) - if model["type"] == "classification": - predicted = int(predicted) + if target_col_type in {consts.COLUMN_TYPE_STRING, consts.COLUMN_TYPE_INT}: result["predicted_class"] = predicted result["predicted_class_reversed"] = reverse_transform(predicted) + else: + result["predicted_value"] = predicted + result["predicted_value_reversed"] = reverse_transform(predicted) return result @@ -205,13 +206,15 @@ def run_predict(sample): def is_valid_sample(sample): + ctx = local_cache["ctx"] + for column in local_cache["required_inputs"]: if column["name"] not in sample: return False, "{} is missing".format(column["name"]) sample_val = sample[column["name"]] - column_type = local_cache["ctx"].get_inferred_column_type(column["name"]) - is_valid = util.CORTEX_TYPE_TO_UPCAST_VALIDATOR[column_type](sample_val) + column_type = ctx.get_inferred_column_type(column["name"]) + is_valid = util.CORTEX_TYPE_TO_VALIDATOR[column_type](sample_val) if not is_valid: return (False, "{} should be a {}".format(column["name"], column_type)) @@ -240,8 +243,11 @@ def predict(app_name, api_name): except Exception as e: return "Malformed JSON", status.HTTP_400_BAD_REQUEST + ctx = local_cache["ctx"] model = local_cache["model"] + estimator = local_cache["estimator"] api = local_cache["api"] + target_col_type = local_cache["target_col_type"] response = {} @@ -267,7 +273,7 @@ def predict(app_name, api_name): return prediction_failed(sample, reason) for column in local_cache["required_inputs"]: - column_type = local_cache["ctx"].get_inferred_column_type(column["name"]) + column_type = ctx.get_inferred_column_type(column["name"]) sample[column["name"]] = util.upcast(sample[column["name"]], column_type) try: @@ -287,10 +293,10 @@ def predict(app_name, api_name): predictions.append(result) - if model["type"] == "regression": - response["regression_predictions"] = predictions - if model["type"] == "classification": + if target_col_type in {consts.COLUMN_TYPE_STRING, consts.COLUMN_TYPE_INT}: response["classification_predictions"] = predictions + else: + response["regression_predictions"] = predictions response["resource_id"] = api["id"] @@ -303,26 +309,31 @@ def start(args): api = ctx.apis_id_map[args.api] model = ctx.models[api["model_name"]] + estimator = ctx.estimators[model["estimator"]] tf_lib.set_logging_verbosity(ctx.environment["log_level"]["tensorflow"]) local_cache["ctx"] = ctx local_cache["api"] = api local_cache["model"] = model + local_cache["estimator"] = estimator + local_cache["target_col"] = ctx.columns[util.get_resource_ref(model["target_column"])] + local_cache["target_col_type"] = ctx.get_inferred_column_type( + util.get_resource_ref(model["target_column"]) + ) if not os.path.isdir(args.model_dir): ctx.storage.download_and_unzip(model["key"], args.model_dir) - for column_name in model["feature_columns"] + [model["target_column"]]: + for column_name in ctx.extract_column_names([model["input"], model["target_column"]]): if ctx.is_transformed_column(column_name): trans_impl, _ = ctx.get_transformer_impl(column_name) local_cache["trans_impls"][column_name] = trans_impl transformed_column = ctx.transformed_columns[column_name] - input_args_schema = transformed_column["inputs"]["args"] - # cache aggregates and constants in memory - if input_args_schema is not None: - local_cache["transform_args_cache"][column_name] = ctx.populate_args( - input_args_schema - ) + + # cache aggregate values + for resource_name in util.extract_resource_refs(transformed_column["input"]): + if resource_name in ctx.aggregates: + ctx.get_obj(ctx.aggregates[resource_name]["key"]) channel = grpc.insecure_channel("localhost:" + str(args.tf_serve_port)) local_cache["stub"] = prediction_service_pb2_grpc.PredictionServiceStub(channel) diff --git a/pkg/workloads/tf_train/train.py b/pkg/workloads/tf_train/train.py index 2b3e24f9c7..35f74a2068 100644 --- a/pkg/workloads/tf_train/train.py +++ b/pkg/workloads/tf_train/train.py @@ -41,8 +41,8 @@ def train(args): ctx.upload_resource_status_start(model) try: - model_impl = ctx.get_model_impl(model["name"]) - train_util.train(model["name"], model_impl, ctx, model_dir) + estimator_impl = ctx.get_estimator_impl(model["name"]) + train_util.train(model["name"], estimator_impl, ctx, model_dir) ctx.upload_resource_status_success(model) logger.info("Caching") diff --git a/pkg/workloads/tf_train/train_util.py b/pkg/workloads/tf_train/train_util.py index 8f7105d25d..2bf3bf2506 100644 --- a/pkg/workloads/tf_train/train_util.py +++ b/pkg/workloads/tf_train/train_util.py @@ -36,17 +36,17 @@ def get_input_placeholder(model_name, ctx, training=True): def get_label_placeholder(model_name, ctx): model = ctx.models[model_name] - target_column_name = model["target_column"] + target_column_name = util.get_resource_ref(model["target_column"]) column_type = tf_lib.CORTEX_TYPE_TO_TF_TYPE[ctx.columns[target_column_name]["type"]] return tf.placeholder(shape=[None], dtype=column_type) -def get_transform_tensor_fn(ctx, model_impl, model_name): +def get_transform_tensor_fn(ctx, estimator_impl, model_name): model = ctx.models[model_name] model_config = ctx.model_config(model["name"]) def transform_tensor_fn_wrapper(inputs, labels): - return model_impl.transform_tensorflow(inputs, labels, model_config) + return estimator_impl.transform_tensorflow(inputs, labels, model_config) return transform_tensor_fn_wrapper @@ -58,14 +58,14 @@ def generate_example_parsing_fn(model_name, ctx, training=True): def _parse_example(example_proto): features = tf.parse_single_example(serialized=example_proto, features=feature_spec) - target = features.pop(model["target_column"], None) + target = features.pop(util.get_resource_ref(model["target_column"]), None) return features, target return _parse_example # Mode must be "training" or "evaluation" -def generate_input_fn(model_name, ctx, mode, model_impl): +def generate_input_fn(model_name, ctx, mode, estimator_impl): model = ctx.models[model_name] filenames = ctx.get_training_data_parts(model_name, mode) @@ -84,8 +84,8 @@ def _input_fn(): if model[mode]["shuffle"]: dataset = dataset.shuffle(buffer_size) - if hasattr(model_impl, "transform_tensorflow"): - dataset = dataset.map(get_transform_tensor_fn(ctx, model_impl, model_name)) + if hasattr(estimator_impl, "transform_tensorflow"): + dataset = dataset.map(get_transform_tensor_fn(ctx, estimator_impl, model_name)) dataset = dataset.batch(model[mode]["batch_size"]) dataset = dataset.prefetch(buffer_size) @@ -98,14 +98,14 @@ def _input_fn(): return _input_fn -def generate_json_serving_input_fn(model_name, ctx, model_impl): +def generate_json_serving_input_fn(model_name, ctx, estimator_impl): def _json_serving_input_fn(): inputs = get_input_placeholder(model_name, ctx, training=False) labels = get_label_placeholder(model_name, ctx) features = {key: tensor for key, tensor in inputs.items()} - if hasattr(model_impl, "transform_tensorflow"): - features, _ = get_transform_tensor_fn(ctx, model_impl, model_name)(features, labels) + if hasattr(estimator_impl, "transform_tensorflow"): + features, _ = get_transform_tensor_fn(ctx, estimator_impl, model_name)(features, labels) features = {key: tf.expand_dims(tensor, 0) for key, tensor in features.items()} return tf.estimator.export.ServingInputReceiver(features=features, receiver_tensors=inputs) @@ -124,7 +124,7 @@ def get_regression_eval_metrics(labels, predictions): return metrics -def train(model_name, model_impl, ctx, model_dir): +def train(model_name, estimator_impl, ctx, model_dir): model = ctx.models[model_name] util.mkdir_p(model_dir) @@ -143,9 +143,9 @@ def train(model_name, model_impl, ctx, model_dir): model_dir=model_dir, ) - train_input_fn = generate_input_fn(model_name, ctx, "training", model_impl) - eval_input_fn = generate_input_fn(model_name, ctx, "evaluation", model_impl) - serving_input_fn = generate_json_serving_input_fn(model_name, ctx, model_impl) + train_input_fn = generate_input_fn(model_name, ctx, "training", estimator_impl) + eval_input_fn = generate_input_fn(model_name, ctx, "evaluation", estimator_impl) + serving_input_fn = generate_json_serving_input_fn(model_name, ctx, estimator_impl) exporter = tf.estimator.FinalExporter("estimator", serving_input_fn, as_text=False) train_num_steps = model["training"]["num_steps"] @@ -177,13 +177,14 @@ def train(model_name, model_impl, ctx, model_dir): model_config = ctx.model_config(model["name"]) try: - estimator = model_impl.create_estimator(run_config, model_config) + tf_estimator = estimator_impl.create_estimator(run_config, model_config) except Exception as e: raise UserRuntimeException("model " + model_name) from e - if model["type"] == "regression": - estimator = tf.contrib.estimator.add_metrics(estimator, get_regression_eval_metrics) + target_col_name = util.get_resource_ref(model["target_column"]) + if ctx.get_inferred_column_type(target_col_name) == consts.COLUMN_TYPE_FLOAT: + tf_estimator = tf.contrib.estimator.add_metrics(tf_estimator, get_regression_eval_metrics) - tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) + tf.estimator.train_and_evaluate(tf_estimator, train_spec, eval_spec) return model_dir From da65a21a95fa179183d361c2ca6d7cac7f0303e8 Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Mon, 10 Jun 2019 18:49:19 -0700 Subject: [PATCH 10/44] Don't allow INT_COLUMN for FLOAT_COLUMN --- pkg/operator/api/userconfig/compound_type.go | 6 -- pkg/operator/context/resources_test.go | 60 ++++++++++++++------ pkg/workloads/lib/context.py | 52 +++-------------- pkg/workloads/spark_job/spark_util.py | 10 +--- 4 files changed, 53 insertions(+), 75 deletions(-) diff --git a/pkg/operator/api/userconfig/compound_type.go b/pkg/operator/api/userconfig/compound_type.go index fdab9891d2..9c88432bd2 100644 --- a/pkg/operator/api/userconfig/compound_type.go +++ b/pkg/operator/api/userconfig/compound_type.go @@ -101,12 +101,6 @@ func (compoundType *CompoundType) SupportsType(t interface{}) bool { parsed, _ := parseCompoundType(string(*compoundType)) if columnType, ok := t.(ColumnType); ok { - if columnType == IntegerColumnType { - return parsed.columnTypes[IntegerColumnType] || parsed.columnTypes[FloatColumnType] - } - if columnType == IntegerListColumnType { - return parsed.columnTypes[IntegerListColumnType] || parsed.columnTypes[FloatListColumnType] - } return parsed.columnTypes[columnType] || columnType == InferredColumnType } diff --git a/pkg/operator/context/resources_test.go b/pkg/operator/context/resources_test.go index 1d08da4699..f69aa904e5 100644 --- a/pkg/operator/context/resources_test.go +++ b/pkg/operator/context/resources_test.go @@ -176,9 +176,9 @@ func TestValidateRuntimeTypes(t *testing.T) { checkValidateRuntimeTypesError(t, `{a: FLOAT, b: INT, c: INT}`, `@ca`) checkValidateRuntimeTypesEqual(t, `INT_COLUMN`, `@rc1`, `🌝🌝🌝🌝🌝rc1`) - checkValidateRuntimeTypesEqual(t, `FLOAT_COLUMN`, `@rc1`, `🌝🌝🌝🌝🌝rc1`) + checkValidateRuntimeTypesError(t, `FLOAT_COLUMN`, `@rc1`) checkValidateRuntimeTypesError(t, `STRING_COLUMN`, `@rc1`) - checkValidateRuntimeTypesEqual(t, `FLOAT_COLUMN|STRING_COLUMN`, `@rc1`, `🌝🌝🌝🌝🌝rc1`) + checkValidateRuntimeTypesEqual(t, `INT_COLUMN|STRING_COLUMN`, `@rc1`, `🌝🌝🌝🌝🌝rc1`) checkValidateRuntimeTypesError(t, `INT_LIST_COLUMN`, `@rc1`) checkValidateRuntimeTypesError(t, `[INT_COLUMN]`, `@rc1`) checkValidateRuntimeTypesError(t, `{INT_COLUMN: INT}`, `@rc1`) @@ -323,7 +323,7 @@ func TestValidateRuntimeTypes(t *testing.T) { checkValidateRuntimeTypesError(t, `STRING_COLUMN`, `@tc3`) checkValidateRuntimeTypesEqual(t, `INT_COLUMN`, `@tc3`, `🌝🌝🌝🌝🌝tc3`) - checkValidateRuntimeTypesEqual(t, `FLOAT_COLUMN`, `@tc3`, `🌝🌝🌝🌝🌝tc3`) + checkValidateRuntimeTypesError(t, `FLOAT_COLUMN`, `@tc3`) checkValidateRuntimeTypesError(t, `STRING_LIST_COLUMN`, `@tc3`) checkValidateRuntimeTypesError(t, `INT_LIST_COLUMN`, `@tc3`) checkValidateRuntimeTypesError(t, `FLOAT_LIST_COLUMN`, `@tc3`) @@ -361,7 +361,7 @@ func TestValidateRuntimeTypes(t *testing.T) { checkValidateRuntimeTypesError(t, `FLOAT_COLUMN`, `@tc6`) checkValidateRuntimeTypesError(t, `STRING_LIST_COLUMN`, `@tc6`) checkValidateRuntimeTypesEqual(t, `INT_LIST_COLUMN`, `@tc6`, `🌝🌝🌝🌝🌝tc6`) - checkValidateRuntimeTypesEqual(t, `FLOAT_LIST_COLUMN`, `@tc6`, `🌝🌝🌝🌝🌝tc6`) + checkValidateRuntimeTypesError(t, `FLOAT_LIST_COLUMN`, `@tc6`) checkValidateRuntimeTypesError(t, `[STRING_COLUMN]`, `@tc6`) checkValidateRuntimeTypesError(t, `[INT_COLUMN]`, `@tc6`) checkValidateRuntimeTypesError(t, `{STRING_COLUMN: INT}`, `@tc6`) @@ -386,8 +386,12 @@ func TestValidateRuntimeTypes(t *testing.T) { checkValidateRuntimeTypesEqual(t, `[FLOAT_COLUMN]`, - `[@tc3, @rc1, @tc4, @rc4]`, - []interface{}{"🌝🌝🌝🌝🌝tc3", "🌝🌝🌝🌝🌝rc1", "🌝🌝🌝🌝🌝tc4", "🌝🌝🌝🌝🌝rc4"}) + `[@tc4, @rc4]`, + []interface{}{"🌝🌝🌝🌝🌝tc4", "🌝🌝🌝🌝🌝rc4"}) + + checkValidateRuntimeTypesError(t, + `[FLOAT_COLUMN]`, + `[@tc3, @rc1, @tc4, @rc4]`) checkValidateRuntimeTypesEqual(t, `[FLOAT]`, @@ -416,16 +420,16 @@ func TestValidateRuntimeTypes(t *testing.T) { checkValidateRuntimeTypesEqual(t, `{2: FLOAT_COLUMN, 3: FLOAT}`, - `{2: @tc3, 3: @agg4}`, - map[interface{}]interface{}{int64(2): "🌝🌝🌝🌝🌝tc3", int64(3): "🌝🌝🌝🌝🌝agg4"}) + `{2: @tc4, 3: @agg4}`, + map[interface{}]interface{}{int64(2): "🌝🌝🌝🌝🌝tc4", int64(3): "🌝🌝🌝🌝🌝agg4"}) checkValidateRuntimeTypesEqual(t, `{FLOAT: FLOAT_COLUMN}`, - `{2: @tc3, 3: @tc4, @agg4: @rc1, @agg5: @rc2, @c5: @rc4, @c6: @tc2}`, + `{2: @tc4, 3: @tc4, @agg4: @rc4, @agg5: @rc2, @c5: @rc4, @c6: @tc2}`, map[interface{}]interface{}{ - float64(2): "🌝🌝🌝🌝🌝tc3", + float64(2): "🌝🌝🌝🌝🌝tc4", float64(3): "🌝🌝🌝🌝🌝tc4", - "🌝🌝🌝🌝🌝agg4": "🌝🌝🌝🌝🌝rc1", + "🌝🌝🌝🌝🌝agg4": "🌝🌝🌝🌝🌝rc4", "🌝🌝🌝🌝🌝agg5": "🌝🌝🌝🌝🌝rc2", "🌝🌝🌝🌝🌝c5": "🌝🌝🌝🌝🌝rc4", "🌝🌝🌝🌝🌝c6": "🌝🌝🌝🌝🌝tc2", @@ -524,16 +528,38 @@ func TestValidateRuntimeTypes(t *testing.T) { c: {1: INT_COLUMN, 2: FLOAT_COLUMN, 3: BOOL, 4: STRING} d: {INT: INT_COLUMN} e: {FLOAT_COLUMN: FLOAT|STRING} - f: {INT_LIST_COLUMN|STRING_COLUMN: FLOAT_COLUMN} - g: [FLOAT] + f: {FLOAT_COLUMN: FLOAT|STRING} + g: {INT_LIST_COLUMN|STRING_COLUMN: FLOAT_COLUMN} + h: [FLOAT] `, ` - a: @tc4 b: @rc3 - c: {1: @rc1, 2: @tc3, 3: true, 4: @agg1} + c: {1: @rc1, 2: @tc4, 3: true, 4: @agg1} d: {1: @rc1, 2: @tc3, @c6: @rc2, @agg4: @tc2} - e: {@tc3: @agg4, @rc4: test, @tc2: 2.2, @rc1: @c1, @tc4: @agg5, @rc2: 2} - f: {@tc6: @tc4, @tc1: @rc2, @rc3: @rc1} - g: [@c5, @c6, 2, 2.2, @agg4, @agg5] + e: {@rc4: test, @tc2: 2.2, @tc4: @agg5, @rc2: 2} + f: {@tc4: @agg4, @rc4: @c1} + g: {@tc6: @tc4, @tc1: @rc2, @rc3: @rc4} + h: [@c5, @c6, 2, 2.2, @agg4, @agg5] + `) + + checkValidateRuntimeTypesError(t, ` + - a: FLOAT_COLUMN + b: INT_COLUMN|STRING_COLUMN + c: {1: INT_COLUMN, 2: FLOAT_COLUMN, 3: BOOL, 4: STRING} + d: {INT: INT_COLUMN} + e: {FLOAT_COLUMN: FLOAT|STRING} + f: {FLOAT_COLUMN: FLOAT|STRING} + g: {INT_LIST_COLUMN|STRING_COLUMN: FLOAT_COLUMN} + h: [FLOAT] + `, ` + - a: @tc4 + b: @rc3 + c: {1: @rc1, 2: @tc4, 3: true, 4: @agg1} + d: {1: @rc1, 2: @tc3, @c6: @rc2, @agg4: @tc2} + e: {@rc4: test, @tc2: 2.2, @tc4: @agg5, @rc2: 2} + f: {@tc4: @agg4, @rc4: @c1} + g: {@tc6: @tc4, @tc1: @rc2, @rc3: @rc1} + h: [@c5, @c6, 2, 2.2, @agg4, @agg5] `) checkValidateRuntimeTypesError(t, ` diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index e155ee96b2..ed0ec474ea 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -679,62 +679,24 @@ def _deserialize_raw_ctx(raw_ctx): # input should already have non-column arguments replaced, and all types validated -def create_transformer_inputs_from_map(input, input_schema, col_value_map): +def create_transformer_inputs_from_map(input, col_value_map): if util.is_str(input): res_name = util.get_resource_ref(input) if res_name in col_value_map: - value = col_value_map[res_name] - if input_schema is not None: - valid_col_types = input_schema["_type"].split("|") - if util.is_int(value) and consts.COLUMN_TYPE_INT not in valid_col_types: - value = float(value) - if util.is_int_list(value) and consts.COLUMN_TYPE_INT_LIST not in valid_col_types: - value = [float(elem) for elem in value] - return value + return col_value_map[res_name] return input if util.is_list(input): replaced = [] for item in input: - sub_schema = None - if input_schema is not None: - sub_schema = input_schema["_type"][0] - replaced.append(create_transformer_inputs_from_map(item, sub_schema, col_value_map)) + replaced.append(create_transformer_inputs_from_map(item, col_value_map)) return replaced if util.is_dict(input): replaced = {} - - if input_schema is not None: - is_generic_map = False - generic_key_sub_schema = None - generic_val_sub_schema = None - if len(input_schema["_type"]) == 1: - input_type_key = next(iter(input_schema["_type"].keys())) - if is_compound_type(input_type_key): - is_generic_map = True - generic_key_sub_schema = { - "_type": input_type_key, - "_optional": False, - "_default": None, - "_allow_null": False, - "_min_count": None, - "_max_count": None, - } - generic_val_sub_schema = input_schema["_type"][input_type_key] - for key, val in input.items(): - key_sub_schema = None - val_sub_schema = None - if input_schema is not None: - if is_generic_map: - key_sub_schema = generic_key_sub_schema - val_sub_schema = generic_val_sub_schema - else: - val_sub_schema = input_schema["_type"].get(key) - - key_replaced = create_transformer_inputs_from_map(key, key_sub_schema, col_value_map) - val_replaced = create_transformer_inputs_from_map(val, val_sub_schema, col_value_map) + key_replaced = create_transformer_inputs_from_map(key, col_value_map) + val_replaced = create_transformer_inputs_from_map(val, col_value_map) replaced[key_replaced] = val_replaced return replaced @@ -742,12 +704,12 @@ def create_transformer_inputs_from_map(input, input_schema, col_value_map): # input should already have non-column arguments replaced, and all types validated -def create_transformer_inputs_from_lists(input, input_schema, input_cols_sorted, col_values): +def create_transformer_inputs_from_lists(input, input_cols_sorted, col_values): col_value_map = {} for col_name, col_value in zip(input_cols_sorted, col_values): col_value_map[col_name] = col_value - return create_transformer_inputs_from_map(input, input_schema, col_value_map, col_type_map) + return create_transformer_inputs_from_map(input, col_value_map) # def create_column_inputs_map(self, values_map, column_name): diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 1bd2875cc2..5f258bf79e 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -570,7 +570,7 @@ def execute_transform_python(column_name, df, ctx, spark, validate=False): def _transform(*values): transformer_input = create_transformer_inputs_from_lists( - input_repl, transformer["input"], input_cols_sorted, values + input_repl, input_cols_sorted, values ) return trans_impl.transform_python(transformer_input) @@ -624,16 +624,12 @@ def validate_transformer(column_name, test_df, ctx, spark): input_repl = ctx.populate_values( transformed_column["input"], transformer["input"], preserve_column_refs=True ) - transformer_input = create_transformer_inputs_from_map( - input_repl, transformer["input"], sample - ) + transformer_input = create_transformer_inputs_from_map(input_repl, sample) initial_transformed_sample = trans_impl.transform_python(transformer_input) inferred_python_type = infer_type(initial_transformed_sample) for row in sample_df: - transformer_input = create_transformer_inputs_from_map( - input_repl, transformer["input"], row - ) + transformer_input = create_transformer_inputs_from_map(input_repl, row) transformed_sample = trans_impl.transform_python(transformer_input) if inferred_python_type != infer_type(transformed_sample): raise UserRuntimeException( From f6682b8a9714f0893a3b98d25b15640839111c3b Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Mon, 10 Jun 2019 18:49:53 -0700 Subject: [PATCH 11/44] Fix bugs --- pkg/estimators/estimators.yaml | 56 +++++++++++++++++---------- pkg/workloads/lib/context.py | 29 ++++++++++---- pkg/workloads/lib/test/util_test.py | 36 ++++++++--------- pkg/workloads/lib/util.py | 40 +++++++++++++++---- pkg/workloads/spark_job/spark_job.py | 2 - pkg/workloads/spark_job/spark_util.py | 18 ++++----- pkg/workloads/tf_api/api.py | 14 ++++--- pkg/workloads/tf_train/train.py | 2 +- pkg/workloads/tf_train/train_util.py | 1 + 9 files changed, 126 insertions(+), 72 deletions(-) diff --git a/pkg/estimators/estimators.yaml b/pkg/estimators/estimators.yaml index 43a508645c..da0c86d533 100644 --- a/pkg/estimators/estimators.yaml +++ b/pkg/estimators/estimators.yaml @@ -124,9 +124,11 @@ boundaries: [FLOAT] _default: [] training_input: - weight_column: - _type: INT_COLUMN|FLOAT_COLUMN - _optional: True + _type: + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: {} hparams: hidden_units: [INT] @@ -178,9 +180,11 @@ boundaries: [FLOAT] _default: [] training_input: - weight_column: - _type: INT_COLUMN|FLOAT_COLUMN - _optional: True + _type: + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: {} - kind: estimator name: linear_regressor @@ -220,9 +224,11 @@ boundaries: [FLOAT] _default: [] training_input: - weight_column: - _type: INT_COLUMN|FLOAT_COLUMN - _optional: True + _type: + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: {} - kind: estimator name: dnn_linear_combined_classifier @@ -317,9 +323,11 @@ _default: [] training_input: - weight_column: - _type: INT_COLUMN|FLOAT_COLUMN - _optional: True + _type: + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: {} hparams: dnn_hidden_units: [INT] @@ -406,9 +414,11 @@ _default: [] training_input: - weight_column: - _type: INT_COLUMN|FLOAT_COLUMN - _optional: True + _type: + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: {} hparams: dnn_hidden_units: [INT] @@ -469,9 +479,11 @@ boundaries: [FLOAT] _default: [] training_input: - weight_column: - _type: INT_COLUMN|FLOAT_COLUMN - _optional: True + _type: + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: {} hparams: batches_per_layer: INT num_trees: @@ -549,9 +561,11 @@ boundaries: [FLOAT] _default: [] training_input: - weight_column: - _type: INT_COLUMN|FLOAT_COLUMN - _optional: True + _type: + weight_column: + _type: INT_COLUMN|FLOAT_COLUMN + _optional: True + _default: {} hparams: batches_per_layer: INT num_trees: diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index ed0ec474ea..45b7aab520 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -244,7 +244,7 @@ def get_transformer_impl(self, column_name): return (impl, impl_path) def get_estimator_impl(self, model_name): - estimator_name = self.models[model_name]["aggregator"] + estimator_name = self.models[model_name]["estimator"] if estimator_name in self._estimator_impls: return self._estimator_impls[estimator_name] @@ -521,13 +521,15 @@ def populate_values(self, input, input_schema, preserve_column_refs): input_type_key = next(iter(input_schema["_type"].keys())) if is_compound_type(input_type_key): is_generic_map = True - generic_map_key = input_type_key + generic_map_key_schema = input_schema_from_type_schema(input_type_key) generic_map_value = input_schema["_type"][input_type_key] if is_generic_map: casted = {} for key, val in input.items(): - key_casted = self.populate_values(key, generic_map_key, preserve_column_refs) + key_casted = self.populate_values( + key, generic_map_key_schema, preserve_column_refs + ) try: val_casted = self.populate_values( val, generic_map_value, preserve_column_refs @@ -540,14 +542,14 @@ def populate_values(self, input, input_schema, preserve_column_refs): # fixed map casted = {} - for key, val_schema in input_schema["_type"]: + for key, val_schema in input_schema["_type"].items(): default = None if key not in input: - if input_schema.get("_optional") is not True: + if val_schema.get("_optional") is not True: raise UserException("missing key: " + util.pp_str_flat(key)) - if input_schema.get("_default") is None: + if val_schema.get("_default") is None: continue - default = input_schema["_default"] + default = val_schema["_default"] val = input.get(key, default) try: @@ -565,6 +567,17 @@ def populate_values(self, input, input_schema, preserve_column_refs): return cast_compound_type(input, input_schema["_type"]) +def input_schema_from_type_schema(type_schema): + return { + "_type": type_schema, + "_optional": False, + "_default": None, + "_allow_null": False, + "_min_count": None, + "_max_count": None, + } + + def is_compound_type(type_str): if not util.is_str(type_str): return False @@ -591,7 +604,7 @@ def cast_compound_type(value, type_str): if util.is_float(value): return value if consts.VALUE_TYPE_STRING in allowed_types: - if util.is_string(value): + if util.is_str(value): return value if consts.VALUE_TYPE_BOOL in allowed_types: if util.is_bool(value): diff --git a/pkg/workloads/lib/test/util_test.py b/pkg/workloads/lib/test/util_test.py index 4ad0c193e8..2557c068cb 100644 --- a/pkg/workloads/lib/test/util_test.py +++ b/pkg/workloads/lib/test/util_test.py @@ -146,24 +146,24 @@ def test_print_samples_horiz(caplog): assert "\n".join(records) + "\n" == expected -def test_validate_column_type(): - assert util.validate_column_type(2, "INT_COLUMN") == True - assert util.validate_column_type(2.2, "INT_COLUMN") == False - assert util.validate_column_type("2", "INT_COLUMN") == False - assert util.validate_column_type(None, "INT_COLUMN") == True - - assert util.validate_column_type(2.2, "FLOAT_COLUMN") == True - assert util.validate_column_type(2, "FLOAT_COLUMN") == False - assert util.validate_column_type("2", "FLOAT_COLUMN") == False - assert util.validate_column_type(None, "FLOAT_COLUMN") == True - - assert util.validate_column_type("2", "STRING_COLUMN") == True - assert util.validate_column_type(2, "STRING_COLUMN") == False - assert util.validate_column_type(2.2, "STRING_COLUMN") == False - assert util.validate_column_type(None, "STRING_COLUMN") == True - - assert util.validate_column_type("2", "STRING_LIST_COLUMN") == False - assert util.validate_column_type(["2", "string"], "STRING_LIST_COLUMN") == True +def test_validate_cortex_type(): + assert util.validate_cortex_type(2, "INT_COLUMN") == True + assert util.validate_cortex_type(2.2, "INT_COLUMN") == False + assert util.validate_cortex_type("2", "INT_COLUMN") == False + assert util.validate_cortex_type(None, "INT_COLUMN") == True + + assert util.validate_cortex_type(2.2, "FLOAT_COLUMN") == True + assert util.validate_cortex_type(2, "FLOAT_COLUMN") == False + assert util.validate_cortex_type("2", "FLOAT_COLUMN") == False + assert util.validate_cortex_type(None, "FLOAT_COLUMN") == True + + assert util.validate_cortex_type("2", "STRING_COLUMN") == True + assert util.validate_cortex_type(2, "STRING_COLUMN") == False + assert util.validate_cortex_type(2.2, "STRING_COLUMN") == False + assert util.validate_cortex_type(None, "STRING_COLUMN") == True + + assert util.validate_cortex_type("2", "STRING_LIST_COLUMN") == False + assert util.validate_cortex_type(["2", "string"], "STRING_LIST_COLUMN") == True def test_validate_output_type(): diff --git a/pkg/workloads/lib/util.py b/pkg/workloads/lib/util.py index 9d96b722b8..35cd40919a 100644 --- a/pkg/workloads/lib/util.py +++ b/pkg/workloads/lib/util.py @@ -54,6 +54,10 @@ def pp_str(obj, indent=0): return indent_str(out, indent) +def pp(obj, indent=0): + print(pp_str(obj, indent)) + + def pp_str_flat(obj, indent=0): try: out = json.dumps(obj, sort_keys=True) @@ -687,22 +691,42 @@ def log_job_finished(workload_id): consts.COLUMN_TYPE_FLOAT_LIST: lambda ls: [float(item) for item in ls], } +CORTEX_TYPE_TO_CASTER = { + consts.COLUMN_TYPE_INT: lambda x: int(x), + consts.COLUMN_TYPE_INT_LIST: lambda ls: [int(item) for item in ls], + consts.COLUMN_TYPE_FLOAT: lambda x: float(x), + consts.COLUMN_TYPE_FLOAT_LIST: lambda ls: [float(item) for item in ls], + consts.COLUMN_TYPE_STRING: lambda x: str(x), + consts.COLUMN_TYPE_STRING_LIST: lambda ls: [str(item) for item in ls], + consts.VALUE_TYPE_INT: lambda x: int(x), + consts.VALUE_TYPE_FLOAT: lambda x: float(x), + consts.VALUE_TYPE_STRING: lambda x: str(x), + consts.VALUE_TYPE_BOOL: lambda x: bool(x), +} + -def upcast(value, column_type): - upcaster = CORTEX_TYPE_TO_UPCASTER.get(column_type, None) +def upcast(value, cortex_type): + upcaster = CORTEX_TYPE_TO_UPCASTER.get(cortex_type, None) if upcaster: return upcaster(value) return value -def validate_column_type(value, column_type): +def cast(value, cortex_type): + upcaster = CORTEX_TYPE_TO_CASTER.get(cortex_type, None) + if upcaster: + return upcaster(value) + return value + + +def validate_cortex_type(value, cortex_type): if value is None: return True - if not is_str(column_type): + if not is_str(cortex_type): raise - valid_types = column_type.split("|") + valid_types = cortex_type.split("|") for valid_type in valid_types: if CORTEX_TYPE_TO_VALIDATOR[valid_type](value): return True @@ -829,9 +853,9 @@ def get_resource_ref(obj): def extract_resource_refs(input): if is_str(input): - res = util.get_resource_ref(input) + res = get_resource_ref(input) if res is not None: - return set(res) + return {res} return set() if is_dict(input): @@ -844,7 +868,7 @@ def extract_resource_refs(input): if is_list(input): resources = set() for item in input: - resources = resources.union(extract_resource_refs(val)) + resources = resources.union(extract_resource_refs(item)) return resources return set() diff --git a/pkg/workloads/spark_job/spark_job.py b/pkg/workloads/spark_job/spark_job.py index 9e839e57d1..508e0bc514 100644 --- a/pkg/workloads/spark_job/spark_job.py +++ b/pkg/workloads/spark_job/spark_job.py @@ -243,8 +243,6 @@ def validate_transformers(spark, ctx, cols_to_transform, raw_df): for transformed_column in resource_list: ctx.upload_resource_status_start(transformed_column) try: - input_columns_dict = transformed_column["inputs"]["columns"] - input_cols = sorted(ctx.extract_column_names(transformed_column["input"])) tf_name = transformed_column["name"] logger.info("Transforming {} to {}".format(", ".join(input_cols), tf_name)) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 5f258bf79e..f24308fbc1 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -21,7 +21,7 @@ import pyspark.sql.functions as F from lib import util -from lib.context import create_inputs_map +from lib.context import create_transformer_inputs_from_map, create_transformer_inputs_from_lists from lib.exceptions import CortexException, UserException, UserRuntimeException from lib.log import get_logger import consts @@ -364,7 +364,7 @@ def read_parquet(ctx, spark): # not included in this list: collect_list, grouping, grouping_id -AGG_SPARK_LIST = set( +AGG_SPARK_LIST = { "approx_count_distinct", "avg", "collect_set_int", @@ -393,7 +393,7 @@ def read_parquet(ctx, spark): "var_pop", "var_samp", "variance", -) +} def split_aggregators(aggregate_names, ctx): @@ -426,7 +426,7 @@ def run_builtin_aggregators(builtin_aggregates, df, ctx, spark): ) if aggregator["name"] == "avg": agg_cols.append(F.avg(input_repl).alias(agg["name"])) - if aggregator["name"] in set("collect_set_int", "collect_set_float", "collect_set_string"): + if aggregator["name"] in {"collect_set_int", "collect_set_float", "collect_set_string"}: agg_cols.append(F.collect_set(input_repl).alias(agg["name"])) if aggregator["name"] == "count": agg_cols.append(F.count(input_repl).alias(agg["name"])) @@ -438,11 +438,11 @@ def run_builtin_aggregators(builtin_aggregates, df, ctx, spark): agg_cols.append(F.covar_samp(input_repl["col1"], input_repl["col2"]).alias(agg["name"])) if aggregator["name"] == "kurtosis": agg_cols.append(F.kurtosis(input_repl).alias(agg["name"])) - if aggregator["name"] in set("max_int", "max_float", "max_string"): + if aggregator["name"] in {"max_int", "max_float", "max_string"}: agg_cols.append(F.max(input_repl).alias(agg["name"])) if aggregator["name"] == "mean": agg_cols.append(F.mean(input_repl).alias(agg["name"])) - if aggregator["name"] in set("min_int", "min_float", "min_string"): + if aggregator["name"] in {"min_int", "min_float", "min_string"}: agg_cols.append(F.min(input_repl).alias(agg["name"])) if aggregator["name"] == "skewness": agg_cols.append(F.skewness(input_repl).alias(agg["name"])) @@ -452,9 +452,9 @@ def run_builtin_aggregators(builtin_aggregates, df, ctx, spark): agg_cols.append(F.stddev_pop(input_repl).alias(agg["name"])) if aggregator["name"] == "stddev_samp": agg_cols.append(F.stddev_samp(input_repl).alias(agg["name"])) - if aggregator["name"] in set("sum_int", "sum_float"): + if aggregator["name"] in {"sum_int", "sum_float"}: agg_cols.append(F.sum(input_repl).alias(agg["name"])) - if aggregator["name"] in set("sum_distinct_int", "sum_distinct_float"): + if aggregator["name"] in {"sum_distinct_int", "sum_distinct_float"}: agg_cols.append(F.sumDistinct(input_repl).alias(agg["name"])) if aggregator["name"] == "var_pop": agg_cols.append(F.var_pop(input_repl).alias(agg["name"])) @@ -582,7 +582,7 @@ def _transform(*values): def _transform_and_validate(*values): result = _transform(*values) - if not util.validate_column_type(result, column_type): + if not util.validate_cortex_type(result, column_type): raise UserException( "transformed column " + column_name, "tranformer " + transformed_column["transformer"], diff --git a/pkg/workloads/tf_api/api.py b/pkg/workloads/tf_api/api.py index 858b938a4f..37046488b6 100644 --- a/pkg/workloads/tf_api/api.py +++ b/pkg/workloads/tf_api/api.py @@ -18,6 +18,7 @@ import argparse import tensorflow as tf import traceback +import time from flask import Flask, request, jsonify from flask_api import status from waitress import serve @@ -25,11 +26,13 @@ from tensorflow_serving.apis import predict_pb2 from tensorflow_serving.apis import get_model_metadata_pb2 from tensorflow_serving.apis import prediction_service_pb2_grpc +from google.protobuf import json_format + +import consts from lib import util, tf_lib, package, Context from lib.log import get_logger from lib.exceptions import CortexException, UserRuntimeException, UserException -from google.protobuf import json_format -import time +from lib.context import create_transformer_inputs_from_map logger = get_logger() logger.propagate = False # prevent double logging (flask modifies root logger) @@ -74,8 +77,9 @@ def transform_sample(sample): else: transformed_column = ctx.transformed_columns[column_name] input_repl = ctx.populate_values( - transformed_column["input"], None, preserve_column_refs=False + transformed_column["input"], None, preserve_column_refs=True ) + transformer_input = create_transformer_inputs_from_map(input_repl, sample) trans_impl = local_cache["trans_impls"][column_name] if not hasattr(trans_impl, "transform_python"): raise UserException( @@ -83,7 +87,7 @@ def transform_sample(sample): "transformer " + transformed_column["transformer"], "transform_python() function is missing", ) - transformed_value = trans_impl.transform_python(input_repl) + transformed_value = trans_impl.transform_python(transformer_input) transformed_sample[column_name] = transformed_value @@ -155,7 +159,7 @@ def parse_response_proto(response_proto): outputs = results_dict["outputs"] value_key = DTYPE_TO_VALUE_KEY[outputs[prediction_key]["dtype"]] predicted = outputs[prediction_key][value_key][0] - predicted = util.upcast(predicted, target_col_type) + predicted = util.cast(predicted, target_col_type) result = {} for key in outputs.keys(): diff --git a/pkg/workloads/tf_train/train.py b/pkg/workloads/tf_train/train.py index 35f74a2068..5e9f03c52c 100644 --- a/pkg/workloads/tf_train/train.py +++ b/pkg/workloads/tf_train/train.py @@ -41,7 +41,7 @@ def train(args): ctx.upload_resource_status_start(model) try: - estimator_impl = ctx.get_estimator_impl(model["name"]) + estimator_impl, _ = ctx.get_estimator_impl(model["name"]) train_util.train(model["name"], estimator_impl, ctx, model_dir) ctx.upload_resource_status_success(model) diff --git a/pkg/workloads/tf_train/train_util.py b/pkg/workloads/tf_train/train_util.py index 2bf3bf2506..7cdc131fc8 100644 --- a/pkg/workloads/tf_train/train_util.py +++ b/pkg/workloads/tf_train/train_util.py @@ -20,6 +20,7 @@ import math import tensorflow as tf +import consts from lib import util, tf_lib from lib.exceptions import UserRuntimeException From cabc489df661a330a77938a71748aa371e9deaf0 Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Mon, 10 Jun 2019 19:22:39 -0700 Subject: [PATCH 12/44] Clean up --- pkg/workloads/lib/context.py | 144 +------------------------- pkg/workloads/lib/util.py | 22 ---- pkg/workloads/spark_job/spark_util.py | 28 ----- 3 files changed, 1 insertion(+), 193 deletions(-) diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index 45b7aab520..12f98a58d7 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -445,7 +445,7 @@ def populate_values(self, input, input_schema, preserve_column_refs): if res_name in self.columns: if input_schema is not None: col_type = self.get_inferred_column_type(res_name) - if not column_type_matches(col_type, input_schema["_type"]): + if col_type not in input_schema["_type"]: raise UserException( "column {}: column type mismatch: got {}, expected {}".format( res_name, col_type, input_schema["_type"] @@ -587,12 +587,6 @@ def is_compound_type(type_str): return True -def column_type_matches(value_type, schema_type): - if consts.COLUMN_TYPE_FLOAT in schema_type: - schema_type = schema_type + "|" + consts.COLUMN_TYPE_INT - return value_type in schema_type - - def cast_compound_type(value, type_str): allowed_types = type_str.split("|") if consts.VALUE_TYPE_INT in allowed_types: @@ -723,139 +717,3 @@ def create_transformer_inputs_from_lists(input, input_cols_sorted, col_values): col_value_map[col_name] = col_value return create_transformer_inputs_from_map(input, col_value_map) - - -# def create_column_inputs_map(self, values_map, column_name): -# """Construct an inputs dict with actual data""" -# columns_input_config = self.transformed_columns[column_name]["inputs"]["columns"] -# return create_inputs_map(values_map, columns_input_config) - -# def create_inputs_map(values_map, input_config): -# inputs = {} -# for input_name, input_config_item in input_config.items(): -# if util.is_str(input_config_item): -# inputs[input_name] = values_map[input_config_item] -# elif util.is_int(input_config_item): -# inputs[input_name] = values_map[input_config_item] -# elif util.is_list(input_config_item): -# inputs[input_name] = [values_map[f] for f in input_config_item] -# else: -# raise CortexException("invalid column inputs") - -# return inputs - -# def populate_args(self, args_dict): -# return { -# arg_name: self.get_obj(self.values[value_name]["key"]) -# for arg_name, value_name in args_dict.items() -# } - -# def get_model_impl(self, model_name): -# if model_name in self._model_impls: -# return self._model_impls[model_name] - -# model = self.models[model_name] - -# try: -# impl, impl_path = self.load_module("model", model_name, model["impl_key"]) -# _validate_impl(impl, MODEL_IMPL_VALIDATION) -# except CortexException as e: -# e.wrap("model " + model_name) -# raise - -# self._model_impls[model_name] = impl -# return impl - -# def column_config(self, column_name): -# if self.is_raw_column(column_name): -# return self.raw_column_config(column_name) -# elif self.is_transformed_column(column_name): -# return self.transformed_column_config(column_name) -# return None - -# def raw_column_config(self, column_name): -# raw_column = self.raw_columns[column_name] -# if raw_column is None: -# return None -# config = deepcopy(raw_column) -# config_keys = ["name", "type", "required", "min", "max", "values", "tags"] -# util.keep_dict_keys(config, config_keys) -# return config - -# def transformed_column_config(self, column_name): -# transformed_column = self.transformed_columns[column_name] -# if transformed_column is None: -# return None -# config = deepcopy(transformed_column) -# config_keys = ["name", "transformer", "inputs", "tags", "type"] -# util.keep_dict_keys(config, config_keys) -# config["inputs"] = self._expand_inputs_config(config["inputs"]) -# config["transformer"] = self.transformer_config(config["transformer"]) -# return config - -# def value_config(self, value_name): -# if self.is_constant(value_name): -# return self.constant_config(value_name) -# elif self.is_aggregate(value_name): -# return self.aggregate_config(value_name) -# return None - -# def constant_config(self, constant_name): -# constant = self.constants[constant_name] -# if constant is None: -# return None -# config = deepcopy(constant) -# config_keys = ["name", "type", "tags"] -# util.keep_dict_keys(config, config_keys) -# return config - -# def aggregate_config(self, aggregate_name): -# aggregate = self.aggregates[aggregate_name] -# if aggregate is None: -# return None -# config = deepcopy(aggregate) -# config_keys = ["name", "type", "inputs", "aggregator", "tags"] -# util.keep_dict_keys(config, config_keys) -# config["inputs"] = self._expand_inputs_config(config["inputs"]) -# config["aggregator"] = self.aggregator_config(config["aggregator"]) -# return config - -# def aggregator_config(self, aggregator_name): -# aggregator = self.aggregators[aggregator_name] -# if aggregator is None: -# return None -# config = deepcopy(aggregator) -# config_keys = ["name", "output_type", "inputs"] -# util.keep_dict_keys(config, config_keys) -# config["name"] = aggregator_name # Use the fully qualified name (includes namespace) -# return config - -# def transformer_config(self, transformer_name): -# transformer = self.transformers[transformer_name] -# if transformer is None: -# return None -# config = deepcopy(transformer) -# config_keys = ["name", "output_type", "inputs"] -# util.keep_dict_keys(config, config_keys) -# config["name"] = transformer_name # Use the fully qualified name (includes namespace) -# return config - -# def _expand_inputs_config(self, inputs_config): -# inputs_config["columns"] = self._expand_columns_input_dict(inputs_config["columns"]) -# inputs_config["args"] = self._expand_args_dict(inputs_config["args"]) -# return inputs_config - -# def _expand_columns_input_dict(self, input_columns_dict): -# expanded = {} -# for column_name, value in input_columns_dict.items(): -# if util.util.is_str(value): -# expanded[column_name] = self.column_config(value) -# elif util.is_list(value): -# expanded[column_name] = [self.column_config(name) for name in value] -# return expanded - -# def _expand_args_dict(self, args_dict): -# expanded = {} -# for arg_name, value_name in args_dict.items(): -# expanded[arg_name] = self.value_config(value_name) -# return expanded diff --git a/pkg/workloads/lib/util.py b/pkg/workloads/lib/util.py index 35cd40919a..2e2c1b2947 100644 --- a/pkg/workloads/lib/util.py +++ b/pkg/workloads/lib/util.py @@ -872,25 +872,3 @@ def extract_resource_refs(input): return resources return set() - - -# def replace_resource_refs(input): -# if is_str(input): -# res = util.get_resource_ref(input) -# if res is not None: -# return res -# return input - -# if is_dict(input): -# replaced = {} -# for key, val in input.items(): -# replaced[replace_resource_refs(key)] = replace_resource_refs(val) -# return replaced - -# if is_list(input): -# replaced = [] -# for item in input: -# replaced.append(replace_resource_refs(item)) -# return replaced - -# return input diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index f24308fbc1..1f43f3b069 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -508,16 +508,6 @@ def run_custom_aggregator(aggregate, df, ctx, spark): return result -# def extract_inputs(column_name, ctx): -# columns_input_config = ctx.transformed_columns[column_name]["inputs"]["columns"] -# impl_args_schema = ctx.transformed_columns[column_name]["inputs"]["args"] -# if impl_args_schema is not None: -# impl_args = ctx.populate_args(impl_args_schema) -# else: -# impl_args = {} -# return columns_input_config, impl_args - - def execute_transform_spark(column_name, df, ctx, spark): trans_impl, trans_impl_path = ctx.get_transformer_impl(column_name) transformed_column = ctx.transformed_columns[column_name] @@ -536,24 +526,6 @@ def execute_transform_spark(column_name, df, ctx, spark): raise UserRuntimeException("function transform_spark") from e -# def column_names_to_index(columns_input_config): -# column_list = [] -# for k, v in columns_input_config.items(): -# if util.is_list(v): -# column_list += v -# else: -# column_list.append(v) - -# required_input_columns_sorted = sorted(set(column_list)) - -# index_to_col_map = dict( -# [(column_name, idx) for idx, column_name in enumerate(required_input_columns_sorted)] -# ) - -# columns_input_config_indexed = create_inputs_map(index_to_col_map, columns_input_config) -# return required_input_columns_sorted, columns_input_config_indexed - - def execute_transform_python(column_name, df, ctx, spark, validate=False): trans_impl, trans_impl_path = ctx.get_transformer_impl(column_name) transformed_column = ctx.transformed_columns[column_name] From 301dd0e8186a395e37b6ce2fea7250d0bad27bd2 Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Mon, 10 Jun 2019 23:56:30 -0700 Subject: [PATCH 13/44] Fix some tests --- images/test/Dockerfile | 1 + pkg/workloads/lib/test/context_test.py | 68 ------------------- pkg/workloads/lib/test/util_test.py | 10 ++- pkg/workloads/lib/util.py | 6 +- .../test/integration/iris_context.py | 2 +- 5 files changed, 9 insertions(+), 78 deletions(-) delete mode 100644 pkg/workloads/lib/test/context_test.py diff --git a/images/test/Dockerfile b/images/test/Dockerfile index c643b96790..fff92acae7 100644 --- a/images/test/Dockerfile +++ b/images/test/Dockerfile @@ -5,6 +5,7 @@ RUN pip3 install pytest mock COPY pkg/workloads /src COPY pkg/aggregators /aggregators COPY pkg/transformers /transformers +COPY pkg/estimators /estimators COPY images/test/run.sh /src/run.sh diff --git a/pkg/workloads/lib/test/context_test.py b/pkg/workloads/lib/test/context_test.py deleted file mode 100644 index ce8ad8798a..0000000000 --- a/pkg/workloads/lib/test/context_test.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2019 Cortex Labs, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys -import os - -from lib import context - - -def test_create_inputs_map(): - values_map = { - "e11": "value_11", - "e12": 12, - "e21": 2.1, - "e22": "value_22", - "e0": "value_e0", - "e1": "value_e1", - "e2": "value_e2", - "e3": "value_e3", - "e4": "value_e4", - } - - input_config = {"in": "e11"} - inputs = context.create_inputs_map(values_map, input_config) - assert inputs == {"in": "value_11"} - - input_config = {"a1": ["e11", "e12"], "a2": ["e21", "e22"]} - inputs = context.create_inputs_map(values_map, input_config) - assert inputs == {"a1": ["value_11", 12], "a2": [2.1, "value_22"]} - - values_map2 = { - "f1": 111, - "f2": 2.22, - "f3": "3", - "f4": "4", - "f5": "5", - "f6": "6", - "f7": "7", - "f8": "8", - "f9": "9", - } - - input_config = {"in1": "f1"} - inputs = context.create_inputs_map(values_map2, input_config) - assert inputs == {"in1": 111} - - input_config = {"in1": "f1", "in2": "f2"} - inputs = context.create_inputs_map(values_map2, input_config) - assert inputs == {"in1": 111, "in2": 2.22} - - input_config = {"in1": ["f1", "f2", "f3"]} - inputs = context.create_inputs_map(values_map2, input_config) - assert inputs == {"in1": [111, 2.22, "3"]} - - input_config = {"in1": ["f1", "f2", "f3"], "in2": ["f4", "f5", "f6"]} - inputs = context.create_inputs_map(values_map2, input_config) - assert inputs == {"in1": [111, 2.22, "3"], "in2": ["4", "5", "6"]} diff --git a/pkg/workloads/lib/test/util_test.py b/pkg/workloads/lib/test/util_test.py index 2557c068cb..ef2ed60d73 100644 --- a/pkg/workloads/lib/test/util_test.py +++ b/pkg/workloads/lib/test/util_test.py @@ -153,7 +153,7 @@ def test_validate_cortex_type(): assert util.validate_cortex_type(None, "INT_COLUMN") == True assert util.validate_cortex_type(2.2, "FLOAT_COLUMN") == True - assert util.validate_cortex_type(2, "FLOAT_COLUMN") == False + assert util.validate_cortex_type(2, "FLOAT_COLUMN") == True assert util.validate_cortex_type("2", "FLOAT_COLUMN") == False assert util.validate_cortex_type(None, "FLOAT_COLUMN") == True @@ -173,7 +173,7 @@ def test_validate_output_type(): assert util.validate_output_type(None, "INT") == True assert util.validate_output_type(2.2, "FLOAT") == True - assert util.validate_output_type(2, "FLOAT") == False + assert util.validate_output_type(2, "FLOAT") == True assert util.validate_output_type("2", "FLOAT") == False assert util.validate_output_type(None, "FLOAT") == True @@ -195,7 +195,7 @@ def test_validate_output_type(): assert util.validate_output_type({"test": 2.2}, {"STRING": "INT|FLOAT"}) == True assert util.validate_output_type({"a": 2.2, "b": False}, {"STRING": "FLOAT|BOOL"}) == True - assert util.validate_output_type({"test": 2.2, "test2": 3}, {"STRING": "FLOAT|BOOL"}) == False + assert util.validate_output_type({"test": 2.2, "test2": 3}, {"STRING": "FLOAT|BOOL"}) == True assert util.validate_output_type({}, {"STRING": "INT|FLOAT"}) == True assert util.validate_output_type({"test": "2.2"}, {"STRING": "FLOAT|INT"}) == False assert util.validate_output_type({2: 2.2}, {"STRING": "INT|FLOAT"}) == False @@ -205,9 +205,7 @@ def test_validate_output_type(): assert util.validate_output_type({"f": "s", "i": 2}, {"f": "FLOAT", "i": "INT"}) == False assert util.validate_output_type({"f": 2.2}, {"f": "FLOAT", "i": "INT"}) == False assert util.validate_output_type({"f": 2.2, "i": None}, {"f": "FLOAT", "i": "INT"}) == True - assert ( - util.validate_output_type({"f": 0.2, "i": 2, "e": 1}, {"f": "FLOAT", "i": "INT"}) == False - ) + assert util.validate_output_type({"f": 0.2, "i": 2, "e": 1}, {"f": "FLOAT", "i": "INT"}) == True assert util.validate_output_type(["s"], ["STRING"]) == True assert util.validate_output_type(["a", "b", "c"], ["STRING"]) == True diff --git a/pkg/workloads/lib/util.py b/pkg/workloads/lib/util.py index 2e2c1b2947..1d11d57f98 100644 --- a/pkg/workloads/lib/util.py +++ b/pkg/workloads/lib/util.py @@ -780,10 +780,10 @@ def validate_output_type(value, output_type): return True # Fixed map - for value_key, value_val in value.items(): - if value_key not in output_type: + for type_key, type_val in output_type.items(): + if type_key not in value: return False - if not validate_output_type(value_val, output_type[value_key]): + if not validate_output_type(value[type_key], type_val): return False return True diff --git a/pkg/workloads/spark_job/test/integration/iris_context.py b/pkg/workloads/spark_job/test/integration/iris_context.py index cfcda308b8..d36a9e4dd1 100644 --- a/pkg/workloads/spark_job/test/integration/iris_context.py +++ b/pkg/workloads/spark_job/test/integration/iris_context.py @@ -19,7 +19,7 @@ 1. cx deploy 2. get a path to a context -3. ssh into a docker container (spark/tf_train) +3. ssh into a docker container (spark or tf_train) docker run -it --entrypoint "/bin/bash" cortexlabs/spark 4. run the following in python3 shell From aff0cf0f070d9b87d09a7da7babe806c0a74fb7f Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Tue, 11 Jun 2019 00:24:03 -0700 Subject: [PATCH 14/44] Fix api bug --- cli/cmd/get.go | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/cli/cmd/get.go b/cli/cmd/get.go index 335791282a..dddb2a3c19 100644 --- a/cli/cmd/get.go +++ b/cli/cmd/get.go @@ -27,6 +27,7 @@ import ( "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/json" "github.com/cortexlabs/cortex/pkg/lib/msgpack" + "github.com/cortexlabs/cortex/pkg/lib/sets/strset" s "github.com/cortexlabs/cortex/pkg/lib/strings" libtime "github.com/cortexlabs/cortex/pkg/lib/time" "github.com/cortexlabs/cortex/pkg/lib/urls" @@ -411,13 +412,20 @@ func describeAPI(name string, resourcesRes *schema.GetResourcesResponse) (string } out += titleStr("Endpoint") - var samplePlaceholderFields []string + resIDs := strset.New() combinedInput := []interface{}{model.Input, model.TrainingInput} - for _, res := range ctx.ExtractCortexResources(combinedInput, resource.RawColumnType) { - rawColumn := res.(context.RawColumn) - fieldStr := `"` + rawColumn.GetName() + `": ` + rawColumn.GetColumnType().JSONPlaceholder() - samplePlaceholderFields = append(samplePlaceholderFields, fieldStr) + for _, res := range ctx.ExtractCortexResources(combinedInput) { + resIDs.Add(res.GetID()) + resIDs.Merge(ctx.AllComputedResourceDependencies(res.GetID())) + } + var samplePlaceholderFields []string + for rawColumnName, rawColumn := range ctx.RawColumns { + if resIDs.Has(rawColumn.GetID()) { + fieldStr := `"` + rawColumnName + `": ` + rawColumn.GetColumnType().JSONPlaceholder() + samplePlaceholderFields = append(samplePlaceholderFields, fieldStr) + } } + sort.Strings(samplePlaceholderFields) samplesPlaceholderStr := `{ "samples": [ { ` + strings.Join(samplePlaceholderFields, ", ") + " } ] }" out += "URL: " + urls.Join(resourcesRes.APIsBaseURL, anyAPIStatus.Path) + "\n" out += "Method: POST\n" From 20fc3ca4eba2a31961e6db95b80b7dfa1b75006e Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Tue, 11 Jun 2019 01:04:30 -0700 Subject: [PATCH 15/44] Fix predict CLI command --- cli/cmd/predict.go | 62 +++++++++++++------------------------ pkg/workloads/tf_api/api.py | 26 ++++++---------- 2 files changed, 32 insertions(+), 56 deletions(-) diff --git a/cli/cmd/predict.go b/cli/cmd/predict.go index bf70164480..4fe20925e1 100644 --- a/cli/cmd/predict.go +++ b/cli/cmd/predict.go @@ -24,6 +24,7 @@ import ( "github.com/spf13/cobra" + "github.com/cortexlabs/cortex/pkg/lib/cast" "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/files" "github.com/cortexlabs/cortex/pkg/lib/json" @@ -42,22 +43,15 @@ func init() { } type PredictResponse struct { - ResourceID string `json:"resource_id"` - ClassificationPredictions []ClassificationPrediction `json:"classification_predictions,omitempty"` - RegressionPredictions []RegressionPrediction `json:"regression_predictions,omitempty"` + ResourceID string `json:"resource_id"` + Predictions []Prediction `json:"predictions"` } -type ClassificationPrediction struct { - PredictedClass int `json:"predicted_class"` - PredictedClassReversed interface{} `json:"predicted_class_reversed"` - Probabilities []float64 `json:"probabilities"` - TransformedSample interface{} `json:"transformed_sample"` -} - -type RegressionPrediction struct { - PredictedValue float64 `json:"predicted_value"` - PredictedValueReversed interface{} `json:"predicted_value_reversed"` - TransformedSample interface{} `json:"transformed_sample"` +type Prediction struct { + Prediction interface{} `json:"prediction"` + PredictionReversed interface{} `json:"prediction_reversed"` + TransformedSample interface{} `json:"transformed_sample"` + Response interface{} `json:"response"` } var predictCmd = &cobra.Command{ @@ -108,34 +102,22 @@ var predictCmd = &cobra.Command{ apiStart := libtime.LocalTimestampHuman(api.Start) fmt.Println("\n" + apiName + " was last updated on " + apiStart + "\n") - if predictResponse.ClassificationPredictions != nil { - if len(predictResponse.ClassificationPredictions) == 1 { - fmt.Println("Predicted class:") - } else { - fmt.Println("Predicted classes:") - } - for _, prediction := range predictResponse.ClassificationPredictions { - if prediction.PredictedClassReversed != nil { - json, _ := json.Marshal(prediction.PredictedClassReversed) - fmt.Println(s.TrimPrefixAndSuffix(string(json), "\"")) - } else { - fmt.Println(prediction.PredictedClass) - } - } + if len(predictResponse.Predictions) == 1 { + fmt.Println("Prediction:") + } else { + fmt.Println("Predictions:") } - if predictResponse.RegressionPredictions != nil { - if len(predictResponse.RegressionPredictions) == 1 { - fmt.Println("Predicted value:") - } else { - fmt.Println("Predicted values:") + + for _, prediction := range predictResponse.Predictions { + value := prediction.Prediction + if prediction.PredictionReversed != nil { + value = prediction.PredictionReversed } - for _, prediction := range predictResponse.RegressionPredictions { - if prediction.PredictedValueReversed != nil { - json, _ := json.Marshal(prediction.PredictedValueReversed) - fmt.Println(s.TrimPrefixAndSuffix(string(json), "\"")) - } else { - fmt.Println(s.Round(prediction.PredictedValue, 2, true)) - } + + if casted, ok := cast.InterfaceToFloat64(value); ok { + fmt.Println(s.Round(casted, 2, true)) + } else { + fmt.Println(s.UserStrStripped(value)) } } }, diff --git a/pkg/workloads/tf_api/api.py b/pkg/workloads/tf_api/api.py index 37046488b6..54d50d0d1b 100644 --- a/pkg/workloads/tf_api/api.py +++ b/pkg/workloads/tf_api/api.py @@ -158,20 +158,17 @@ def parse_response_proto(response_proto): results_dict = json_format.MessageToDict(response_proto) outputs = results_dict["outputs"] value_key = DTYPE_TO_VALUE_KEY[outputs[prediction_key]["dtype"]] - predicted = outputs[prediction_key][value_key][0] - predicted = util.cast(predicted, target_col_type) + prediction = outputs[prediction_key][value_key][0] + prediction = util.cast(prediction, target_col_type) result = {} + result["prediction"] = prediction + result["prediction_reversed"] = reverse_transform(prediction) + + result["response"] = {} for key in outputs.keys(): value_key = DTYPE_TO_VALUE_KEY[outputs[key]["dtype"]] - result[key] = outputs[key][value_key] - - if target_col_type in {consts.COLUMN_TYPE_STRING, consts.COLUMN_TYPE_INT}: - result["predicted_class"] = predicted - result["predicted_class_reversed"] = reverse_transform(predicted) - else: - result["predicted_value"] = predicted - result["predicted_value_reversed"] = reverse_transform(predicted) + result["response"][key] = outputs[key][value_key] return result @@ -198,7 +195,6 @@ def run_predict(sample): prediction_request = create_prediction_request(transformed_sample) response_proto = local_cache["stub"].Predict(prediction_request, timeout=10.0) result = parse_response_proto(response_proto) - result["transformed_sample"] = transformed_sample util.log_indent("Raw sample:", indent=4) util.log_pretty(sample, indent=6) util.log_indent("Transformed sample:", indent=4) @@ -206,6 +202,8 @@ def run_predict(sample): util.log_indent("Prediction:", indent=4) util.log_pretty(result, indent=6) + result["transformed_sample"] = transformed_sample + return result @@ -297,11 +295,7 @@ def predict(app_name, api_name): predictions.append(result) - if target_col_type in {consts.COLUMN_TYPE_STRING, consts.COLUMN_TYPE_INT}: - response["classification_predictions"] = predictions - else: - response["regression_predictions"] = predictions - + response["predictions"] = predictions response["resource_id"] = api["id"] return jsonify(response) From fad5553baf97bd3f61822e512408b7f96ab36212 Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Tue, 11 Jun 2019 10:26:16 -0700 Subject: [PATCH 16/44] Rename boolean functions --- pkg/lib/configreader/string.go | 6 +++--- pkg/lib/regex/regex.go | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pkg/lib/configreader/string.go b/pkg/lib/configreader/string.go index 8ef0742163..003d15a332 100644 --- a/pkg/lib/configreader/string.go +++ b/pkg/lib/configreader/string.go @@ -193,19 +193,19 @@ func ValidateStringVal(val string, v *StringValidation) error { } if v.AlphaNumericDashDotUnderscore { - if !regex.CheckAlphaNumericDashDotUnderscore(val) { + if !regex.IsAlphaNumericDashDotUnderscore(val) { return ErrorAlphaNumericDashDotUnderscore(val) } } if v.AlphaNumericDashUnderscore { - if !regex.CheckAlphaNumericDashUnderscore(val) { + if !regex.IsAlphaNumericDashUnderscore(val) { return ErrorAlphaNumericDashUnderscore(val) } } if v.AlphaNumericDashDotUnderscoreOrEmpty { - if !regex.CheckAlphaNumericDashDotUnderscore(val) && val != "" { + if !regex.IsAlphaNumericDashDotUnderscore(val) && val != "" { return ErrorAlphaNumericDashDotUnderscore(val) } } diff --git a/pkg/lib/regex/regex.go b/pkg/lib/regex/regex.go index fb4aacf7b0..4d59289be0 100644 --- a/pkg/lib/regex/regex.go +++ b/pkg/lib/regex/regex.go @@ -31,12 +31,12 @@ func MatchAnyRegex(s string, regexes []*regexp.Regexp) bool { var alphaNumericDashDotUnderscoreRegex = regexp.MustCompile(`^[a-zA-Z0-9_\-\.]+$`) -func CheckAlphaNumericDashDotUnderscore(s string) bool { +func IsAlphaNumericDashDotUnderscore(s string) bool { return alphaNumericDashDotUnderscoreRegex.MatchString(s) } var alphaNumericDashUnderscoreRegex = regexp.MustCompile(`^[a-zA-Z0-9_\-]+$`) -func CheckAlphaNumericDashUnderscore(s string) bool { +func IsAlphaNumericDashUnderscore(s string) bool { return alphaNumericDashUnderscoreRegex.MatchString(s) } From 1fb68cdca4b6cc9d97b5d0e081d8b1bfa0aabd5c Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Tue, 11 Jun 2019 12:26:54 -0700 Subject: [PATCH 17/44] Address Ivan comments --- cli/cmd/get.go | 2 +- pkg/lib/cast/interface.go | 16 ---------------- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/cli/cmd/get.go b/cli/cmd/get.go index dddb2a3c19..108c9e71b5 100644 --- a/cli/cmd/get.go +++ b/cli/cmd/get.go @@ -421,7 +421,7 @@ func describeAPI(name string, resourcesRes *schema.GetResourcesResponse) (string var samplePlaceholderFields []string for rawColumnName, rawColumn := range ctx.RawColumns { if resIDs.Has(rawColumn.GetID()) { - fieldStr := `"` + rawColumnName + `": ` + rawColumn.GetColumnType().JSONPlaceholder() + fieldStr := fmt.Sprintf("\"%s\": %s", rawColumnName, rawColumn.GetColumnType().JSONPlaceholder()) samplePlaceholderFields = append(samplePlaceholderFields, fieldStr) } } diff --git a/pkg/lib/cast/interface.go b/pkg/lib/cast/interface.go index c15147ae97..f2ebebd508 100644 --- a/pkg/lib/cast/interface.go +++ b/pkg/lib/cast/interface.go @@ -742,19 +742,3 @@ func IsScalarType(in interface{}) bool { } return false } - -func ToScalarType(in interface{}) (interface{}, bool) { - if casted, ok := InterfaceToInt64(in); ok { - return casted, true - } - if casted, ok := InterfaceToFloat64(in); ok { - return casted, true - } - if casted, ok := in.(bool); ok { - return casted, true - } - if casted, ok := in.(string); ok { - return casted, true - } - return in, false -} From da86b23de614bba3d4ff3fc4b472d4f0f4d1a997 Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Tue, 11 Jun 2019 16:28:04 -0700 Subject: [PATCH 18/44] Update iris example, fix estimator bugs --- examples/iris/app.yaml | 26 +++++++++ examples/iris/resources/aggregates.yaml | 44 --------------- examples/iris/resources/apis.yaml | 5 -- examples/iris/resources/models.yaml | 19 ------- examples/iris/resources/raw_columns.yaml | 53 ------------------- .../iris/resources/transformed_columns.yaml | 38 ------------- pkg/estimators/boosted_trees_regressor.py | 2 +- .../dnn_linear_combined_classifier.py | 4 +- .../dnn_linear_combined_regressor.py | 4 +- pkg/estimators/linear_classifier.py | 2 +- pkg/estimators/linear_regressor.py | 2 +- 11 files changed, 33 insertions(+), 166 deletions(-) delete mode 100644 examples/iris/resources/aggregates.yaml delete mode 100644 examples/iris/resources/apis.yaml delete mode 100644 examples/iris/resources/models.yaml delete mode 100644 examples/iris/resources/raw_columns.yaml delete mode 100644 examples/iris/resources/transformed_columns.yaml diff --git a/examples/iris/app.yaml b/examples/iris/app.yaml index bd81736cac..42f5a580f4 100644 --- a/examples/iris/app.yaml +++ b/examples/iris/app.yaml @@ -1,2 +1,28 @@ - kind: app name: iris + +- kind: environment + name: dev + data: + type: csv + path: s3a://cortex-examples/iris.csv + schema: [@sepal_length, @sepal_width, @petal_length, @petal_width, @class] + +- kind: model + name: dnn + estimator: cortex.dnn_classifier + target_column: @class + input: + numeric_columns: [@sepal_length, @sepal_width, @petal_length, @petal_width] + target_vocab: ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'] + hparams: + hidden_units: [4, 2] + training: + batch_size: 10 + num_steps: 1000 + +- kind: api + name: iris-type + model: @dnn + compute: + replicas: 1 diff --git a/examples/iris/resources/aggregates.yaml b/examples/iris/resources/aggregates.yaml deleted file mode 100644 index 18c05b211d..0000000000 --- a/examples/iris/resources/aggregates.yaml +++ /dev/null @@ -1,44 +0,0 @@ -- kind: aggregate - name: sepal_length_mean - aggregator: cortex.mean - input: @sepal_length - -- kind: aggregate - name: sepal_length_stddev - aggregator: cortex.stddev - input: @sepal_length - -- kind: aggregate - name: sepal_width_mean - aggregator: cortex.mean - input: @sepal_width - -- kind: aggregate - name: sepal_width_stddev - aggregator: cortex.stddev - input: @sepal_width - -- kind: aggregate - name: petal_length_mean - aggregator: cortex.mean - input: @petal_length - -- kind: aggregate - name: petal_length_stddev - aggregator: cortex.stddev - input: @petal_length - -- kind: aggregate - name: petal_width_mean - aggregator: cortex.mean - input: @petal_width - -- kind: aggregate - name: petal_width_stddev - aggregator: cortex.stddev - input: @petal_width - -- kind: aggregate - name: class_index - aggregator: cortex.index_string - input: @class diff --git a/examples/iris/resources/apis.yaml b/examples/iris/resources/apis.yaml deleted file mode 100644 index 7272bd6fe8..0000000000 --- a/examples/iris/resources/apis.yaml +++ /dev/null @@ -1,5 +0,0 @@ -- kind: api - name: iris-type - model: @dnn - compute: - replicas: 1 diff --git a/examples/iris/resources/models.yaml b/examples/iris/resources/models.yaml deleted file mode 100644 index 761e7f6d24..0000000000 --- a/examples/iris/resources/models.yaml +++ /dev/null @@ -1,19 +0,0 @@ -- kind: model - name: dnn - estimator: cortex.dnn_classifier - target_column: @class_indexed - input: - numeric_columns: - - @sepal_length_normalized - - @sepal_width_normalized - - @petal_length_normalized - - @petal_width_normalized - num_classes: 3 - hparams: - hidden_units: [4, 2] - data_partition_ratio: - training: 80 - evaluation: 20 - training: - batch_size: 10 - num_steps: 1000 diff --git a/examples/iris/resources/raw_columns.yaml b/examples/iris/resources/raw_columns.yaml deleted file mode 100644 index 0991b9cc0a..0000000000 --- a/examples/iris/resources/raw_columns.yaml +++ /dev/null @@ -1,53 +0,0 @@ -- kind: environment - name: dev - data: - type: csv - path: s3a://cortex-examples/iris.csv - schema: [@sepal_length, @sepal_width, @petal_length, @petal_width, @class] - - -- kind: environment - name: prod - data: - type: parquet - path: s3a://cortex-examples/iris.parquet - schema: - - parquet_column_name: sepal_length - raw_column: @sepal_length - - parquet_column_name: sepal_width - raw_column: @sepal_width - - parquet_column_name: petal_length - raw_column: @petal_length - - parquet_column_name: petal_width - raw_column: @petal_width - - parquet_column_name: class - raw_column: @class - -- kind: raw_column - name: sepal_length - type: FLOAT_COLUMN - min: 0 - max: 10 - -- kind: raw_column - name: sepal_width - type: FLOAT_COLUMN - min: 0 - max: 10 - -- kind: raw_column - name: petal_length - type: FLOAT_COLUMN - min: 0 - max: 10 - -- kind: raw_column - name: petal_width - type: FLOAT_COLUMN - min: 0 - max: 10 - -- kind: raw_column - name: class - type: STRING_COLUMN - values: ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'] diff --git a/examples/iris/resources/transformed_columns.yaml b/examples/iris/resources/transformed_columns.yaml deleted file mode 100644 index ddb96bfa9e..0000000000 --- a/examples/iris/resources/transformed_columns.yaml +++ /dev/null @@ -1,38 +0,0 @@ -- kind: transformed_column - name: sepal_length_normalized - transformer: cortex.normalize - input: - col: @sepal_length - mean: @sepal_length_mean - stddev: @sepal_length_stddev - -- kind: transformed_column - name: sepal_width_normalized - transformer: cortex.normalize - input: - col: @sepal_width - mean: @sepal_width_mean - stddev: @sepal_width_stddev - -- kind: transformed_column - name: petal_length_normalized - transformer: cortex.normalize - input: - col: @petal_length - mean: @petal_length_mean - stddev: @petal_length_stddev - -- kind: transformed_column - name: petal_width_normalized - transformer: cortex.normalize - input: - col: @petal_width - mean: @petal_width_mean - stddev: @petal_width_stddev - -- kind: transformed_column - name: class_indexed - transformer: cortex.index_string - input: - col: @class - indexes: @class_index diff --git a/pkg/estimators/boosted_trees_regressor.py b/pkg/estimators/boosted_trees_regressor.py index 3dfc3321e8..55881eec4f 100644 --- a/pkg/estimators/boosted_trees_regressor.py +++ b/pkg/estimators/boosted_trees_regressor.py @@ -59,7 +59,7 @@ def create_estimator(run_config, model_config): ) ) - return tf.estimator.BoostedTreesClassifier( + return tf.estimator.BoostedTreesRegressor( feature_columns=feature_columns, n_batches_per_layer=model_config["hparams"]["batches_per_layer"], weight_column=model_config["input"].get("weight_column", None), diff --git a/pkg/estimators/dnn_linear_combined_classifier.py b/pkg/estimators/dnn_linear_combined_classifier.py index f9e4410a94..2857bcc88c 100644 --- a/pkg/estimators/dnn_linear_combined_classifier.py +++ b/pkg/estimators/dnn_linear_combined_classifier.py @@ -5,7 +5,7 @@ def create_estimator(run_config, model_config): dnn_feature_columns = [] for col_name in model_config["input"]["dnn_columns"]["numeric_columns"]: - feature_columns.append(tf.feature_column.numeric_column(col_name)) + dnn_feature_columns.append(tf.feature_column.numeric_column(col_name)) for col_info in model_config["input"]["dnn_columns"]["categorical_columns_with_vocab"]: col = tf.feature_column.categorical_column_with_vocabulary_list( @@ -114,7 +114,7 @@ def create_estimator(run_config, model_config): target_vocab = model_config["input"]["target_vocab"] num_classes = len(target_vocab) - return tf.estimator.DNNClassifier( + return tf.estimator.DNNLinearCombinedClassifier( linear_feature_columns=linear_feature_columns, dnn_feature_columns=dnn_feature_columns, n_classes=num_classes, diff --git a/pkg/estimators/dnn_linear_combined_regressor.py b/pkg/estimators/dnn_linear_combined_regressor.py index cb6730d763..14ba51f7a6 100644 --- a/pkg/estimators/dnn_linear_combined_regressor.py +++ b/pkg/estimators/dnn_linear_combined_regressor.py @@ -5,7 +5,7 @@ def create_estimator(run_config, model_config): dnn_feature_columns = [] for col_name in model_config["input"]["dnn_columns"]["numeric_columns"]: - feature_columns.append(tf.feature_column.numeric_column(col_name)) + dnn_feature_columns.append(tf.feature_column.numeric_column(col_name)) for col_info in model_config["input"]["dnn_columns"]["categorical_columns_with_vocab"]: col = tf.feature_column.categorical_column_with_vocabulary_list( @@ -101,7 +101,7 @@ def create_estimator(run_config, model_config): ) ) - return tf.estimator.DNNClassifier( + return tf.estimator.DNNLinearCombinedRegressor( linear_feature_columns=linear_feature_columns, dnn_feature_columns=dnn_feature_columns, dnn_hidden_units=model_config["hparams"]["dnn_hidden_units"], diff --git a/pkg/estimators/linear_classifier.py b/pkg/estimators/linear_classifier.py index adc9d24afb..70514f84af 100644 --- a/pkg/estimators/linear_classifier.py +++ b/pkg/estimators/linear_classifier.py @@ -57,7 +57,7 @@ def create_estimator(run_config, model_config): target_vocab = model_config["input"]["target_vocab"] num_classes = len(target_vocab) - return tf.estimator.DNNClassifier( + return tf.estimator.LinearClassifier( feature_columns=feature_columns, n_classes=num_classes, label_vocabulary=target_vocab, diff --git a/pkg/estimators/linear_regressor.py b/pkg/estimators/linear_regressor.py index cf95399ddc..7c990d2829 100644 --- a/pkg/estimators/linear_regressor.py +++ b/pkg/estimators/linear_regressor.py @@ -44,7 +44,7 @@ def create_estimator(run_config, model_config): ) ) - return tf.estimator.DNNRegressor( + return tf.estimator.LinearRegressor( feature_columns=feature_columns, weight_column=model_config["input"].get("weight_column", None), config=run_config, From a137b1250f6c529c6e1bbd2a875b35e553828012 Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Tue, 11 Jun 2019 16:36:16 -0700 Subject: [PATCH 19/44] Clean up --- pkg/consts/consts.go | 1 - pkg/workloads/lib/context.py | 63 +++++++++++++++--------- pkg/workloads/lib/util.py | 25 +++++----- pkg/workloads/spark_job/spark_util.py | 70 ++++++++++++--------------- pkg/workloads/tf_api/api.py | 17 +++---- 5 files changed, 93 insertions(+), 83 deletions(-) diff --git a/pkg/consts/consts.go b/pkg/consts/consts.go index fd80d06442..8b92580411 100644 --- a/pkg/consts/consts.go +++ b/pkg/consts/consts.go @@ -46,7 +46,6 @@ var ( AggregatesDir = "aggregates" TransformersDir = "transformers" EstimatorsDir = "estimators" - ModelImplsDir = "model_implementations" PythonPackagesDir = "python_packages" ModelsDir = "models" ConstantsDir = "constants" diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index 12f98a58d7..dabb356c13 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -415,14 +415,15 @@ def get_inferred_column_type(self, column_name): return column_type - # replaces column references with column names (unless preserve_column_refs = true, then leaves them untouched) + # Replace aggregates and constants with their values, and columns with their names (unless preserve_column_refs == False) + # Also validate against input_schema (if not None) def populate_values(self, input, input_schema, preserve_column_refs): if input is None: if input_schema is None: return None - if input_schema["_allow_null"]: + if input_schema.get("_allow_null") == True: return None - raise UserException("Null is not allowed") + raise UserException("Null value is not allowed") if util.is_resource_ref(input): res_name = util.get_resource_ref(input) @@ -447,8 +448,10 @@ def populate_values(self, input, input_schema, preserve_column_refs): col_type = self.get_inferred_column_type(res_name) if col_type not in input_schema["_type"]: raise UserException( - "column {}: column type mismatch: got {}, expected {}".format( - res_name, col_type, input_schema["_type"] + "column {}: unsupported input type (expected type {}, got type {})".format( + res_name, + util.data_type_str(input_schema["_type"]), + util.data_type_str(col_type), ) ) if preserve_column_refs: @@ -460,13 +463,17 @@ def populate_values(self, input, input_schema, preserve_column_refs): elem_schema = None if input_schema is not None: if not util.is_list(input_schema["_type"]): - raise UserException("unexpected type (list)") + raise UserException( + "unsupported input type (expected type {}, got {})".format( + util.data_type_str(input_schema["_type"]), util.pp_str_flat(input) + ) + ) elem_schema = input_schema["_type"][0] min_count = input_schema.get("_min_count") if min_count is not None and len(input) < min_count: raise UserException( - "list has length {}, but the minimum length is {}".format( + "list has length {}, but the minimum allowed length is {}".format( len(input), min_count ) ) @@ -474,7 +481,7 @@ def populate_values(self, input, input_schema, preserve_column_refs): max_count = input_schema.get("_max_count") if max_count is not None and len(input) > max_count: raise UserException( - "list has length {}, but the maximum length is {}".format( + "list has length {}, but the maximum allowed length is {}".format( len(input), max_count ) ) @@ -496,24 +503,32 @@ def populate_values(self, input, input_schema, preserve_column_refs): try: val_casted = self.populate_values(val, None, preserve_column_refs) except CortexException as e: - e.wrap(util.pp_str_flat(key_casted)) + e.wrap(util.pp_str_flat(key)) raise casted[key_casted] = val_casted return casted if not util.is_dict(input_schema["_type"]): - raise UserException("unexpected type (map)") + raise UserException( + "unsupported input type (expected type {}, got {})".format( + util.data_type_str(input_schema["_type"]), util.pp_str_flat(input) + ) + ) min_count = input_schema.get("_min_count") if min_count is not None and len(input) < min_count: raise UserException( - "map has length {}, but the minimum length is {}".format(len(input), min_count) + "map has length {}, but the minimum allowed length is {}".format( + len(input), min_count + ) ) max_count = input_schema.get("_max_count") if max_count is not None and len(input) > max_count: raise UserException( - "map has length {}, but the maximum length is {}".format(len(input), max_count) + "map has length {}, but the maximum allowed length is {}".format( + len(input), max_count + ) ) is_generic_map = False @@ -535,7 +550,7 @@ def populate_values(self, input, input_schema, preserve_column_refs): val, generic_map_value, preserve_column_refs ) except CortexException as e: - e.wrap(util.pp_str_flat(key_casted)) + e.wrap(util.pp_str_flat(key)) raise casted[key_casted] = val_casted return casted @@ -543,15 +558,15 @@ def populate_values(self, input, input_schema, preserve_column_refs): # fixed map casted = {} for key, val_schema in input_schema["_type"].items(): - default = None - if key not in input: + if key in input: + val = input[key] + else: if val_schema.get("_optional") is not True: raise UserException("missing key: " + util.pp_str_flat(key)) if val_schema.get("_default") is None: continue - default = val_schema["_default"] + val = val_schema["_default"] - val = input.get(key, default) try: val_casted = self.populate_values(val, val_schema, preserve_column_refs) except CortexException as e: @@ -562,8 +577,12 @@ def populate_values(self, input, input_schema, preserve_column_refs): if input_schema is None: return input - if util.is_list(input_schema["_type"]) or util.is_dict(input_schema["_type"]): - raise UserException("unexpected type (scalar)") + if not util.is_str(input_schema["_type"]): + raise UserException( + "unsupported input type (expected type {}, got {})".format( + util.data_type_str(input_schema["_type"]), util.pp_str_flat(input) + ) + ) return cast_compound_type(input, input_schema["_type"]) @@ -605,8 +624,8 @@ def cast_compound_type(value, type_str): return value raise UserException( - "input value's type is not supported by the schema (got {}, expected input with type {})".format( - util.pp_str_flat(value), type_str + "unsupported input type (expected type {}, got {})".format( + util.data_type_str(type_str), util.pp_str_flat(value) ) ) @@ -689,7 +708,7 @@ def _deserialize_raw_ctx(raw_ctx): def create_transformer_inputs_from_map(input, col_value_map): if util.is_str(input): res_name = util.get_resource_ref(input) - if res_name in col_value_map: + if res_name is not None and res_name in col_value_map: return col_value_map[res_name] return input diff --git a/pkg/workloads/lib/util.py b/pkg/workloads/lib/util.py index 1d11d57f98..51a20b9393 100644 --- a/pkg/workloads/lib/util.py +++ b/pkg/workloads/lib/util.py @@ -66,6 +66,11 @@ def pp_str_flat(obj, indent=0): return indent_str(out, indent) +def data_type_str(obj): + # TODO. Also call this method with output types? + return pp_str_flat(obj) + + def log_indent(obj, indent=0, logging_func=logger.info): if not is_str(obj): text = repr(obj) @@ -748,8 +753,6 @@ def validate_output_type(value, output_type): return False if is_list(output_type): - if not (len(output_type) == 1 and is_str(output_type[0])): - return False if not is_list(value): return False for value_item in value: @@ -760,8 +763,6 @@ def validate_output_type(value, output_type): if is_dict(output_type): if not is_dict(value): return False - if len(output_type) == 0: - return False is_generic_map = False if len(output_type) == 1: @@ -787,10 +788,10 @@ def validate_output_type(value, output_type): return False return True - return False + return False # unexpected -# Casts int -> float. Input is assumed to be already validated +# value is assumed to be already validated against output_type def cast_output_type(value, output_type): if is_str(output_type): if ( @@ -858,6 +859,12 @@ def extract_resource_refs(input): return {res} return set() + if is_list(input): + resources = set() + for item in input: + resources = resources.union(extract_resource_refs(item)) + return resources + if is_dict(input): resources = set() for key, val in input.items(): @@ -865,10 +872,4 @@ def extract_resource_refs(input): resources = resources.union(extract_resource_refs(val)) return resources - if is_list(input): - resources = set() - for item in input: - resources = resources.union(extract_resource_refs(item)) - return resources - return set() diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 1f43f3b069..e59b59d53c 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -397,12 +397,12 @@ def read_parquet(ctx, spark): def split_aggregators(aggregate_names, ctx): - aggregate_resources = [ctx.aggregates[agg_name] for agg_name in aggregate_names] + aggregates = [ctx.aggregates[agg_name] for agg_name in aggregate_names] builtin_aggregates = [] custom_aggregates = [] - for agg in aggregate_resources: + for agg in aggregates: aggregator = ctx.aggregators[agg["aggregator"]] if aggregator.get("namespace", None) == "cortex" and aggregator["name"] in AGG_SPARK_LIST: builtin_aggregates.append(agg) @@ -416,52 +416,50 @@ def run_builtin_aggregators(builtin_aggregates, df, ctx, spark): agg_cols = [] for agg in builtin_aggregates: aggregator = ctx.aggregators[agg["aggregator"]] - input_repl = ctx.populate_values( - agg["input"], aggregator["input"], preserve_column_refs=False - ) + input = ctx.populate_values(agg["input"], aggregator["input"], preserve_column_refs=False) if aggregator["name"] == "approx_count_distinct": agg_cols.append( - F.approxCountDistinct(input_repl["col"], input_repl.get("rsd")).alias(agg["name"]) + F.approxCountDistinct(input["col"], input.get("rsd")).alias(agg["name"]) ) if aggregator["name"] == "avg": - agg_cols.append(F.avg(input_repl).alias(agg["name"])) + agg_cols.append(F.avg(input).alias(agg["name"])) if aggregator["name"] in {"collect_set_int", "collect_set_float", "collect_set_string"}: - agg_cols.append(F.collect_set(input_repl).alias(agg["name"])) + agg_cols.append(F.collect_set(input).alias(agg["name"])) if aggregator["name"] == "count": - agg_cols.append(F.count(input_repl).alias(agg["name"])) + agg_cols.append(F.count(input).alias(agg["name"])) if aggregator["name"] == "count_distinct": - agg_cols.append(F.countDistinct(*input_repl).alias(agg["name"])) + agg_cols.append(F.countDistinct(*input).alias(agg["name"])) if aggregator["name"] == "covar_pop": - agg_cols.append(F.covar_pop(input_repl["col1"], input_repl["col2"]).alias(agg["name"])) + agg_cols.append(F.covar_pop(input["col1"], input["col2"]).alias(agg["name"])) if aggregator["name"] == "covar_samp": - agg_cols.append(F.covar_samp(input_repl["col1"], input_repl["col2"]).alias(agg["name"])) + agg_cols.append(F.covar_samp(input["col1"], input["col2"]).alias(agg["name"])) if aggregator["name"] == "kurtosis": - agg_cols.append(F.kurtosis(input_repl).alias(agg["name"])) + agg_cols.append(F.kurtosis(input).alias(agg["name"])) if aggregator["name"] in {"max_int", "max_float", "max_string"}: - agg_cols.append(F.max(input_repl).alias(agg["name"])) + agg_cols.append(F.max(input).alias(agg["name"])) if aggregator["name"] == "mean": - agg_cols.append(F.mean(input_repl).alias(agg["name"])) + agg_cols.append(F.mean(input).alias(agg["name"])) if aggregator["name"] in {"min_int", "min_float", "min_string"}: - agg_cols.append(F.min(input_repl).alias(agg["name"])) + agg_cols.append(F.min(input).alias(agg["name"])) if aggregator["name"] == "skewness": - agg_cols.append(F.skewness(input_repl).alias(agg["name"])) + agg_cols.append(F.skewness(input).alias(agg["name"])) if aggregator["name"] == "stddev": - agg_cols.append(F.stddev(input_repl).alias(agg["name"])) + agg_cols.append(F.stddev(input).alias(agg["name"])) if aggregator["name"] == "stddev_pop": - agg_cols.append(F.stddev_pop(input_repl).alias(agg["name"])) + agg_cols.append(F.stddev_pop(input).alias(agg["name"])) if aggregator["name"] == "stddev_samp": - agg_cols.append(F.stddev_samp(input_repl).alias(agg["name"])) + agg_cols.append(F.stddev_samp(input).alias(agg["name"])) if aggregator["name"] in {"sum_int", "sum_float"}: - agg_cols.append(F.sum(input_repl).alias(agg["name"])) + agg_cols.append(F.sum(input).alias(agg["name"])) if aggregator["name"] in {"sum_distinct_int", "sum_distinct_float"}: - agg_cols.append(F.sumDistinct(input_repl).alias(agg["name"])) + agg_cols.append(F.sumDistinct(input).alias(agg["name"])) if aggregator["name"] == "var_pop": - agg_cols.append(F.var_pop(input_repl).alias(agg["name"])) + agg_cols.append(F.var_pop(input).alias(agg["name"])) if aggregator["name"] == "var_samp": - agg_cols.append(F.var_samp(input_repl).alias(agg["name"])) + agg_cols.append(F.var_samp(input).alias(agg["name"])) if aggregator["name"] == "variance": - agg_cols.append(F.variance(input_repl).alias(agg["name"])) + agg_cols.append(F.variance(input).alias(agg["name"])) results = df.agg(*agg_cols).collect()[0].asDict() @@ -479,12 +477,10 @@ def run_builtin_aggregators(builtin_aggregates, df, ctx, spark): def run_custom_aggregator(aggregate, df, ctx, spark): aggregator = ctx.aggregators[aggregate["aggregator"]] aggregator_impl, _ = ctx.get_aggregator_impl(aggregate["name"]) - input_repl = ctx.populate_values( - aggregate["input"], aggregator["input"], preserve_column_refs=False - ) + input = ctx.populate_values(aggregate["input"], aggregator["input"], preserve_column_refs=False) try: - result = aggregator_impl.aggregate_spark(df, input_repl) + result = aggregator_impl.aggregate_spark(df, input) except Exception as e: raise UserRuntimeException( "aggregate " + aggregate["name"], @@ -517,11 +513,11 @@ def execute_transform_spark(column_name, df, ctx, spark): spark.sparkContext.addPyFile(trans_impl_path) # Executor pods need this because of the UDF ctx.spark_uploaded_impls[trans_impl_path] = True - input_repl = ctx.populate_values( + input = ctx.populate_values( transformed_column["input"], transformer["input"], preserve_column_refs=False ) try: - return trans_impl.transform_spark(df, input_repl, column_name) + return trans_impl.transform_spark(df, input, column_name) except Exception as e: raise UserRuntimeException("function transform_spark") from e @@ -532,7 +528,7 @@ def execute_transform_python(column_name, df, ctx, spark, validate=False): transformer = ctx.transformers[transformed_column["transformer"]] input_cols_sorted = sorted(ctx.extract_column_names(transformed_column["input"])) - input_repl = ctx.populate_values( + input = ctx.populate_values( transformed_column["input"], transformer["input"], preserve_column_refs=True ) @@ -541,9 +537,7 @@ def execute_transform_python(column_name, df, ctx, spark, validate=False): ctx.spark_uploaded_impls[trans_impl_path] = True def _transform(*values): - transformer_input = create_transformer_inputs_from_lists( - input_repl, input_cols_sorted, values - ) + transformer_input = create_transformer_inputs_from_lists(input, input_cols_sorted, values) return trans_impl.transform_python(transformer_input) transform_python_func = _transform @@ -593,15 +587,15 @@ def validate_transformer(column_name, test_df, ctx, spark): if transformer["output_type"] == consts.COLUMN_TYPE_INFERRED: sample_df = test_df.collect() sample = sample_df[0] - input_repl = ctx.populate_values( + input = ctx.populate_values( transformed_column["input"], transformer["input"], preserve_column_refs=True ) - transformer_input = create_transformer_inputs_from_map(input_repl, sample) + transformer_input = create_transformer_inputs_from_map(input, sample) initial_transformed_sample = trans_impl.transform_python(transformer_input) inferred_python_type = infer_type(initial_transformed_sample) for row in sample_df: - transformer_input = create_transformer_inputs_from_map(input_repl, row) + transformer_input = create_transformer_inputs_from_map(input, row) transformed_sample = trans_impl.transform_python(transformer_input) if inferred_python_type != infer_type(transformed_sample): raise UserRuntimeException( diff --git a/pkg/workloads/tf_api/api.py b/pkg/workloads/tf_api/api.py index 54d50d0d1b..90d00d8445 100644 --- a/pkg/workloads/tf_api/api.py +++ b/pkg/workloads/tf_api/api.py @@ -76,17 +76,17 @@ def transform_sample(sample): transformed_value = sample[column_name] else: transformed_column = ctx.transformed_columns[column_name] - input_repl = ctx.populate_values( - transformed_column["input"], None, preserve_column_refs=True - ) - transformer_input = create_transformer_inputs_from_map(input_repl, sample) trans_impl = local_cache["trans_impls"][column_name] if not hasattr(trans_impl, "transform_python"): raise UserException( "transformed column " + column_name, "transformer " + transformed_column["transformer"], - "transform_python() function is missing", + "transform_python function is missing", ) + input = ctx.populate_values( + transformed_column["input"], None, preserve_column_refs=True + ) + transformer_input = create_transformer_inputs_from_map(input, sample) transformed_value = trans_impl.transform_python(transformer_input) transformed_sample[column_name] = transformed_value @@ -123,9 +123,9 @@ def reverse_transform(value): if not (trans_impl and hasattr(trans_impl, "reverse_transform_python")): return None - input_repl = ctx.populate_values(target_col["input"], None, preserve_column_refs=False) + input = ctx.populate_values(target_col["input"], None, preserve_column_refs=False) try: - result = trans_impl.reverse_transform_python(value, input_repl) + result = trans_impl.reverse_transform_python(value, input) except Exception as e: raise UserRuntimeException( "transformer " + target_col["transformer"], "function reverse_transform_python" @@ -246,10 +246,7 @@ def predict(app_name, api_name): return "Malformed JSON", status.HTTP_400_BAD_REQUEST ctx = local_cache["ctx"] - model = local_cache["model"] - estimator = local_cache["estimator"] api = local_cache["api"] - target_col_type = local_cache["target_col_type"] response = {} From 45f7bac8adcda36dacce247a5d795a5e084ce53f Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Tue, 11 Jun 2019 17:27:45 -0700 Subject: [PATCH 20/44] Fix embed --- pkg/operator/api/userconfig/embed.go | 7 ++++--- pkg/operator/api/userconfig/errors.go | 6 +++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/pkg/operator/api/userconfig/embed.go b/pkg/operator/api/userconfig/embed.go index 3778e189d0..f14d8d450e 100644 --- a/pkg/operator/api/userconfig/embed.go +++ b/pkg/operator/api/userconfig/embed.go @@ -40,9 +40,10 @@ var embedValidation = &cr.StructValidation{ { StructField: "Args", InterfaceMapValidation: &cr.InterfaceMapValidation{ - Required: false, - AllowEmpty: true, - Default: make(map[string]interface{}), + Required: false, + AllowEmpty: true, + AllowCortexResources: true, + Default: make(map[string]interface{}), }, }, typeFieldValidation, diff --git a/pkg/operator/api/userconfig/errors.go b/pkg/operator/api/userconfig/errors.go index 078d21d0dc..9f5e7fbe17 100644 --- a/pkg/operator/api/userconfig/errors.go +++ b/pkg/operator/api/userconfig/errors.go @@ -364,9 +364,13 @@ func ErrorCannotMixValueAndColumnTypes(provided interface{}) error { } func ErrorColumnTypeLiteral(provided interface{}) error { + colName := "column_name" + if providedStr, ok := provided.(string); ok { + colName = providedStr + } return Error{ Kind: ErrColumnTypeLiteral, - message: fmt.Sprintf("%s: literal values cannot be provided for column input types (e.g. use FLOAT_COLUMN instead of FLOAT)", s.UserStrStripped(provided)), + message: fmt.Sprintf("%s: literal values cannot be provided for column input types (use a reference to a column, e.g. \"@%s\")", s.UserStrStripped(provided), colName), } } From e09499088f84e2f5caf09f9dcde653892fdb8306 Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Tue, 11 Jun 2019 17:58:26 -0700 Subject: [PATCH 21/44] Fix ALL_TYPES --- pkg/workloads/consts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/workloads/consts.py b/pkg/workloads/consts.py index 01b03211a4..aefa8f8447 100644 --- a/pkg/workloads/consts.py +++ b/pkg/workloads/consts.py @@ -41,4 +41,4 @@ VALUE_TYPES = [VALUE_TYPE_INT, VALUE_TYPE_FLOAT, VALUE_TYPE_STRING, VALUE_TYPE_BOOL] -ALL_TYPES = set(COLUMN_LIST_TYPES + VALUE_TYPES) +ALL_TYPES = set(COLUMN_TYPES + VALUE_TYPES) From 2f5c8ed0005a90fa8ba55515c6822f0fc21bc5ca Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Tue, 11 Jun 2019 23:56:19 -0700 Subject: [PATCH 22/44] Misc fixes --- pkg/workloads/lib/context.py | 18 +++--- pkg/workloads/lib/util.py | 35 +++++++++++- pkg/workloads/spark_job/spark_util.py | 81 ++++++++++++++++----------- 3 files changed, 88 insertions(+), 46 deletions(-) diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index dabb356c13..bca64aa6e6 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -408,7 +408,7 @@ def get_metadata(self, resource_id, use_cache=True): def get_inferred_column_type(self, column_name): column = self.columns[column_name] - column_type = self.columns[column_name].get("type", "unknown") + column_type = self.columns[column_name]["type"] if column_type == consts.COLUMN_TYPE_INFERRED: column_type = self.get_metadata(column["id"])["type"] self.columns[column_name]["type"] = column_type @@ -465,7 +465,7 @@ def populate_values(self, input, input_schema, preserve_column_refs): if not util.is_list(input_schema["_type"]): raise UserException( "unsupported input type (expected type {}, got {})".format( - util.data_type_str(input_schema["_type"]), util.pp_str_flat(input) + util.data_type_str(input_schema["_type"]), util.user_obj_str(input) ) ) elem_schema = input_schema["_type"][0] @@ -503,7 +503,7 @@ def populate_values(self, input, input_schema, preserve_column_refs): try: val_casted = self.populate_values(val, None, preserve_column_refs) except CortexException as e: - e.wrap(util.pp_str_flat(key)) + e.wrap(util.user_obj_str(key)) raise casted[key_casted] = val_casted return casted @@ -511,7 +511,7 @@ def populate_values(self, input, input_schema, preserve_column_refs): if not util.is_dict(input_schema["_type"]): raise UserException( "unsupported input type (expected type {}, got {})".format( - util.data_type_str(input_schema["_type"]), util.pp_str_flat(input) + util.data_type_str(input_schema["_type"]), util.user_obj_str(input) ) ) @@ -550,7 +550,7 @@ def populate_values(self, input, input_schema, preserve_column_refs): val, generic_map_value, preserve_column_refs ) except CortexException as e: - e.wrap(util.pp_str_flat(key)) + e.wrap(util.user_obj_str(key)) raise casted[key_casted] = val_casted return casted @@ -562,7 +562,7 @@ def populate_values(self, input, input_schema, preserve_column_refs): val = input[key] else: if val_schema.get("_optional") is not True: - raise UserException("missing key: " + util.pp_str_flat(key)) + raise UserException("missing key: " + util.user_obj_str(key)) if val_schema.get("_default") is None: continue val = val_schema["_default"] @@ -570,7 +570,7 @@ def populate_values(self, input, input_schema, preserve_column_refs): try: val_casted = self.populate_values(val, val_schema, preserve_column_refs) except CortexException as e: - e.wrap(util.pp_str_flat(key)) + e.wrap(util.user_obj_str(key)) raise casted[key] = val_casted return casted @@ -580,7 +580,7 @@ def populate_values(self, input, input_schema, preserve_column_refs): if not util.is_str(input_schema["_type"]): raise UserException( "unsupported input type (expected type {}, got {})".format( - util.data_type_str(input_schema["_type"]), util.pp_str_flat(input) + util.data_type_str(input_schema["_type"]), util.user_obj_str(input) ) ) return cast_compound_type(input, input_schema["_type"]) @@ -625,7 +625,7 @@ def cast_compound_type(value, type_str): raise UserException( "unsupported input type (expected type {}, got {})".format( - util.data_type_str(type_str), util.pp_str_flat(value) + util.data_type_str(type_str), util.user_obj_str(value) ) ) diff --git a/pkg/workloads/lib/util.py b/pkg/workloads/lib/util.py index 51a20b9393..8dc9174d9c 100644 --- a/pkg/workloads/lib/util.py +++ b/pkg/workloads/lib/util.py @@ -66,9 +66,35 @@ def pp_str_flat(obj, indent=0): return indent_str(out, indent) -def data_type_str(obj): - # TODO. Also call this method with output types? - return pp_str_flat(obj) +def data_type_str(data_type): + data_type_str = pp_str_flat(flatten_type_schema(data_type)) + for t in consts.ALL_TYPES: + data_type_str = data_type_str.replace('"' + t, t) + data_type_str = data_type_str.replace(t + '"', t) + return data_type_str + + +def flatten_type_schema(data_type): + if is_list(data_type): + flattened = [] + for item in data_type: + flattened.append(flatten_type_schema(item)) + return flattened + + if is_dict(data_type): + if "_type" in data_type: + return flatten_type_schema(data_type["_type"]) + + flattened = {} + for key, val in data_type.items(): + flattened[key] = flatten_type_schema(val) + return flattened + + return data_type + + +def user_obj_str(obj): + return truncate_str(pp_str_flat(obj), 1000) def log_indent(obj, indent=0, logging_func=logger.info): @@ -731,6 +757,9 @@ def validate_cortex_type(value, cortex_type): if not is_str(cortex_type): raise + if cortex_type == consts.COLUMN_TYPE_INFERRED: + return True + valid_types = cortex_type.split("|") for valid_type in valid_types: if CORTEX_TYPE_TO_VALIDATOR[valid_type](value): diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index e59b59d53c..ce1ebda8d8 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -53,6 +53,15 @@ consts.COLUMN_TYPE_STRING_LIST: [ArrayType(StringType(), True)], } +CORTEX_TYPE_TO_ACCEPTABLE_PYTHON_TYPE_STRS = { + consts.COLUMN_TYPE_INT: ["int"], + consts.COLUMN_TYPE_INT_LIST: ["[int]"], + consts.COLUMN_TYPE_FLOAT: ["float", "int"], + consts.COLUMN_TYPE_FLOAT_LIST: ["[float]", "[int]"], + consts.COLUMN_TYPE_STRING: ["string"], + consts.COLUMN_TYPE_STRING_LIST: ["[string]"], +} + CORTEX_TYPE_TO_CASTABLE_SPARK_TYPES = { "csv": { consts.COLUMN_TYPE_INT: [IntegerType(), LongType()], @@ -296,7 +305,7 @@ def ingest(ctx, spark): raise UserException( "raw column " + raw_column_name, "type mismatch", - "expected {} but found {}".format( + "expected {} but got {}".format( " or ".join(str(x) for x in expected_types), actual_spark_type ), ) @@ -488,14 +497,14 @@ def run_custom_aggregator(aggregate, df, ctx, spark): "function aggregate_spark", ) from e - if "output_type" in aggregator and not util.validate_output_type( + if aggregator.get("output_type") is not None and not util.validate_output_type( result, aggregator["output_type"] ): raise UserException( "aggregate " + aggregate["name"], "aggregator " + aggregator["name"], - "type of {} is not {}".format( - util.str_rep(util.pp_str(result), truncate=100), aggregator["output_type"] + "unsupported return type (expected type {}, got {})".format( + util.data_type_str(aggregator["output_type"]), util.user_obj_str(result) ), ) @@ -552,7 +561,10 @@ def _transform_and_validate(*values): raise UserException( "transformed column " + column_name, "tranformer " + transformed_column["transformer"], - "type of {} is not {}".format(result, column_type), + "incorrect return value type: expected {}, got {}.".format( + " or ".join(CORTEX_TYPE_TO_ACCEPTABLE_PYTHON_TYPE_STRS[column_type]), + util.user_obj_str(result), + ), ) return result @@ -564,7 +576,7 @@ def _transform_and_validate(*values): return df.withColumn(column_name, transform_udf(*input_cols_sorted)) -def infer_type(obj): +def infer_python_type(obj): obj_type = type(obj) if obj_type == list: @@ -591,14 +603,14 @@ def validate_transformer(column_name, test_df, ctx, spark): transformed_column["input"], transformer["input"], preserve_column_refs=True ) transformer_input = create_transformer_inputs_from_map(input, sample) - initial_transformed_sample = trans_impl.transform_python(transformer_input) - inferred_python_type = infer_type(initial_transformed_sample) + initial_transformed_value = trans_impl.transform_python(transformer_input) + inferred_python_type = infer_python_type(initial_transformed_value) for row in sample_df: transformer_input = create_transformer_inputs_from_map(input, row) - transformed_sample = trans_impl.transform_python(transformer_input) - if inferred_python_type != infer_type(transformed_sample): - raise UserRuntimeException( + transformed_value = trans_impl.transform_python(transformer_input) + if inferred_python_type != infer_python_type(transformed_value): + raise UserException( "transformed column " + column_name, "type inference failed, mixed data types in dataframe.", 'expected type of "' @@ -625,7 +637,7 @@ def validate_transformer(column_name, test_df, ctx, spark): # check that the return object is a dataframe if type(transform_spark_df) is not DataFrame: raise UserException( - "expected pyspark.sql.dataframe.DataFrame but found type {}".format( + "expected pyspark.sql.dataframe.DataFrame but got type {}".format( type(transform_spark_df) ) ) @@ -641,35 +653,36 @@ def validate_transformer(column_name, test_df, ctx, spark): ) ) + if transformer["output_type"] == consts.COLUMN_TYPE_INFERRED: + inferred_spark_type = SPARK_TYPE_TO_CORTEX_TYPE[ + transform_spark_df.select(column_name).schema[0].dataType + ] + ctx.write_metadata(transformed_column["id"], {"type": inferred_spark_type}) + # check that transformer run on data try: transform_spark_df.select(column_name).collect() except Exception as e: raise UserRuntimeException("function transform_spark") from e - if transformer["output_type"] == consts.COLUMN_TYPE_INFERRED: - inferred_spark_type = transform_spark_df.select(column_name).schema[0].dataType - ctx.write_metadata(transformed_column["id"], {"type": inferred_spark_type}) - # check that expected output column has the correct data type - actual_structfield = transform_spark_df.select(column_name).schema.fields[0] - if ( - actual_structfield.dataType - not in CORTEX_TYPE_TO_ACCEPTABLE_SPARK_TYPES[ - ctx.get_inferred_column_type(column_name) - ] - ): - raise UserException( - "incorrect column type, expected {}, found {}.".format( - " or ".join( - str(t) - for t in CORTEX_TYPE_TO_ACCEPTABLE_SPARK_TYPES[ - ctx.get_inferred_column_type(column_name) - ] - ), - actual_structfield.dataType, + if transformer["output_type"] != consts.COLUMN_TYPE_INFERRED: + actual_structfield = transform_spark_df.select(column_name).schema.fields[0] + if ( + actual_structfield.dataType + not in CORTEX_TYPE_TO_ACCEPTABLE_SPARK_TYPES[transformer["output_type"]] + ): + raise UserException( + "incorrect column type: expected {}, got {}.".format( + " or ".join( + str(t) + for t in CORTEX_TYPE_TO_ACCEPTABLE_SPARK_TYPES[ + transformer["output_type"] + ] + ), + actual_structfield.dataType, + ) ) - ) # perform the necessary casting for the column transform_spark_df = transform_spark_df.withColumn( @@ -703,7 +716,7 @@ def validate_transformer(column_name, test_df, ctx, spark): transformer["output_type"] == consts.COLUMN_TYPE_INFERRED and inferred_spark_type != inferred_python_type ): - raise UserRuntimeException( + raise UserException( "transformed column " + column_name, "type inference failed, transform_spark and transform_python had differing types.", "transform_python: " + inferred_python_type, From b2f371cc0cc14251e332c815c75f79ca3badab9d Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Wed, 12 Jun 2019 00:20:52 -0700 Subject: [PATCH 23/44] Escape emojis in python prints --- pkg/workloads/lib/util.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/pkg/workloads/lib/util.py b/pkg/workloads/lib/util.py index 8dc9174d9c..3d60b80938 100644 --- a/pkg/workloads/lib/util.py +++ b/pkg/workloads/lib/util.py @@ -35,6 +35,9 @@ logger = get_logger() +resource_escape_seq = "🌝🌝🌝🌝🌝" +resource_escape_seq_raw = r"\ud83c\udf1d\ud83c\udf1d\ud83c\udf1d\ud83c\udf1d\ud83c\udf1d" + def isclose(a, b, rel_tol=1e-09, abs_tol=0.0): return abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol) @@ -51,6 +54,8 @@ def pp_str(obj, indent=0): out = json.dumps(obj, sort_keys=True, indent=2) except: out = pprint.pformat(obj, width=120) + out = out.replace(resource_escape_seq, "@") + out = out.replace(resource_escape_seq_raw, "@") return indent_str(out, indent) @@ -63,6 +68,8 @@ def pp_str_flat(obj, indent=0): out = json.dumps(obj, sort_keys=True) except: out = str(obj).replace("\n", "") + out = out.replace(resource_escape_seq, "@") + out = out.replace(resource_escape_seq_raw, "@") return indent_str(out, indent) @@ -864,21 +871,20 @@ def cast_output_type(value, output_type): return value -escape_seq = "🌝🌝🌝🌝🌝" - - def is_resource_ref(obj): if not is_str(obj): return False - return obj.startswith(escape_seq) + return obj.startswith(resource_escape_seq) or obj.startswith(resource_escape_seq_raw) def get_resource_ref(obj): if not is_str(obj): return None - if not obj.startswith(escape_seq): - return None - return obj[len(escape_seq) :] + if obj.startswith(resource_escape_seq): + return obj[len(resource_escape_seq) :] + elif obj.startswith(resource_escape_seq_raw): + return obj[len(resource_escape_seq_raw) :] + return None def extract_resource_refs(input): From bc36e64e781dd9df271235ca52c4efe23e48e45d Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Wed, 12 Jun 2019 00:39:53 -0700 Subject: [PATCH 24/44] Check target_column type in Python --- pkg/workloads/lib/context.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index bca64aa6e6..dcb8556c9d 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -303,6 +303,20 @@ def model_config(self, model_name): if model is None: return None estimator = self.estimators[model["estimator"]] + target_column = self.columns[util.get_resource_ref(model["target_column"])] + + if estimator.get("target_column") is not None: + target_col_type = self.get_inferred_column_type(target_column["name"]) + if target_col_type not in estimator["target_column"]: + raise UserException( + "model " + model_name, + "target_column", + target_column["name"], + "unsupported type (expected type {}, got type {})".format( + util.data_type_str(estimator["target_column"]), + util.data_type_str(target_col_type), + ), + ) model_config = deepcopy(model) config_keys = [ @@ -321,15 +335,15 @@ def model_config(self, model_name): ] util.keep_dict_keys(model_config, config_keys) - model_config["target_column"] = util.get_resource_ref(model["target_column"]) + model_config["target_column"] = target_column["name"] model_config["input"] = self.populate_values( model["input"], estimator["input"], preserve_column_refs=False ) - if model["training_input"] is not None: + if model.get("training_input") is not None: model_config["training_input"] = self.populate_values( model["training_input"], estimator["training_input"], preserve_column_refs=False ) - if model["hparams"] is not None: + if model.get("hparams") is not None: model_config["hparams"] = self.populate_values( model["hparams"], estimator["hparams"], preserve_column_refs=False ) From 1effca3255bde78a95f68772865a8485d94bbdf8 Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Wed, 12 Jun 2019 00:42:43 -0700 Subject: [PATCH 25/44] Fix msgpack.loads --- pkg/workloads/lib/storage/s3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/workloads/lib/storage/s3.py b/pkg/workloads/lib/storage/s3.py index 90348dff04..afdb36b4b8 100644 --- a/pkg/workloads/lib/storage/s3.py +++ b/pkg/workloads/lib/storage/s3.py @@ -143,7 +143,7 @@ def get_msgpack(self, key, allow_missing=False): obj = self._read_bytes_from_s3(key, allow_missing) if obj == None: return None - return msgpack.load(obj, raw=False) + return msgpack.loads(obj, raw=False) def put_pyobj(self, obj, key): self._upload_string_to_s3(pickle.dumps(obj), key) From d105d7a0ebc903c08c98a0414f61c4f25930bd33 Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Wed, 12 Jun 2019 00:50:33 -0700 Subject: [PATCH 26/44] Escape emojis --- pkg/operator/api/userconfig/resource.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pkg/operator/api/userconfig/resource.go b/pkg/operator/api/userconfig/resource.go index da91815692..91426a0697 100644 --- a/pkg/operator/api/userconfig/resource.go +++ b/pkg/operator/api/userconfig/resource.go @@ -19,6 +19,8 @@ package userconfig import ( "fmt" + "github.com/cortexlabs/yaml" + s "github.com/cortexlabs/cortex/pkg/lib/strings" "github.com/cortexlabs/cortex/pkg/operator/api/resource" ) @@ -74,6 +76,8 @@ func Identify(r Resource) string { } func identify(filePath string, resourceType resource.Type, name string, index int, embed *Embed) string { + name, _ = yaml.UnescapeAtSymbol(name) + resourceTypeStr := resourceType.String() if resourceType == resource.UnknownType { resourceTypeStr = "resource" From cae1bd988f0fe9fd42ea3fae27b14cf94926824c Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Wed, 12 Jun 2019 08:22:09 -0700 Subject: [PATCH 27/44] Use target_vocab in built-in estimator prediction responses --- pkg/workloads/tf_api/api.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/pkg/workloads/tf_api/api.py b/pkg/workloads/tf_api/api.py index 90d00d8445..ce763fca79 100644 --- a/pkg/workloads/tf_api/api.py +++ b/pkg/workloads/tf_api/api.py @@ -159,7 +159,21 @@ def parse_response_proto(response_proto): outputs = results_dict["outputs"] value_key = DTYPE_TO_VALUE_KEY[outputs[prediction_key]["dtype"]] prediction = outputs[prediction_key][value_key][0] - prediction = util.cast(prediction, target_col_type) + + target_vocab_estimators = { + "dnn_classifier", + "linear_classifier", + "dnn_linear_combined_classifier", + "boosted_trees_classifier", + } + if ( + estimator["namespace"] == "cortex" + and estimator["name"] in target_vocab_estimators + and model["input"].get("target_vocab") is not None + ): + prediction = model["input"]["target_vocab"][int(prediction)] + else: + prediction = util.cast(prediction, target_col_type) result = {} result["prediction"] = prediction From a323d1b74664703bcbfc53e7871ec10d7cc9c393 Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Wed, 12 Jun 2019 09:11:26 -0700 Subject: [PATCH 28/44] Update kubernetes config --- cortex-installer.sh | 43 +++++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/cortex-installer.sh b/cortex-installer.sh index c1ce4f05dc..630e2b0b91 100755 --- a/cortex-installer.sh +++ b/cortex-installer.sh @@ -335,7 +335,7 @@ metadata: name: argo-executor namespace: ${CORTEX_NAMESPACE} --- -apiVersion: rbac.authorization.k8s.io/v1beta1 +apiVersion: rbac.authorization.k8s.io/v1 kind: RoleBinding metadata: name: argo-executor @@ -381,13 +381,13 @@ rules: verbs: [create, get, list, watch, update, patch, delete] - apiGroups: [\"\"] resources: [configmaps] - verbs: [get, list, watch] + verbs: [get, watch, list] - apiGroups: [\"\"] resources: [persistentvolumeclaims] verbs: [create, delete] - apiGroups: [argoproj.io] resources: [workflows, workflows/finalizers] - verbs: [get, list, watch, update, patch] + verbs: [get, list, watch, update, patch, delete] --- apiVersion: rbac.authorization.k8s.io/v1 kind: RoleBinding @@ -412,7 +412,7 @@ data: config: | namespace: ${CORTEX_NAMESPACE} --- -apiVersion: apps/v1beta2 +apiVersion: apps/v1 kind: Deployment metadata: name: argo-controller @@ -436,9 +436,9 @@ spec: - Always command: - workflow-controller - name: argo-controller image: ${CORTEX_IMAGE_ARGO_CONTROLLER} imagePullPolicy: Always + name: argo-controller serviceAccountName: argo-controller " | kubectl apply -f - >/dev/null } @@ -455,7 +455,7 @@ metadata: name: spark-operator namespace: ${CORTEX_NAMESPACE} --- -apiVersion: rbac.authorization.k8s.io/v1beta1 +apiVersion: rbac.authorization.k8s.io/v1 kind: Role metadata: name: spark-operator @@ -486,7 +486,7 @@ rules: resources: [sparkapplications, scheduledsparkapplications] verbs: [\"*\"] --- -apiVersion: rbac.authorization.k8s.io/v1beta1 +apiVersion: rbac.authorization.k8s.io/v1 kind: RoleBinding metadata: name: spark-operator @@ -500,7 +500,7 @@ roleRef: name: spark-operator apiGroup: rbac.authorization.k8s.io --- -apiVersion: apps/v1beta1 +apiVersion: apps/v1 kind: Deployment metadata: name: spark-operator @@ -730,7 +730,7 @@ metadata: name: spark namespace: ${CORTEX_NAMESPACE} --- -apiVersion: rbac.authorization.k8s.io/v1beta1 +apiVersion: rbac.authorization.k8s.io/v1 kind: Role metadata: name: spark @@ -745,7 +745,7 @@ rules: resources: [services] verbs: [\"*\"] --- -apiVersion: rbac.authorization.k8s.io/v1beta1 +apiVersion: rbac.authorization.k8s.io/v1 kind: RoleBinding metadata: name: spark @@ -774,7 +774,7 @@ metadata: name: nginx namespace: ${CORTEX_NAMESPACE} --- -apiVersion: rbac.authorization.k8s.io/v1beta1 +apiVersion: rbac.authorization.k8s.io/v1 kind: Role metadata: name: nginx @@ -812,7 +812,7 @@ rules: resources: [configmaps] verbs: [get, list, watch, create] --- -apiVersion: rbac.authorization.k8s.io/v1beta1 +apiVersion: rbac.authorization.k8s.io/v1 kind: RoleBinding metadata: name: nginx @@ -834,7 +834,7 @@ metadata: data: use-proxy-protocol: \"true\" --- -apiVersion: extensions/v1beta1 +apiVersion: apps/v1 kind: Deployment metadata: name: nginx-backend-operator @@ -895,7 +895,7 @@ spec: app.kubernetes.io/name: nginx-backend-operator app.kubernetes.io/part-of: ingress-nginx --- -apiVersion: extensions/v1beta1 +apiVersion: apps/v1 kind: Deployment metadata: name: nginx-controller-operator @@ -998,7 +998,7 @@ spec: port: 443 targetPort: https --- -apiVersion: extensions/v1beta1 +apiVersion: apps/v1 kind: Deployment metadata: name: nginx-backend-apis @@ -1059,7 +1059,7 @@ spec: app.kubernetes.io/name: nginx-backend-apis app.kubernetes.io/part-of: ingress-nginx --- -apiVersion: extensions/v1beta1 +apiVersion: apps/v1 kind: Deployment metadata: name: nginx-controller-apis @@ -1178,7 +1178,7 @@ metadata: labels: app: fluentd --- -apiVersion: rbac.authorization.k8s.io/v1beta1 +apiVersion: rbac.authorization.k8s.io/v1 kind: Role metadata: name: fluentd @@ -1188,7 +1188,7 @@ rules: resources: [pods] verbs: [get, list, watch] --- -apiVersion: rbac.authorization.k8s.io/v1beta1 +apiVersion: rbac.authorization.k8s.io/v1 kind: RoleBinding metadata: name: fluentd @@ -1305,7 +1305,7 @@ metadata: name: operator namespace: ${CORTEX_NAMESPACE} --- -apiVersion: rbac.authorization.k8s.io/v1beta1 +apiVersion: rbac.authorization.k8s.io/v1 kind: RoleBinding metadata: name: operator @@ -1319,7 +1319,7 @@ roleRef: name: cluster-admin apiGroup: rbac.authorization.k8s.io --- -apiVersion: apps/v1beta1 +apiVersion: apps/v1 kind: Deployment metadata: name: operator @@ -1328,6 +1328,9 @@ metadata: workloadType: operator spec: replicas: 1 + selector: + matchLabels: + workloadId: operator template: metadata: labels: From a03194b387b11fe0492c0a42e1bdbc70c833736c Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Wed, 12 Jun 2019 15:18:05 -0700 Subject: [PATCH 29/44] Fix AllComputedResourceDependencies --- cli/cmd/get.go | 2 +- pkg/operator/api/context/dependencies.go | 24 +++++++++++++++--------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/cli/cmd/get.go b/cli/cmd/get.go index 108c9e71b5..b3b3d5a0d0 100644 --- a/cli/cmd/get.go +++ b/cli/cmd/get.go @@ -414,7 +414,7 @@ func describeAPI(name string, resourcesRes *schema.GetResourcesResponse) (string out += titleStr("Endpoint") resIDs := strset.New() combinedInput := []interface{}{model.Input, model.TrainingInput} - for _, res := range ctx.ExtractCortexResources(combinedInput) { + for _, res := range ctx.ExtractCortexResources(combinedInput, resource.ConstantType, resource.RawColumnType, resource.AggregateType, resource.TransformedColumnType) { resIDs.Add(res.GetID()) resIDs.Merge(ctx.AllComputedResourceDependencies(res.GetID())) } diff --git a/pkg/operator/api/context/dependencies.go b/pkg/operator/api/context/dependencies.go index 66791ca064..eb7a8fe3c9 100644 --- a/pkg/operator/api/context/dependencies.go +++ b/pkg/operator/api/context/dependencies.go @@ -28,13 +28,19 @@ import ( ) func (ctx *Context) AllComputedResourceDependencies(resourceID string) strset.Set { - dependencies := ctx.DirectComputedResourceDependencies(resourceID) - for dependency := range dependencies.Copy() { - for subDependency := range ctx.AllComputedResourceDependencies(dependency) { - dependencies.Add(subDependency) - } + allDependencies := strset.New() + ctx.allComputedResourceDependenciesHelper(resourceID, allDependencies) + return allDependencies +} + +func (ctx *Context) allComputedResourceDependenciesHelper(resourceID string, allDependencies strset.Set) { + subDependencies := ctx.DirectComputedResourceDependencies(resourceID) + subDependencies.Subtract(allDependencies) + allDependencies.Merge(subDependencies) + + for dependency := range subDependencies { + ctx.allComputedResourceDependenciesHelper(dependency, allDependencies) } - return dependencies } func (ctx *Context) DirectComputedResourceDependencies(resourceID string) strset.Set { @@ -95,7 +101,7 @@ func (ctx *Context) aggregatesDependencies(aggregate *Aggregate) strset.Set { dependencies.Add(pythonPackage.GetID()) } - for _, res := range ctx.ExtractCortexResources(aggregate.Input) { + for _, res := range ctx.ExtractCortexResources(aggregate.Input, resource.ConstantType, resource.RawColumnType) { dependencies.Add(res.GetID()) } @@ -109,7 +115,7 @@ func (ctx *Context) transformedColumnDependencies(transformedColumn *Transformed dependencies.Add(pythonPackage.GetID()) } - for _, res := range ctx.ExtractCortexResources(transformedColumn.Input) { + for _, res := range ctx.ExtractCortexResources(transformedColumn.Input, resource.ConstantType, resource.RawColumnType, resource.AggregateType) { dependencies.Add(res.GetID()) } @@ -137,7 +143,7 @@ func (ctx *Context) modelDependencies(model *Model) strset.Set { dependencies.Add(model.Dataset.ID) combinedInput := []interface{}{model.Input, model.TrainingInput, model.TargetColumn} - for _, res := range ctx.ExtractCortexResources(combinedInput) { + for _, res := range ctx.ExtractCortexResources(combinedInput, resource.ConstantType, resource.RawColumnType, resource.AggregateType, resource.TransformedColumnType) { dependencies.Add(res.GetID()) } From f94ef4e5729c54fc56a5e93ab3fe052f12a8a4f0 Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Wed, 12 Jun 2019 15:47:40 -0700 Subject: [PATCH 30/44] Clean up --- pkg/operator/api/userconfig/errors.go | 15 +++++++++++++-- pkg/operator/context/models.go | 4 ++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/pkg/operator/api/userconfig/errors.go b/pkg/operator/api/userconfig/errors.go index 9f5e7fbe17..b5efdc55c7 100644 --- a/pkg/operator/api/userconfig/errors.go +++ b/pkg/operator/api/userconfig/errors.go @@ -271,9 +271,20 @@ func ErrorRawColumnNotInEnv(envName string) error { } func ErrorUndefinedResource(resourceName string, resourceTypes ...resource.Type) error { - message := fmt.Sprintf("%s %s is not defined", s.StrsOr(resource.Types(resourceTypes).StringList()), s.UserStr(resourceName)) + message := fmt.Sprintf("%s is not defined", s.UserStr(resourceName)) + + if len(resourceTypes) == 1 { + message = fmt.Sprintf("%s %s is not defined", resourceTypes[0].String(), s.UserStr(resourceName)) + } else if len(resourceTypes) > 1 { + message = fmt.Sprintf("%s is not defined as a %s", s.UserStr(resourceName), s.StrsOr(resource.Types(resourceTypes).StringList())) + } + if strings.HasPrefix(resourceName, "cortex.") { - message = fmt.Sprintf("%s is not defined as a built-in %s in the Cortex namespace", s.UserStr(resourceName), s.StrsOr(resource.Types(resourceTypes).StringList())) + if len(resourceTypes) == 0 { + message = fmt.Sprintf("%s is not defined in the Cortex namespace", s.UserStr(resourceName)) + } else { + message = fmt.Sprintf("%s is not defined as a built-in %s in the Cortex namespace", s.UserStr(resourceName), s.StrsOr(resource.Types(resourceTypes).StringList())) + } } return Error{ diff --git a/pkg/operator/context/models.go b/pkg/operator/context/models.go index 6285e1ef69..7e2be24067 100644 --- a/pkg/operator/context/models.go +++ b/pkg/operator/context/models.go @@ -67,7 +67,7 @@ func getModels( castedInput, inputID, err := ValidateInput( modelConfig.Input, estimator.Input, - []resource.Type{resource.RawColumnType, resource.TransformedColumnType, resource.ConstantType, resource.AggregateType, resource.TransformedColumnType}, + []resource.Type{resource.RawColumnType, resource.TransformedColumnType, resource.ConstantType, resource.AggregateType}, validInputResources, config.Resources, aggregators, @@ -85,7 +85,7 @@ func getModels( castedTrainingInput, trainingInputID, err := ValidateInput( modelConfig.TrainingInput, estimator.TrainingInput, - []resource.Type{resource.RawColumnType, resource.TransformedColumnType, resource.ConstantType, resource.AggregateType, resource.TransformedColumnType}, + []resource.Type{resource.RawColumnType, resource.TransformedColumnType, resource.ConstantType, resource.AggregateType}, validInputResources, config.Resources, aggregators, From 3321e5b45b5f18f35118968b5dd125550796b315 Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Wed, 12 Jun 2019 16:01:06 -0700 Subject: [PATCH 31/44] Allow ints for categorical_columns_with_vocab --- pkg/estimators/estimators.yaml | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/pkg/estimators/estimators.yaml b/pkg/estimators/estimators.yaml index da0c86d533..7f494d97dc 100644 --- a/pkg/estimators/estimators.yaml +++ b/pkg/estimators/estimators.yaml @@ -44,7 +44,7 @@ categorical_columns_with_vocab: _type: - col: STRING_COLUMN - vocab: [STRING] + vocab: [STRING|INT] embedding_size: # If not specified, an indicator column will be used instead _type: INT _optional: true @@ -99,7 +99,7 @@ categorical_columns_with_vocab: _type: - col: STRING_COLUMN - vocab: [STRING] + vocab: [STRING|INT] embedding_size: # If not specified, an indicator column will be used instead _type: INT _optional: true @@ -161,7 +161,7 @@ categorical_columns_with_vocab: _type: - col: STRING_COLUMN - vocab: [STRING] + vocab: [STRING|INT] weight_column: _type: INT_COLUMN|FLOAT_COLUMN _optional: True @@ -205,7 +205,7 @@ categorical_columns_with_vocab: _type: - col: STRING_COLUMN - vocab: [STRING] + vocab: [STRING|INT] weight_column: _type: INT_COLUMN|FLOAT_COLUMN _optional: True @@ -263,7 +263,7 @@ categorical_columns_with_vocab: _type: - col: STRING_COLUMN - vocab: [STRING] + vocab: [STRING|INT] embedding_size: # If not specified, an indicator column will be used instead _type: INT _optional: true @@ -303,7 +303,7 @@ categorical_columns_with_vocab: _type: - col: STRING_COLUMN - vocab: [STRING] + vocab: [STRING|INT] weight_column: _type: INT_COLUMN|FLOAT_COLUMN _optional: True @@ -354,7 +354,7 @@ categorical_columns_with_vocab: _type: - col: STRING_COLUMN - vocab: [STRING] + vocab: [STRING|INT] embedding_size: # If not specified, an indicator column will be used instead _type: INT _optional: true @@ -394,7 +394,7 @@ categorical_columns_with_vocab: _type: - col: STRING_COLUMN - vocab: [STRING] + vocab: [STRING|INT] weight_column: _type: INT_COLUMN|FLOAT_COLUMN _optional: True @@ -454,7 +454,7 @@ categorical_columns_with_vocab: _type: - col: STRING_COLUMN - vocab: [STRING] + vocab: [STRING|INT] embedding_size: # If not specified, an indicator column will be used instead _type: INT _optional: true @@ -536,7 +536,7 @@ categorical_columns_with_vocab: _type: - col: STRING_COLUMN - vocab: [STRING] + vocab: [STRING|INT] embedding_size: # If not specified, an indicator column will be used instead _type: INT _optional: true From cab43c1d80b2be6d656c5c0989818de1f0f1a76d Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Wed, 12 Jun 2019 16:24:10 -0700 Subject: [PATCH 32/44] Fix categorical_columns_with_vocab --- pkg/estimators/estimators.yaml | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/pkg/estimators/estimators.yaml b/pkg/estimators/estimators.yaml index 7f494d97dc..9898fc9577 100644 --- a/pkg/estimators/estimators.yaml +++ b/pkg/estimators/estimators.yaml @@ -43,7 +43,7 @@ _default: [] categorical_columns_with_vocab: _type: - - col: STRING_COLUMN + - col: STRING_COLUMN|INT_COLUMN vocab: [STRING|INT] embedding_size: # If not specified, an indicator column will be used instead _type: INT @@ -98,7 +98,7 @@ _default: [] categorical_columns_with_vocab: _type: - - col: STRING_COLUMN + - col: STRING_COLUMN|INT_COLUMN vocab: [STRING|INT] embedding_size: # If not specified, an indicator column will be used instead _type: INT @@ -160,7 +160,7 @@ _default: [] categorical_columns_with_vocab: _type: - - col: STRING_COLUMN + - col: STRING_COLUMN|INT_COLUMN vocab: [STRING|INT] weight_column: _type: INT_COLUMN|FLOAT_COLUMN @@ -204,7 +204,7 @@ _default: [] categorical_columns_with_vocab: _type: - - col: STRING_COLUMN + - col: STRING_COLUMN|INT_COLUMN vocab: [STRING|INT] weight_column: _type: INT_COLUMN|FLOAT_COLUMN @@ -262,7 +262,7 @@ _default: [] categorical_columns_with_vocab: _type: - - col: STRING_COLUMN + - col: STRING_COLUMN|INT_COLUMN vocab: [STRING|INT] embedding_size: # If not specified, an indicator column will be used instead _type: INT @@ -302,7 +302,7 @@ _default: [] categorical_columns_with_vocab: _type: - - col: STRING_COLUMN + - col: STRING_COLUMN|INT_COLUMN vocab: [STRING|INT] weight_column: _type: INT_COLUMN|FLOAT_COLUMN @@ -353,7 +353,7 @@ _default: [] categorical_columns_with_vocab: _type: - - col: STRING_COLUMN + - col: STRING_COLUMN|INT_COLUMN vocab: [STRING|INT] embedding_size: # If not specified, an indicator column will be used instead _type: INT @@ -393,7 +393,7 @@ _default: [] categorical_columns_with_vocab: _type: - - col: STRING_COLUMN + - col: STRING_COLUMN|INT_COLUMN vocab: [STRING|INT] weight_column: _type: INT_COLUMN|FLOAT_COLUMN @@ -453,7 +453,7 @@ _default: [] categorical_columns_with_vocab: _type: - - col: STRING_COLUMN + - col: STRING_COLUMN|INT_COLUMN vocab: [STRING|INT] embedding_size: # If not specified, an indicator column will be used instead _type: INT @@ -535,7 +535,7 @@ _default: [] categorical_columns_with_vocab: _type: - - col: STRING_COLUMN + - col: STRING_COLUMN|INT_COLUMN vocab: [STRING|INT] embedding_size: # If not specified, an indicator column will be used instead _type: INT From 97348b6f9f12ab6004670f361711350ad6bb4100 Mon Sep 17 00:00:00 2001 From: Vishal Bollu Date: Thu, 13 Jun 2019 01:43:44 +0200 Subject: [PATCH 33/44] Update python tests (#160) --- .../insurance/implementations/models/dnn.py | 18 +- examples/insurance/resources/apis.yaml | 2 +- .../insurance/resources/environments.yaml | 9 +- examples/insurance/resources/features.yaml | 22 +- examples/insurance/resources/models.yaml | 22 +- pkg/workloads/lib/context.py | 4 +- pkg/workloads/lib/util.py | 9 +- .../test/integration/insurance_context.py | 631 ++++++++++++++++++ .../{iris_test.py => insurance_test.py} | 46 +- .../test/integration/iris_context.py | 572 ---------------- .../spark_job/test/unit/spark_util_test.py | 205 +++--- 11 files changed, 773 insertions(+), 767 deletions(-) create mode 100644 pkg/workloads/spark_job/test/integration/insurance_context.py rename pkg/workloads/spark_job/test/integration/{iris_test.py => insurance_test.py} (71%) delete mode 100644 pkg/workloads/spark_job/test/integration/iris_context.py diff --git a/examples/insurance/implementations/models/dnn.py b/examples/insurance/implementations/models/dnn.py index 5f1db542ad..3c5a2ccec0 100644 --- a/examples/insurance/implementations/models/dnn.py +++ b/examples/insurance/implementations/models/dnn.py @@ -2,27 +2,33 @@ def create_estimator(run_config, model_config): + aggregates = model_config["input"]["aggregates"] + feature_columns = [ tf.feature_column.indicator_column( - tf.feature_column.categorical_column_with_vocabulary_list("sex", ["female", "male"]) + tf.feature_column.categorical_column_with_vocabulary_list( + "sex", aggregates["sex_vocab"] + ) ), tf.feature_column.indicator_column( - tf.feature_column.categorical_column_with_vocabulary_list("smoker", ["yes", "no"]) + tf.feature_column.categorical_column_with_vocabulary_list( + "smoker", aggregates["smoker_vocab"] + ) ), tf.feature_column.indicator_column( tf.feature_column.categorical_column_with_vocabulary_list( - "region", ["northwest", "northeast", "southwest", "southeast"] + "region", aggregates["region_vocab"] ) ), tf.feature_column.bucketized_column( - tf.feature_column.numeric_column("age"), [15, 20, 25, 35, 40, 45, 50, 55, 60, 65] + tf.feature_column.numeric_column("age"), aggregates["age_buckets"] ), tf.feature_column.bucketized_column( - tf.feature_column.numeric_column("bmi"), [15, 20, 25, 35, 40, 45, 50, 55] + tf.feature_column.numeric_column("bmi"), aggregates["bmi_buckets"] ), tf.feature_column.indicator_column( tf.feature_column.categorical_column_with_vocabulary_list( - "children", model_config["aggregates"]["children_set"] + "children", aggregates["children_set"] ) ), ] diff --git a/examples/insurance/resources/apis.yaml b/examples/insurance/resources/apis.yaml index 4507607637..fb1611e74e 100644 --- a/examples/insurance/resources/apis.yaml +++ b/examples/insurance/resources/apis.yaml @@ -1,5 +1,5 @@ - kind: api name: cost - model_name: dnn + model: @dnn compute: replicas: 1 diff --git a/examples/insurance/resources/environments.yaml b/examples/insurance/resources/environments.yaml index f1d0b44ad3..fa1444d1c2 100644 --- a/examples/insurance/resources/environments.yaml +++ b/examples/insurance/resources/environments.yaml @@ -3,11 +3,4 @@ data: type: csv path: s3a://cortex-examples/insurance.csv - schema: - - age - - sex - - bmi - - children - - smoker - - region - - charges + schema: [@age, @sex, @bmi, @children, @smoker, @region, @charges] diff --git a/examples/insurance/resources/features.yaml b/examples/insurance/resources/features.yaml index 0465fec6a2..4a299c36a9 100644 --- a/examples/insurance/resources/features.yaml +++ b/examples/insurance/resources/features.yaml @@ -47,30 +47,22 @@ - kind: aggregate name: charges_mean aggregator: cortex.mean - inputs: - columns: - col: charges + input: @charges - kind: aggregate name: charges_stddev aggregator: cortex.stddev - inputs: - columns: - col: charges + input: @charges - kind: aggregate name: children_set aggregator: cortex.collect_set_int - inputs: - columns: - col: children + input: @children - kind: transformed_column name: charges_normalized transformer: cortex.normalize - inputs: - columns: - num: charges - args: - mean: charges_mean - stddev: charges_stddev + input: + col: @charges + mean: @charges_mean + stddev: @charges_stddev diff --git a/examples/insurance/resources/models.yaml b/examples/insurance/resources/models.yaml index 4b68106722..76ad5d1119 100644 --- a/examples/insurance/resources/models.yaml +++ b/examples/insurance/resources/models.yaml @@ -1,14 +1,16 @@ - kind: model name: dnn - type: regression - target_column: charges_normalized - feature_columns: - - age - - sex - - bmi - - children - - smoker - - region + estimator_path: implementations/models/dnn.py + target_column: @charges_normalized + input: + features: [@age, @sex, @bmi, @children, @smoker, @region, @charges] + aggregates: + children_set: @children_set + region_vocab: ["northwest", "northeast", "southwest", "southeast"] + age_buckets: [15, 20, 25, 35, 40, 45, 50, 55, 60, 65] + bmi_buckets: [15, 20, 25, 35, 40, 45, 50, 55] + smoker_vocab: ["yes", "no"] + sex_vocab: ["female", "male"] hparams: hidden_units: [100, 100, 100] data_partition_ratio: @@ -16,5 +18,3 @@ evaluation: 0.2 training: num_steps: 10000 - aggregates: - - children_set diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index dcb8556c9d..9135147771 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -721,8 +721,8 @@ def _deserialize_raw_ctx(raw_ctx): # input should already have non-column arguments replaced, and all types validated def create_transformer_inputs_from_map(input, col_value_map): if util.is_str(input): - res_name = util.get_resource_ref(input) - if res_name is not None and res_name in col_value_map: + if util.is_resource_ref(input): + res_name = util.get_resource_ref(input) return col_value_map[res_name] return input diff --git a/pkg/workloads/lib/util.py b/pkg/workloads/lib/util.py index 3d60b80938..4a35aec5b1 100644 --- a/pkg/workloads/lib/util.py +++ b/pkg/workloads/lib/util.py @@ -879,19 +879,18 @@ def is_resource_ref(obj): def get_resource_ref(obj): if not is_str(obj): - return None + raise ValueError("expected input of type string but received " + str(type(obj))) if obj.startswith(resource_escape_seq): return obj[len(resource_escape_seq) :] elif obj.startswith(resource_escape_seq_raw): return obj[len(resource_escape_seq_raw) :] - return None + raise ValueError("expected a resource reference but got " + obj) def extract_resource_refs(input): if is_str(input): - res = get_resource_ref(input) - if res is not None: - return {res} + if is_resource_ref(input): + return {get_resource_ref(input)} return set() if is_list(input): diff --git a/pkg/workloads/spark_job/test/integration/insurance_context.py b/pkg/workloads/spark_job/test/integration/insurance_context.py new file mode 100644 index 0000000000..65d30365c7 --- /dev/null +++ b/pkg/workloads/spark_job/test/integration/insurance_context.py @@ -0,0 +1,631 @@ +# Copyright 2019 Cortex Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import consts + +""" +HOW TO GENERATE CONTEXT + +1. cx deploy +2. get a path to a context +3. ssh into a docker container (spark or tf_train) +docker run -it --entrypoint "/bin/bash" cortexlabs/spark +4. run the following in python3 shell + +from lib import util +from lib.storage import S3 + +S3(bucket, client_config={}).get_msgpack(key) + +bucket, key = S3.deconstruct_s3_path('s3:///apps//contexts/.msgpack') + +5. udpate path of any implementations being used to // e.g. /transformers/normalize.py +6. delete images in cortex_config +""" + + +def get(input_data_path): + raw_ctx["environment_data"]["csv_data"]["path"] = input_data_path + raw_ctx["cortex_config"]["api_version"] = consts.CORTEX_VERSION + + return raw_ctx + + +raw_ctx = { + "constants": {}, + "root": "apps/insurance/data/2019-06-12-19-22-55-936375/483b4537be2db30b81be4809ab0c787f65230540b9b6779af7420922d654011", + "key": "apps/insurance/contexts/58feed5c3620ce81e3cb18ba0268078b28f102a70aa55cb8d30fa29a0d0f323.msgpack", + "estimators": { + "c5e53f81d6d57bd1b46ed04020509401b139864d483c35e781338f11b2cc301": { + "target_column": None, + "id": "be000f367fb9cd7167d4eb9bebc6469d7fd839fc55adb3cb6ff7e34e33335e6", + "index": 0, + "resource_type": "estimator", + "namespace": None, + "input": None, + "embed": None, + "impl_key": "estimators/c5e53f81d6d57bd1b46ed04020509401b139864d483c35e781338f11b2cc301.py", + "path": "implementations/models/dnn.py", + "hparams": None, + "prediction_key": "", + "file_path": "", + "name": "c5e53f81d6d57bd1b46ed04020509401b139864d483c35e781338f11b2cc301", + "training_input": None, + } + }, + "transformed_columns": { + "charges_normalized": { + "index": 10, + "type": "FLOAT_COLUMN", + "tags": {}, + "resource_type": "transformed_column", + "transformer_path": None, + "input": { + "col": "\U0001f31d\U0001f31d\U0001f31d\U0001f31d\U0001f31dcharges", + "stddev": "\U0001f31d\U0001f31d\U0001f31d\U0001f31d\U0001f31dcharges_stddev", + "mean": "\U0001f31d\U0001f31d\U0001f31d\U0001f31d\U0001f31dcharges_mean", + }, + "embed": None, + "workload_id": "cs7eodcfv22u5ztm2l8g", + "name": "charges_normalized", + "transformer": "cortex.normalize", + "file_path": "resources/features.yaml", + "id": "a91d01f7a9a260d094fa90d3f6a1e77372f5c271d633ac129637f2eb6196d7e", + "compute": { + "driver_mem_overhead": None, + "executors": 1, + "executor_mem": '{"Quantity":"500Mi","UserString":"500Mi"}', + "driver_cpu": '{"Quantity":"1","UserString":"1"}', + "mem_overhead_factor": None, + "executor_mem_overhead": None, + "driver_mem": '{"Quantity":"500Mi","UserString":"500Mi"}', + "executor_cpu": '{"Quantity":"1","UserString":"1"}', + }, + } + }, + "id": "58feed5c3620ce81e3cb18ba0268078b28f102a70aa55cb8d30fa29a0d0f323", + "environment_data": { + "csv_data": { + "schema": [ + "\U0001f31d\U0001f31d\U0001f31d\U0001f31d\U0001f31dage", + "\U0001f31d\U0001f31d\U0001f31d\U0001f31d\U0001f31dsex", + "\U0001f31d\U0001f31d\U0001f31d\U0001f31d\U0001f31dbmi", + "\U0001f31d\U0001f31d\U0001f31d\U0001f31d\U0001f31dchildren", + "\U0001f31d\U0001f31d\U0001f31d\U0001f31d\U0001f31dsmoker", + "\U0001f31d\U0001f31d\U0001f31d\U0001f31d\U0001f31dregion", + "\U0001f31d\U0001f31d\U0001f31d\U0001f31d\U0001f31dcharges", + ], + "csv_config": { + "ignore_trailing_white_space": None, + "null_value": None, + "positive_inf": None, + "escape": None, + "ignore_leading_white_space": None, + "header": None, + "negative_inf": None, + "max_chars_per_column": None, + "char_to_escape_quote_escaping": None, + "nan_value": None, + "encoding": None, + "multiline": None, + "quote": None, + "comment": None, + "sep": None, + "max_columns": None, + "empty_value": None, + }, + "type": "csv", + "path": "s3a://cortex-examples/insurance.csv", + "drop_null": False, + }, + "parquet_data": None, + }, + "apis": { + "cost": { + "model_name": "dnn", + "index": 0, + "tags": {}, + "resource_type": "api", + "embed": None, + "workload_id": "zixo1slsqroa3dheq2mg", + "path": "/insurance/cost", + "name": "cost", + "model": "\U0001f31d\U0001f31d\U0001f31d\U0001f31d\U0001f31ddnn", + "file_path": "resources/apis.yaml", + "id": "a5341babd551a85d82b7b49e488983113e97eb6a0c4c8643986a92a3c44ec32", + "compute": {"gpu": 0, "cpu": None, "replicas": 1, "mem": None}, + } + }, + "cortex_config": { + "enable_telemetry": False, + "id": "da5e65b994ba4ebb069bdc19cf73da64aee79e5d83f466038dc75b3ef04fa63", + "operator_in_cluster": False, + "log_group": "cortex", + "namespace": "cortex", + "region": "us-west-2", + "api_version": "master", + }, + "metadata_root": "apps/insurance/data/2019-06-12-19-22-55-936375/483b4537be2db30b81be4809ab0c787f65230540b9b6779af7420922d654011/metadata", + "raw_columns": { + "raw_string_columns": { + "smoker": { + "index": 4, + "required": True, + "tags": {}, + "type": "STRING_COLUMN", + "embed": None, + "workload_id": "cs7eodcfv22u5ztm2l8g", + "name": "smoker", + "file_path": "resources/features.yaml", + "id": "3625e8890ecad8d2f64382dac7a06b7546b21eeba32de894265fbb95e8c2140", + "values": ["yes", "no"], + "compute": { + "driver_mem_overhead": None, + "executors": 1, + "executor_mem": '{"Quantity":"500Mi","UserString":"500Mi"}', + "driver_cpu": '{"Quantity":"1","UserString":"1"}', + "mem_overhead_factor": None, + "executor_mem_overhead": None, + "driver_mem": '{"Quantity":"500Mi","UserString":"500Mi"}', + "executor_cpu": '{"Quantity":"1","UserString":"1"}', + }, + "resource_type": "raw_column", + }, + "sex": { + "index": 1, + "required": True, + "tags": {}, + "type": "STRING_COLUMN", + "embed": None, + "workload_id": "cs7eodcfv22u5ztm2l8g", + "name": "sex", + "file_path": "resources/features.yaml", + "id": "fe94653b7cc3ea3d6cdee9e42b0feab7ad8f1a5b7d2bd4f9f6860c1b3ce5071", + "values": ["female", "male"], + "compute": { + "driver_mem_overhead": None, + "executors": 1, + "executor_mem": '{"Quantity":"500Mi","UserString":"500Mi"}', + "driver_cpu": '{"Quantity":"1","UserString":"1"}', + "mem_overhead_factor": None, + "executor_mem_overhead": None, + "driver_mem": '{"Quantity":"500Mi","UserString":"500Mi"}', + "executor_cpu": '{"Quantity":"1","UserString":"1"}', + }, + "resource_type": "raw_column", + }, + "region": { + "index": 5, + "required": True, + "tags": {}, + "type": "STRING_COLUMN", + "embed": None, + "workload_id": "cs7eodcfv22u5ztm2l8g", + "name": "region", + "file_path": "resources/features.yaml", + "id": "a51c74f814e6a28891b6677698fc5129794c8bf60d90741ac2e51d5a112e024", + "values": ["northwest", "northeast", "southwest", "southeast"], + "compute": { + "driver_mem_overhead": None, + "executors": 1, + "executor_mem": '{"Quantity":"500Mi","UserString":"500Mi"}', + "driver_cpu": '{"Quantity":"1","UserString":"1"}', + "mem_overhead_factor": None, + "executor_mem_overhead": None, + "driver_mem": '{"Quantity":"500Mi","UserString":"500Mi"}', + "executor_cpu": '{"Quantity":"1","UserString":"1"}', + }, + "resource_type": "raw_column", + }, + }, + "raw_float_columns": { + "bmi": { + "index": 2, + "resource_type": "raw_column", + "tags": {}, + "required": True, + "embed": None, + "workload_id": "cs7eodcfv22u5ztm2l8g", + "max": 60.0, + "name": "bmi", + "min": 0.0, + "file_path": "resources/features.yaml", + "id": "05734a3de0cc2a28050b811241eb27c59634c9a175382fc0c0d8ceeb0840036", + "values": None, + "compute": { + "driver_mem_overhead": None, + "executors": 1, + "executor_mem": '{"Quantity":"500Mi","UserString":"500Mi"}', + "driver_cpu": '{"Quantity":"1","UserString":"1"}', + "mem_overhead_factor": None, + "executor_mem_overhead": None, + "driver_mem": '{"Quantity":"500Mi","UserString":"500Mi"}', + "executor_cpu": '{"Quantity":"1","UserString":"1"}', + }, + "type": "FLOAT_COLUMN", + }, + "charges": { + "index": 6, + "resource_type": "raw_column", + "tags": {}, + "required": True, + "embed": None, + "workload_id": "cs7eodcfv22u5ztm2l8g", + "max": 100000.0, + "name": "charges", + "min": 0.0, + "file_path": "resources/features.yaml", + "id": "9e8a5917e4a8c803af17ad0792b89201b1784415bfeb7801ca0308c1a8f6090", + "values": None, + "compute": { + "driver_mem_overhead": None, + "executors": 1, + "executor_mem": '{"Quantity":"500Mi","UserString":"500Mi"}', + "driver_cpu": '{"Quantity":"1","UserString":"1"}', + "mem_overhead_factor": None, + "executor_mem_overhead": None, + "driver_mem": '{"Quantity":"500Mi","UserString":"500Mi"}', + "executor_cpu": '{"Quantity":"1","UserString":"1"}', + }, + "type": "FLOAT_COLUMN", + }, + }, + "raw_int_columns": { + "age": { + "index": 0, + "resource_type": "raw_column", + "tags": {}, + "required": True, + "embed": None, + "workload_id": "cs7eodcfv22u5ztm2l8g", + "max": 100, + "name": "age", + "min": 0, + "file_path": "resources/features.yaml", + "id": "2957b75e2f53e5c74f6036022ef1681d1c6444e1c8a7ca424813f642463f503", + "values": None, + "compute": { + "driver_mem_overhead": None, + "executors": 1, + "executor_mem": '{"Quantity":"500Mi","UserString":"500Mi"}', + "driver_cpu": '{"Quantity":"1","UserString":"1"}', + "mem_overhead_factor": None, + "executor_mem_overhead": None, + "driver_mem": '{"Quantity":"500Mi","UserString":"500Mi"}', + "executor_cpu": '{"Quantity":"1","UserString":"1"}', + }, + "type": "INT_COLUMN", + }, + "children": { + "index": 3, + "resource_type": "raw_column", + "tags": {}, + "required": True, + "embed": None, + "workload_id": "cs7eodcfv22u5ztm2l8g", + "max": 10, + "name": "children", + "min": 0, + "file_path": "resources/features.yaml", + "id": "a782d9ba0db596ad7c3e1a46be4dcd5a0c24e626e8ee308e941e33de04149bc", + "values": None, + "compute": { + "driver_mem_overhead": None, + "executors": 1, + "executor_mem": '{"Quantity":"500Mi","UserString":"500Mi"}', + "driver_cpu": '{"Quantity":"1","UserString":"1"}', + "mem_overhead_factor": None, + "executor_mem_overhead": None, + "driver_mem": '{"Quantity":"500Mi","UserString":"500Mi"}', + "executor_cpu": '{"Quantity":"1","UserString":"1"}', + }, + "type": "INT_COLUMN", + }, + }, + "raw_inferred_columns": {}, + }, + "python_packages": {}, + "status_prefix": "apps/insurance/resource_statuses", + "aggregators": { + "cortex.mean": { + "input": { + "_type": "FLOAT_COLUMN|INT_COLUMN", + "_min_count": None, + "_optional": False, + "_allow_null": False, + "_default": None, + "_max_count": None, + }, + "file_path": "/home/ubuntu/src/github.com/cortexlabs/cortex/pkg/aggregators/aggregators.yaml", + "impl_key": "aggregators/71c8aa1ce07d9d7059e305ed2b180504c36a41452e73fb251ef532bf679f851.py", + "path": "spark/mean.py", + "name": "mean", + "index": 13, + "embed": None, + "output_type": "FLOAT", + "id": "945acd1e82ed2178b7937215f6f82c814abfa967ec7da545e0a8c776759a37f", + "resource_type": "aggregator", + "namespace": "cortex", + }, + "cortex.stddev": { + "input": { + "_type": "FLOAT_COLUMN|INT_COLUMN", + "_min_count": None, + "_optional": False, + "_allow_null": False, + "_default": None, + "_max_count": None, + }, + "file_path": "/home/ubuntu/src/github.com/cortexlabs/cortex/pkg/aggregators/aggregators.yaml", + "impl_key": "aggregators/b8fa468e54c55083bf350f8b482c5323bd4bc12dd5fa0d859908ab2829aea7f.py", + "path": "spark/stddev.py", + "name": "stddev", + "index": 18, + "embed": None, + "output_type": "FLOAT", + "id": "5b076c542b7e8ad7ebd57c1bedad32875ccddccf66bb11c4d87a6058ec765c9", + "resource_type": "aggregator", + "namespace": "cortex", + }, + "cortex.collect_set_int": { + "input": { + "_type": "INT_COLUMN", + "_min_count": None, + "_optional": False, + "_allow_null": False, + "_default": None, + "_max_count": None, + }, + "file_path": "/home/ubuntu/src/github.com/cortexlabs/cortex/pkg/aggregators/aggregators.yaml", + "impl_key": "aggregators/4a26fe7551ea175ae68b998ea12766a8e4ffc6a6763816b141bc84d42275e90.py", + "path": "spark/collect_set_int.py", + "name": "collect_set_int", + "index": 2, + "embed": None, + "output_type": ["INT"], + "id": "d35f3139f033053d7cc923869d15501cab7078c31ca023a32e307edaac35ae3", + "resource_type": "aggregator", + "namespace": "cortex", + }, + }, + "environment": { + "file_path": "resources/environments.yaml", + "name": "dev", + "index": 0, + "embed": None, + "limit": { + "randomize": None, + "random_seed": None, + "fraction_of_rows": None, + "num_rows": None, + }, + "id": "483b4537be2db30b81be4809ab0c787f65230540b9b6779af7420922d654011", + "log_level": {"spark": "WARN", "tensorflow": "DEBUG"}, + }, + "aggregates": { + "charges_stddev": { + "tags": {}, + "id": "ddacaf489cf1311c12eb898068bbb9f214be66ac789f3f705d6e1c57cdc943a", + "type": "FLOAT", + "aggregator": "cortex.stddev", + "resource_type": "aggregate", + "aggregator_path": None, + "input": "\U0001f31d\U0001f31d\U0001f31d\U0001f31d\U0001f31dcharges", + "embed": None, + "name": "charges_stddev", + "file_path": "resources/features.yaml", + "workload_id": "cs7eodcfv22u5ztm2l8g", + "compute": { + "driver_mem_overhead": None, + "executors": 1, + "executor_mem": '{"Quantity":"500Mi","UserString":"500Mi"}', + "driver_cpu": '{"Quantity":"1","UserString":"1"}', + "mem_overhead_factor": None, + "executor_mem_overhead": None, + "driver_mem": '{"Quantity":"500Mi","UserString":"500Mi"}', + "executor_cpu": '{"Quantity":"1","UserString":"1"}', + }, + "key": "apps/insurance/data/2019-06-12-19-22-55-936375/483b4537be2db30b81be4809ab0c787f65230540b9b6779af7420922d654011/aggregates/ddacaf489cf1311c12eb898068bbb9f214be66ac789f3f705d6e1c57cdc943a.msgpack", + "index": 8, + }, + "charges_mean": { + "tags": {}, + "id": "c08af633f53a95ace3e71152b00d595bbbbbd282037ee5fd5e708adf1d96b38", + "type": "FLOAT", + "aggregator": "cortex.mean", + "resource_type": "aggregate", + "aggregator_path": None, + "input": "\U0001f31d\U0001f31d\U0001f31d\U0001f31d\U0001f31dcharges", + "embed": None, + "name": "charges_mean", + "file_path": "resources/features.yaml", + "workload_id": "cs7eodcfv22u5ztm2l8g", + "compute": { + "driver_mem_overhead": None, + "executors": 1, + "executor_mem": '{"Quantity":"500Mi","UserString":"500Mi"}', + "driver_cpu": '{"Quantity":"1","UserString":"1"}', + "mem_overhead_factor": None, + "executor_mem_overhead": None, + "driver_mem": '{"Quantity":"500Mi","UserString":"500Mi"}', + "executor_cpu": '{"Quantity":"1","UserString":"1"}', + }, + "key": "apps/insurance/data/2019-06-12-19-22-55-936375/483b4537be2db30b81be4809ab0c787f65230540b9b6779af7420922d654011/aggregates/c08af633f53a95ace3e71152b00d595bbbbbd282037ee5fd5e708adf1d96b38.msgpack", + "index": 7, + }, + "children_set": { + "tags": {}, + "id": "800ae506a9af394af56ae5fd7727801eed39a79f0a7bb3bcb58d5722953c748", + "type": ["INT"], + "aggregator": "cortex.collect_set_int", + "resource_type": "aggregate", + "aggregator_path": None, + "input": "\U0001f31d\U0001f31d\U0001f31d\U0001f31d\U0001f31dchildren", + "embed": None, + "name": "children_set", + "file_path": "resources/features.yaml", + "workload_id": "cs7eodcfv22u5ztm2l8g", + "compute": { + "driver_mem_overhead": None, + "executors": 1, + "executor_mem": '{"Quantity":"500Mi","UserString":"500Mi"}', + "driver_cpu": '{"Quantity":"1","UserString":"1"}', + "mem_overhead_factor": None, + "executor_mem_overhead": None, + "driver_mem": '{"Quantity":"500Mi","UserString":"500Mi"}', + "executor_cpu": '{"Quantity":"1","UserString":"1"}', + }, + "key": "apps/insurance/data/2019-06-12-19-22-55-936375/483b4537be2db30b81be4809ab0c787f65230540b9b6779af7420922d654011/aggregates/800ae506a9af394af56ae5fd7727801eed39a79f0a7bb3bcb58d5722953c748.msgpack", + "index": 9, + }, + }, + "app": { + "name": "insurance", + "id": "64e8937abc6d71cf2d2f0fe05e52c33666883443ae4e8af7924c71198caa1f9", + }, + "raw_dataset": { + "key": "apps/insurance/data/2019-06-12-19-22-55-936375/483b4537be2db30b81be4809ab0c787f65230540b9b6779af7420922d654011/data_raw/raw.parquet" + }, + "dataset_version": "2019-06-12-19-22-55-936375", + "transformers": { + "cortex.normalize": { + "input": { + "_type": { + "col": { + "_type": "FLOAT_COLUMN|INT_COLUMN", + "_min_count": None, + "_optional": False, + "_allow_null": False, + "_default": None, + "_max_count": None, + }, + "stddev": { + "_type": "INT|FLOAT", + "_min_count": None, + "_optional": False, + "_allow_null": False, + "_default": None, + "_max_count": None, + }, + "mean": { + "_type": "INT|FLOAT", + "_min_count": None, + "_optional": False, + "_allow_null": False, + "_default": None, + "_max_count": None, + }, + }, + "_min_count": None, + "_optional": False, + "_allow_null": False, + "_default": None, + "_max_count": None, + }, + "file_path": "/home/ubuntu/src/github.com/cortexlabs/cortex/pkg/transformers/transformers.yaml", + "impl_key": "/transformers/normalize.py", + "path": "normalize.py", + "name": "normalize", + "index": 1, + "embed": None, + "output_type": "FLOAT_COLUMN", + "id": "0dcf14b7b6208633b8805b29eb564f860736a07691176095fd4f19aa5ef75ab", + "resource_type": "transformer", + "namespace": "cortex", + } + }, + "models": { + "dnn": { + "training_input": None, + "data_partition_ratio": {"evaluation": 0.2, "training": 0.8}, + "dataset_compute": { + "driver_mem_overhead": None, + "executors": 1, + "executor_mem": '{"Quantity":"500Mi","UserString":"500Mi"}', + "driver_cpu": '{"Quantity":"1","UserString":"1"}', + "mem_overhead_factor": None, + "executor_mem_overhead": None, + "driver_mem": '{"Quantity":"500Mi","UserString":"500Mi"}', + "executor_cpu": '{"Quantity":"1","UserString":"1"}', + }, + "training": { + "tf_random_seed": 1788, + "keep_checkpoint_every_n_hours": 10000, + "num_steps": 10000, + "keep_checkpoint_max": 3, + "log_step_count_steps": 100, + "save_checkpoints_steps": None, + "num_epochs": None, + "save_summary_steps": 100, + "batch_size": 40, + "save_checkpoints_secs": 600, + "tf_randomize_seed": False, + "shuffle": True, + }, + "input": { + "features": [ + "\U0001f31d\U0001f31d\U0001f31d\U0001f31d\U0001f31dage", + "\U0001f31d\U0001f31d\U0001f31d\U0001f31d\U0001f31dsex", + "\U0001f31d\U0001f31d\U0001f31d\U0001f31d\U0001f31dbmi", + "\U0001f31d\U0001f31d\U0001f31d\U0001f31d\U0001f31dchildren", + "\U0001f31d\U0001f31d\U0001f31d\U0001f31d\U0001f31dsmoker", + "\U0001f31d\U0001f31d\U0001f31d\U0001f31d\U0001f31dregion", + ], + "aggregates": { + "region_vocab": ["northwest", "northeast", "southwest", "southeast"], + "smoker_vocab": ["yes", "no"], + "children_set": "\U0001f31d\U0001f31d\U0001f31d\U0001f31d\U0001f31dchildren_set", + "sex_vocab": ["female", "male"], + "age_buckets": [15, 20, 25, 35, 40, 45, 50, 55, 60, 65], + "bmi_buckets": [15, 20, 25, 35, 40, 45, 50, 55], + }, + }, + "target_column": "\U0001f31d\U0001f31d\U0001f31d\U0001f31d\U0001f31dcharges_normalized", + "embed": None, + "hparams": {"hidden_units": [100, 100, 100]}, + "workload_id": "gj4bdao7ocrfa1xbls9y", + "file_path": "resources/models.yaml", + "evaluation": { + "throttle_secs": 600, + "start_delay_secs": 120, + "batch_size": 40, + "num_epochs": None, + "shuffle": False, + "num_steps": 100, + }, + "estimator_path": "implementations/models/dnn.py", + "dataset": { + "file_path": "resources/models.yaml", + "model_name": "dnn", + "workload_id": "cs7eodcfv22u5ztm2l8g", + "name": "dnn/training_dataset", + "index": 0, + "embed": None, + "train_key": "apps/insurance/data/2019-06-12-19-22-55-936375/483b4537be2db30b81be4809ab0c787f65230540b9b6779af7420922d654011/data_training/1020431bc033a0aa2cc427349486b7636f2b0bb65e75553811fea82b9fa12fd/train.tfrecord", + "id": "1020431bc033a0aa2cc427349486b7636f2b0bb65e75553811fea82b9fa12fd", + "resource_type": "training_dataset", + "eval_key": "apps/insurance/data/2019-06-12-19-22-55-936375/483b4537be2db30b81be4809ab0c787f65230540b9b6779af7420922d654011/data_training/1020431bc033a0aa2cc427349486b7636f2b0bb65e75553811fea82b9fa12fd/eval.tfrecord", + }, + "id": "4145945c21c9f9fd616e666fef2b61aa501e8db15e756e9dcd2f2e3ec47575e", + "index": 0, + "resource_type": "model", + "tags": {}, + "prediction_key": "", + "name": "dnn", + "estimator": "c5e53f81d6d57bd1b46ed04020509401b139864d483c35e781338f11b2cc301", + "compute": {"gpu": None, "cpu": None, "mem": None}, + "key": "apps/insurance/data/2019-06-12-19-22-55-936375/483b4537be2db30b81be4809ab0c787f65230540b9b6779af7420922d654011/models/4145945c21c9f9fd616e666fef2b61aa501e8db15e756e9dcd2f2e3ec47575e.zip", + } + }, +} diff --git a/pkg/workloads/spark_job/test/integration/iris_test.py b/pkg/workloads/spark_job/test/integration/insurance_test.py similarity index 71% rename from pkg/workloads/spark_job/test/integration/iris_test.py rename to pkg/workloads/spark_job/test/integration/insurance_test.py index 7fcff3d3c4..50bea0d0a3 100644 --- a/pkg/workloads/spark_job/test/integration/iris_test.py +++ b/pkg/workloads/spark_job/test/integration/insurance_test.py @@ -17,7 +17,7 @@ from spark_job import spark_job from lib.exceptions import UserException from lib import Context -from test.integration import iris_context +from test.integration import insurance_context import pytest from pyspark.sql.types import * @@ -31,22 +31,22 @@ pytestmark = pytest.mark.usefixtures("spark") -iris_data = [ - [5.1, 3.5, 1.4, 0.2, "Iris-setosa"], - [4.9, 3.0, 1.4, 0.2, "Iris-setosa"], - [4.7, 3.2, 1.3, 0.2, "Iris-setosa"], - [4.6, 3.1, 1.5, 0.2, "Iris-setosa"], - [5.0, 3.6, 1.4, 0.2, "Iris-setosa"], - [7.0, 3.2, 4.7, 1.4, "Iris-versicolor"], - [6.4, 3.2, 4.5, 1.5, "Iris-versicolor"], - [6.9, 3.1, 4.9, 1.5, "Iris-versicolor"], - [5.5, 2.3, 4.0, 1.3, "Iris-versicolor"], - [6.5, 2.8, 4.6, 1.5, "Iris-versicolor"], - [6.3, 3.3, 6.0, 2.5, "Iris-virginica"], - [5.8, 2.7, 5.1, 1.9, "Iris-virginica"], - [7.1, 3.0, 5.9, 2.1, "Iris-virginica"], - [6.3, 2.9, 5.6, 1.8, "Iris-virginica"], - [6.5, 3.0, 5.8, 2.2, "Iris-virginica"], +insurance_data = [ + [19, "female", 27.9, 0, "yes", "southwest", 16884.924], + [18, "male", 33.77, 1, "no", "southeast", 1725.5523], + [28, "male", 33, 3, "no", "southeast", 4449.462], + [33, "male", 22.705, 0, "no", "northwest", 21984.47061], + [32, "male", 28.88, 0, "no", "northwest", 3866.8552], + [31, "female", 25.74, 0, "no", "southeast", 3756.6216], + [46, "female", 33.44, 1, "no", "southeast", 8240.5896], + [37, "female", 27.74, 3, "no", "northwest", 7281.5056], + [37, "male", 29.83, 2, "no", "northeast", 6406.4107], + [60, "female", 25.84, 0, "no", "northwest", 28923.13692], + [25, "male", 26.22, 0, "no", "northeast", 2721.3208], + [62, "female", 26.29, 0, "yes", "southeast", 27808.7251], + [23, "male", 34.4, 0, "no", "southwest", 1826.843], + [56, "female", 39.82, 0, "no", "southeast", 11090.7178], + [27, "male", 42.13, 0, "yes", "southeast", 39611.7577], ] @@ -54,11 +54,11 @@ def test_simple_end_to_end(spark): local_storage_path = Path("/workspace/local_storage") local_storage_path.mkdir(parents=True, exist_ok=True) should_ingest = True - input_data_path = os.path.join(str(local_storage_path), "iris.csv") + input_data_path = os.path.join(str(local_storage_path), "insurance.csv") - raw_ctx = iris_context.get(input_data_path) + raw_ctx = insurance_context.get(input_data_path) - workload_id = raw_ctx["raw_columns"]["raw_float_columns"]["sepal_length"]["workload_id"] + workload_id = raw_ctx["raw_columns"]["raw_string_columns"]["smoker"]["workload_id"] cols_to_validate = [] @@ -66,8 +66,8 @@ def test_simple_end_to_end(spark): for raw_column in column_type.values(): cols_to_validate.append(raw_column["id"]) - iris_data_string = "\n".join(",".join(str(val) for val in line) for line in iris_data) - Path(os.path.join(str(local_storage_path), "iris.csv")).write_text(iris_data_string) + insurance_data_string = "\n".join(",".join(str(val) for val in line) for line in insurance_data) + Path(os.path.join(str(local_storage_path), "insurance.csv")).write_text(insurance_data_string) ctx = Context( raw_obj=raw_ctx, cache_dir="/workspace/cache", local_storage_path=str(local_storage_path) @@ -86,7 +86,7 @@ def test_simple_end_to_end(spark): cols_to_aggregate = [r["id"] for r in raw_ctx["aggregates"].values()] - spark_job.run_custom_aggregators(spark, ctx, cols_to_aggregate, raw_df) + spark_job.run_aggregators(spark, ctx, cols_to_aggregate, raw_df) for aggregate_id in cols_to_aggregate: for aggregate_resource in raw_ctx["aggregates"].values(): diff --git a/pkg/workloads/spark_job/test/integration/iris_context.py b/pkg/workloads/spark_job/test/integration/iris_context.py deleted file mode 100644 index d36a9e4dd1..0000000000 --- a/pkg/workloads/spark_job/test/integration/iris_context.py +++ /dev/null @@ -1,572 +0,0 @@ -# Copyright 2019 Cortex Labs, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import consts - -""" -HOW TO GENERATE CONTEXT - -1. cx deploy -2. get a path to a context -3. ssh into a docker container (spark or tf_train) -docker run -it --entrypoint "/bin/bash" cortexlabs/spark -4. run the following in python3 shell - -from lib import util -from lib.storage import S3 -bucket, key = S3.deconstruct_s3_path('s3:///apps//contexts/.msgpack') -S3(bucket, client_config={}).get_msgpack(key) -""" - - -def get(input_data_path): - raw_ctx["environment_data"]["csv_data"]["path"] = input_data_path - raw_ctx["cortex_config"]["api_version"] = consts.CORTEX_VERSION - - return raw_ctx - - -raw_ctx = { - "raw_dataset": { - "key": "apps/iris/data/2019-03-08-09-58-35-701834/3976c5679bcf7cb550453802f4c3a9333c5f193f6097f1f5642de48d2397554/data_raw/raw.parquet" - }, - "aggregates": { - "class_index": { - "workload_id": "jjd3l0fi4fhwqtgmpatg", - "key": "apps/iris/data/2019-03-08-09-58-35-701834/3976c5679bcf7cb550453802f4c3a9333c5f193f6097f1f5642de48d2397554/aggregates/54ead5d565a57cad06972cc11d2f01f05c4e9e1dbfc525d1fa66b7999213722.msgpack", - "tags": {}, - "type": {"index": ["STRING"], "reversed_index": {"STRING": "INT"}}, - "embed": None, - "file_path": "resources/aggregates.yaml", - "name": "class_index", - "id": "54ead5d565a57cad06972cc11d2f01f05c4e9e1dbfc525d1fa66b7999213722", - "aggregator": "cortex.index_string", - "index": 8, - "id_with_tags": "2bd062924097b0add1143dab547387307cf68f40870f52443ce5902006e00d9", - "resource_type": "aggregate", - "inputs": {"columns": {"col": "class"}, "args": {}}, - }, - "sepal_width_mean": { - "workload_id": "jjd3l0fi4fhwqtgmpatg", - "key": "apps/iris/data/2019-03-08-09-58-35-701834/3976c5679bcf7cb550453802f4c3a9333c5f193f6097f1f5642de48d2397554/aggregates/38159191e6018b929b42c7e73e8bfd19f5778bba79e84d9909f5d448ac15fc9.msgpack", - "tags": {}, - "type": "FLOAT", - "embed": None, - "file_path": "resources/aggregates.yaml", - "name": "sepal_width_mean", - "id": "38159191e6018b929b42c7e73e8bfd19f5778bba79e84d9909f5d448ac15fc9", - "aggregator": "cortex.mean", - "index": 2, - "id_with_tags": "850aaab46427d39c331dd996e4a44af1bb326e45b47caaf699a5676863463f6", - "resource_type": "aggregate", - "inputs": {"columns": {"col": "sepal_width"}, "args": {}}, - }, - "petal_width_stddev": { - "workload_id": "jjd3l0fi4fhwqtgmpatg", - "key": "apps/iris/data/2019-03-08-09-58-35-701834/3976c5679bcf7cb550453802f4c3a9333c5f193f6097f1f5642de48d2397554/aggregates/986fd2cbc2b1d74aa06cf533b67d7dd7f54b5b7bf58689c58d0ec8c2568bae8.msgpack", - "tags": {}, - "type": "FLOAT", - "embed": None, - "file_path": "resources/aggregates.yaml", - "name": "petal_width_stddev", - "id": "986fd2cbc2b1d74aa06cf533b67d7dd7f54b5b7bf58689c58d0ec8c2568bae8", - "aggregator": "cortex.stddev", - "index": 7, - "id_with_tags": "b8f7440e0f71ec502cccbead6b52da80c119b833baaee1d00315542a6ab907c", - "resource_type": "aggregate", - "inputs": {"columns": {"col": "petal_width"}, "args": {}}, - }, - "petal_width_mean": { - "workload_id": "jjd3l0fi4fhwqtgmpatg", - "key": "apps/iris/data/2019-03-08-09-58-35-701834/3976c5679bcf7cb550453802f4c3a9333c5f193f6097f1f5642de48d2397554/aggregates/317856401885874d95fffd349fe0878595e8c04833ba63c4546233ffd899e4d.msgpack", - "tags": {}, - "type": "FLOAT", - "embed": None, - "file_path": "resources/aggregates.yaml", - "name": "petal_width_mean", - "id": "317856401885874d95fffd349fe0878595e8c04833ba63c4546233ffd899e4d", - "aggregator": "cortex.mean", - "index": 6, - "id_with_tags": "d8b09430972c97c3679fe3e90d67b69f19f1effcc2b587eb0876fb8f7d7dc55", - "resource_type": "aggregate", - "inputs": {"columns": {"col": "petal_width"}, "args": {}}, - }, - "sepal_length_stddev": { - "workload_id": "jjd3l0fi4fhwqtgmpatg", - "key": "apps/iris/data/2019-03-08-09-58-35-701834/3976c5679bcf7cb550453802f4c3a9333c5f193f6097f1f5642de48d2397554/aggregates/e7191b1effd1e4d351580f251aa35dc7c0b9825745b207fbb8cce904c94a937.msgpack", - "tags": {}, - "type": "FLOAT", - "embed": None, - "file_path": "resources/aggregates.yaml", - "name": "sepal_length_stddev", - "id": "e7191b1effd1e4d351580f251aa35dc7c0b9825745b207fbb8cce904c94a937", - "aggregator": "cortex.stddev", - "index": 1, - "id_with_tags": "a7da8887f18dfac4523dd1f2135913046001c1dc04701c9a1010e6b049db943", - "resource_type": "aggregate", - "inputs": {"columns": {"col": "sepal_length"}, "args": {}}, - }, - "petal_length_stddev": { - "workload_id": "jjd3l0fi4fhwqtgmpatg", - "key": "apps/iris/data/2019-03-08-09-58-35-701834/3976c5679bcf7cb550453802f4c3a9333c5f193f6097f1f5642de48d2397554/aggregates/6a9481dc91eb3f82458356f1f5f98da6f25a69b460679e09a67988543f79e3f.msgpack", - "tags": {}, - "type": "FLOAT", - "embed": None, - "file_path": "resources/aggregates.yaml", - "name": "petal_length_stddev", - "id": "6a9481dc91eb3f82458356f1f5f98da6f25a69b460679e09a67988543f79e3f", - "aggregator": "cortex.stddev", - "index": 5, - "id_with_tags": "325ae37b8684886d194ca37af19f4660c549c07880833a45ced4848895670a4", - "resource_type": "aggregate", - "inputs": {"columns": {"col": "petal_length"}, "args": {}}, - }, - "sepal_width_stddev": { - "workload_id": "jjd3l0fi4fhwqtgmpatg", - "key": "apps/iris/data/2019-03-08-09-58-35-701834/3976c5679bcf7cb550453802f4c3a9333c5f193f6097f1f5642de48d2397554/aggregates/64594f51d3cfb55a3776d013102d5fdab29bfe7332ce0c4f7c916d64d3ca29f.msgpack", - "tags": {}, - "type": "FLOAT", - "embed": None, - "file_path": "resources/aggregates.yaml", - "name": "sepal_width_stddev", - "id": "64594f51d3cfb55a3776d013102d5fdab29bfe7332ce0c4f7c916d64d3ca29f", - "aggregator": "cortex.stddev", - "index": 3, - "id_with_tags": "29e071e808b973aa21af717e9ee3d91f077b40baa5b890d4a58c9dc64a12d4b", - "resource_type": "aggregate", - "inputs": {"columns": {"col": "sepal_width"}, "args": {}}, - }, - "sepal_length_mean": { - "workload_id": "jjd3l0fi4fhwqtgmpatg", - "key": "apps/iris/data/2019-03-08-09-58-35-701834/3976c5679bcf7cb550453802f4c3a9333c5f193f6097f1f5642de48d2397554/aggregates/690f97171881c08770cac55137c672167a84324efba478cfd583ec98dd18844.msgpack", - "tags": {}, - "type": "FLOAT", - "embed": None, - "file_path": "resources/aggregates.yaml", - "name": "sepal_length_mean", - "id": "690f97171881c08770cac55137c672167a84324efba478cfd583ec98dd18844", - "aggregator": "cortex.mean", - "index": 0, - "id_with_tags": "1feedb4635d8955765dc82f58122e9756b1b797b3a7dcf3477ec99e655f05f2", - "resource_type": "aggregate", - "inputs": {"columns": {"col": "sepal_length"}, "args": {}}, - }, - "petal_length_mean": { - "workload_id": "jjd3l0fi4fhwqtgmpatg", - "key": "apps/iris/data/2019-03-08-09-58-35-701834/3976c5679bcf7cb550453802f4c3a9333c5f193f6097f1f5642de48d2397554/aggregates/4deea2705f55fa8a38658546ea5c2d31e37d4aad43a874e091f1c1667b63a6e.msgpack", - "tags": {}, - "type": "FLOAT", - "embed": None, - "file_path": "resources/aggregates.yaml", - "name": "petal_length_mean", - "id": "4deea2705f55fa8a38658546ea5c2d31e37d4aad43a874e091f1c1667b63a6e", - "aggregator": "cortex.mean", - "index": 4, - "id_with_tags": "34a2792b8c2c8489ea3c8db81533a946d8005eb9e547a4863673e0eae011259", - "resource_type": "aggregate", - "inputs": {"columns": {"col": "petal_length"}, "args": {}}, - }, - }, - "transformers": { - "cortex.normalize": { - "id": "eab74305749aa9eaff514882156111fd49b8b740018da396693147cd4443a9e", - "impl_key": "/transformers/normalize.py", - "embed": None, - "file_path": "/home/ubuntu/src/github.com/cortexlabs/cortex/pkg/transformers/transformers.yaml", - "name": "normalize", - "namespace": "cortex", - "path": "", - "output_type": "FLOAT_COLUMN", - "index": 1, - "id_with_tags": "eab74305749aa9eaff514882156111fd49b8b740018da396693147cd4443a9e", - "resource_type": "transformer", - "inputs": { - "columns": {"num": "FLOAT_COLUMN|INT_COLUMN"}, - "args": {"stddev": "INT|FLOAT", "mean": "INT|FLOAT"}, - }, - }, - "cortex.index_string": { - "id": "81bcee8795009e19f3378b2c3ea10fa6048741f2ad6ef841e5ed55c81319a0c", - "impl_key": "/transformers/index_string.py", - "embed": None, - "file_path": "/home/ubuntu/src/github.com/cortexlabs/cortex/pkg/transformers/transformers.yaml", - "name": "index_string", - "namespace": "cortex", - "path": "", - "output_type": "INT_COLUMN", - "index": 2, - "id_with_tags": "81bcee8795009e19f3378b2c3ea10fa6048741f2ad6ef841e5ed55c81319a0c", - "resource_type": "transformer", - "inputs": { - "columns": {"text": "STRING_COLUMN"}, - "args": {"indexes": {"index": ["STRING"], "reversed_index": {"STRING": "INT"}}}, - }, - }, - }, - "python_packages": {}, - "key": "apps/iris/contexts/33d7d279749ec97d342614cd77c5e81314a74ae0c0407ff71a120e83736a658.msgpack", - "raw_columns": { - "raw_float_columns": { - "sepal_length": { - "tags": {}, - "workload_id": "jjd3l0fi4fhwqtgmpatg", - "values": None, - "embed": None, - "type": "FLOAT_COLUMN", - "required": False, - "file_path": "resources/raw_columns.yaml", - "name": "sepal_length", - "id": "9479e84647a126fe5ce36e6eeac35aacb7156cd8c8e0773e572a91a7f9c1e92", - "min": 0.0, - "max": 10.0, - "index": 0, - "resource_type": "raw_column", - "id_with_tags": "b68cec533b973640329709bd4f7628bd8e8da5e3040bf227a1df5a7ce05807c", - }, - "sepal_width": { - "tags": {}, - "workload_id": "jjd3l0fi4fhwqtgmpatg", - "values": None, - "embed": None, - "type": "FLOAT_COLUMN", - "required": False, - "file_path": "resources/raw_columns.yaml", - "name": "sepal_width", - "id": "690b9a1c2e717c7ec4304804d4d7fd54fba554d8ce4829062467a3dc4d5f0f8", - "min": 0.0, - "max": 10.0, - "index": 1, - "resource_type": "raw_column", - "id_with_tags": "01430cc2265647e61dd8d8f9bec1b3918468968bdd8b27d0c6088848501da44", - }, - "petal_length": { - "tags": {}, - "workload_id": "jjd3l0fi4fhwqtgmpatg", - "values": None, - "embed": None, - "type": "FLOAT_COLUMN", - "required": False, - "file_path": "resources/raw_columns.yaml", - "name": "petal_length", - "id": "eb81ff65ce934e409ce18627cbb7d77c804289404fd62850fa5f915a1a9d87f", - "min": 0.0, - "max": 10.0, - "index": 2, - "resource_type": "raw_column", - "id_with_tags": "b89d5ef63dc22bfeb81684631d0f6e387e33cb26a52f7bb1b5de73ab49f40df", - }, - "petal_width": { - "tags": {}, - "workload_id": "jjd3l0fi4fhwqtgmpatg", - "values": None, - "embed": None, - "type": "FLOAT_COLUMN", - "required": False, - "file_path": "resources/raw_columns.yaml", - "name": "petal_width", - "id": "98ee0c5e9935442ea77835297777f4ab916830db5cb1ec82590d8b03f53eb6c", - "min": 0.0, - "max": 10.0, - "index": 3, - "resource_type": "raw_column", - "id_with_tags": "0d148b5ccc0e9266e3fd349793efd12e6880c4f1577d311e6bbef792b939d85", - }, - }, - "raw_string_columns": { - "class": { - "workload_id": "jjd3l0fi4fhwqtgmpatg", - "values": ["Iris-setosa", "Iris-versicolor", "Iris-virginica"], - "embed": None, - "required": False, - "type": "STRING_COLUMN", - "tags": {}, - "file_path": "resources/raw_columns.yaml", - "name": "class", - "id": "397a3c2785bcfdab244acdd11d65b415e3e4258b762deb8c17e600ce187c425", - "index": 4, - "resource_type": "raw_column", - "id_with_tags": "7fa09a7ca3544e1631bd60a792d23719a76bc9f77a350277a68efd3670a1f66", - } - }, - "raw_int_columns": {}, - }, - "environment_data": { - "csv_data": { - "drop_null": False, - "type": "csv", - "path": "/workspace/iris.csv", - "csv_config": { - "negative_inf": None, - "null_value": None, - "sep": None, - "ignore_leading_white_space": None, - "empty_value": None, - "max_columns": None, - "positive_inf": None, - "max_chars_per_column": None, - "nan_value": None, - "comment": None, - "ignore_trailing_white_space": None, - "multiline": None, - "char_to_escape_quote_escaping": None, - "encoding": None, - "escape": None, - "quote": None, - "header": None, - }, - "schema": ["sepal_length", "sepal_width", "petal_length", "petal_width", "class"], - }, - "parquet_data": None, - }, - "apis": { - "iris-type": { - "workload_id": "rvxejtv4uoy3jfuawokc", - "embed": None, - "tags": {}, - "file_path": "resources/apis.yaml", - "name": "iris-type", - "id": "07c60601edc687a5c3106bad0ad49fef497bb207487c5fd0a36226068a3166b", - "path": "/iris/iris-type", - "index": 0, - "id_with_tags": "e574b47b359b60f47132d3c919fc5c82fb5684c7784665a0e61bb42316b5b31", - "resource_type": "api", - "compute": {"mem": None, "replicas": 1, "cpu": None, "gpu": 0}, - "model_name": "dnn", - } - }, - "constants": {}, - "id": "33d7d279749ec97d342614cd77c5e81314a74ae0c0407ff71a120e83736a658", - "dataset_version": "2019-03-08-09-58-35-701834", - "environment": { - "log_level": {"spark": "WARN", "tensorflow": "INFO"}, - "index": 0, - "limit": { - "random_seed": None, - "randomize": None, - "fraction_of_rows": None, - "num_rows": None, - }, - "embed": None, - "file_path": "resources/environments.yaml", - "name": "dev", - "id": "3976c5679bcf7cb550453802f4c3a9333c5f193f6097f1f5642de48d2397554", - }, - "status_prefix": "apps/iris/resource_statuses", - "transformed_columns": { - "sepal_length_normalized": { - "workload_id": "jjd3l0fi4fhwqtgmpatg", - "embed": None, - "type": "FLOAT_COLUMN", - "tags": {}, - "file_path": "resources/transformed_columns.yaml", - "name": "sepal_length_normalized", - "id": "a44a0acbb54123d03d67b47469cf83712df2045b90aa99036dab99f37583d46", - "transformer": "cortex.normalize", - "index": 0, - "id_with_tags": "d4f335e49dec681bd7a79766a79ab7682c8205e51a2ec46e40207785835f35a", - "resource_type": "transformed_column", - "inputs": { - "columns": {"num": "sepal_length"}, - "args": {"stddev": "sepal_length_stddev", "mean": "sepal_length_mean"}, - }, - }, - "petal_width_normalized": { - "workload_id": "jjd3l0fi4fhwqtgmpatg", - "embed": None, - "type": "FLOAT_COLUMN", - "tags": {}, - "file_path": "resources/transformed_columns.yaml", - "name": "petal_width_normalized", - "id": "41221f15eea0328c2987c44171f323529bfa7a196a697b1a87ff4915c143531", - "transformer": "cortex.normalize", - "index": 3, - "id_with_tags": "a3ad56b29de6467931c40c992ac84b8a238a9f6d9611345bf4e338df314bf6d", - "resource_type": "transformed_column", - "inputs": { - "columns": {"num": "petal_width"}, - "args": {"stddev": "petal_width_stddev", "mean": "petal_width_mean"}, - }, - }, - "sepal_width_normalized": { - "workload_id": "jjd3l0fi4fhwqtgmpatg", - "embed": None, - "type": "FLOAT_COLUMN", - "tags": {}, - "file_path": "resources/transformed_columns.yaml", - "name": "sepal_width_normalized", - "id": "360fe839dbc1ee1db2d0e0f0e8ca0d1a2cc54aed69e29843e0361d285ddb700", - "transformer": "cortex.normalize", - "index": 1, - "id_with_tags": "9aa22e1962c62aab2ea56f4cfc1369ed3e559bfaa70a5e8a2e17b82d1042f48", - "resource_type": "transformed_column", - "inputs": { - "columns": {"num": "sepal_width"}, - "args": {"stddev": "sepal_width_stddev", "mean": "sepal_width_mean"}, - }, - }, - "petal_length_normalized": { - "workload_id": "jjd3l0fi4fhwqtgmpatg", - "embed": None, - "type": "FLOAT_COLUMN", - "tags": {}, - "file_path": "resources/transformed_columns.yaml", - "name": "petal_length_normalized", - "id": "7cbc111099c4bf38e27d6a05f9b2d37bdb9038f6f934be10298a718deae6db5", - "transformer": "cortex.normalize", - "index": 2, - "id_with_tags": "0c1923b1cc93679e8df3ec21f212656c050cb32d980604ffbd89fac0815ddcc", - "resource_type": "transformed_column", - "inputs": { - "columns": {"num": "petal_length"}, - "args": {"stddev": "petal_length_stddev", "mean": "petal_length_mean"}, - }, - }, - "class_indexed": { - "workload_id": "jjd3l0fi4fhwqtgmpatg", - "embed": None, - "type": "INT_COLUMN", - "tags": {}, - "file_path": "resources/transformed_columns.yaml", - "name": "class_indexed", - "id": "6097e63c46b62b3cf70d86d9e1282bdd77d15d62bc4d132d9154bb5ddc1861d", - "transformer": "cortex.index_string", - "index": 4, - "id_with_tags": "f3b94376e20e64f67d0808c3589d8a4bb09196e38ff81ba775408be38148c1e", - "resource_type": "transformed_column", - "inputs": {"columns": {"text": "class"}, "args": {"indexes": "class_index"}}, - }, - }, - "models": { - "dnn": { - "aggregates": ["class_index"], - "impl_id": "2d7091a3fff24213d9e67cf2a846e5e31fd27f406fffbdb341140419f138f48", - "training_columns": [], - "key": "apps/iris/data/2019-03-08-09-58-35-701834/3976c5679bcf7cb550453802f4c3a9333c5f193f6097f1f5642de48d2397554/models/4989cb227eb56c2d3ccc1904cb3dbcab9a1ceb1ebf8cdb9f95a20b86a8df019.zip", - "embed": None, - "type": "classification", - "tags": {}, - "id": "4989cb227eb56c2d3ccc1904cb3dbcab9a1ceb1ebf8cdb9f95a20b86a8df019", - "name": "dnn", - "impl_key": "model_implementations/2d7091a3fff24213d9e67cf2a846e5e31fd27f406fffbdb341140419f138f48.py", - "feature_columns": [ - "sepal_length_normalized", - "sepal_width_normalized", - "petal_length_normalized", - "petal_width_normalized", - ], - "target_column": "class_indexed", - "resource_type": "model", - "hparams": {"hidden_units": [4, 2]}, - "prediction_key": "", - "workload_id": "aokhfrzyw6ju730nbwli", - "dataset": { - "train_key": "apps/iris/data/2019-03-08-09-58-35-701834/3976c5679bcf7cb550453802f4c3a9333c5f193f6097f1f5642de48d2397554/data_training/5bdaecf9c5a0094d4a18df15348f709be8acfd3c6faf72c3f243956c3896e76/train.tfrecord", - "workload_id": "jjd3l0fi4fhwqtgmpatg", - "eval_key": "apps/iris/data/2019-03-08-09-58-35-701834/3976c5679bcf7cb550453802f4c3a9333c5f193f6097f1f5642de48d2397554/data_training/5bdaecf9c5a0094d4a18df15348f709be8acfd3c6faf72c3f243956c3896e76/eval.tfrecord", - "embed": None, - "file_path": "resources/models.yaml", - "name": "dnn/training_dataset", - "id": "5bdaecf9c5a0094d4a18df15348f709be8acfd3c6faf72c3f243956c3896e76", - "index": 0, - "id_with_tags": "166a9c191c7d058a596fc2396ded7c39e27c8021bffd7b91ff2bbb07e26f729", - "resource_type": "training_dataset", - "model_name": "dnn", - }, - "data_partition_ratio": {"evaluation": 0.2, "training": 0.8}, - "file_path": "resources/models.yaml", - "path": "implementations/models/dnn.py", - "training": { - "keep_checkpoint_every_n_hours": 10000, - "shuffle": True, - "batch_size": 10, - "log_step_count_steps": 100, - "num_epochs": None, - "keep_checkpoint_max": 3, - "tf_random_seed": 1788, - "save_summary_steps": 100, - "save_checkpoints_steps": None, - "num_steps": 1000, - "save_checkpoints_secs": 600, - "tf_randomize_seed": False, - }, - "evaluation": { - "shuffle": False, - "batch_size": 40, - "num_epochs": None, - "throttle_secs": 600, - "num_steps": 100, - "start_delay_secs": 120, - }, - "index": 0, - "compute": {"mem": None, "cpu": None, "gpu": None}, - "id_with_tags": "4989cb227eb56c2d3ccc1904cb3dbcab9a1ceb1ebf8cdb9f95a20b86a8df019", - } - }, - "app": { - "id": "47612b3175fece07f6c3e91992412c5b16ca88a9068cb72fecbcf653eb5ffcd", - "name": "iris", - }, - "cortex_config": { - "region": "us-west-2", - "log_group": "cortex", - "api_version": "master", - "id": "da5e65b994ba4ebb069bdc19cf73da64aee79e5d83f466038dc75b3ef04fa63", - }, - "root": "apps/iris/data/2019-03-08-09-58-35-701834/3976c5679bcf7cb550453802f4c3a9333c5f193f6097f1f5642de48d2397554", - "metadata_root": "apps/iris/data/2019-03-08-09-58-35-701834/3976c5679bcf7cb550453802f4c3a9333c5f193f6097f1f5642de48d2397554/metadata", - "aggregators": { - "cortex.mean": { - "id": "a68b354ddadc2e14348698e03af74db72cba92d7acb162e3163629e3e343373", - "impl_key": "aggregators/71c8aa1ce07d9d7059e305ed2b180504c36a41452e73fb251ef532bf679f851.py", - "embed": None, - "file_path": "/home/ubuntu/src/github.com/cortexlabs/cortex/pkg/aggregators/aggregators.yaml", - "name": "mean", - "namespace": "cortex", - "path": "", - "output_type": "FLOAT", - "index": 13, - "id_with_tags": "a68b354ddadc2e14348698e03af74db72cba92d7acb162e3163629e3e343373", - "resource_type": "aggregator", - "inputs": {"columns": {"col": "FLOAT_COLUMN|INT_COLUMN"}, "args": {}}, - }, - "cortex.stddev": { - "id": "51ca32fabf602a0c8fd7a9b4f5bb9a3d92bb6b3bbc356a727d7a25b19787353", - "impl_key": "aggregators/b8fa468e54c55083bf350f8b482c5323bd4bc12dd5fa0d859908ab2829aea7f.py", - "embed": None, - "file_path": "/home/ubuntu/src/github.com/cortexlabs/cortex/pkg/aggregators/aggregators.yaml", - "name": "stddev", - "namespace": "cortex", - "path": "", - "output_type": "FLOAT", - "index": 18, - "id_with_tags": "51ca32fabf602a0c8fd7a9b4f5bb9a3d92bb6b3bbc356a727d7a25b19787353", - "resource_type": "aggregator", - "inputs": {"columns": {"col": "FLOAT_COLUMN|INT_COLUMN"}, "args": {}}, - }, - "cortex.index_string": { - "id": "c32f21159377d5dc3ddc664fe5cabbe7b275eadc82b5f6ed711faa1a988deb4", - "impl_key": "/aggregators/index_string.py", - "embed": None, - "file_path": "/home/ubuntu/src/github.com/cortexlabs/cortex/pkg/aggregators/aggregators.yaml", - "name": "index_string", - "namespace": "cortex", - "path": "", - "output_type": {"index": ["STRING"], "reversed_index": {"STRING": "INT"}}, - "index": 29, - "id_with_tags": "c32f21159377d5dc3ddc664fe5cabbe7b275eadc82b5f6ed711faa1a988deb4", - "resource_type": "aggregator", - "inputs": {"columns": {"col": "STRING_COLUMN"}, "args": {}}, - }, - }, -} diff --git a/pkg/workloads/spark_job/test/unit/spark_util_test.py b/pkg/workloads/spark_job/test/unit/spark_util_test.py index a95b95085a..b5e02ff950 100644 --- a/pkg/workloads/spark_job/test/unit/spark_util_test.py +++ b/pkg/workloads/spark_job/test/unit/spark_util_test.py @@ -16,7 +16,7 @@ import spark_util import consts from lib.exceptions import UserException - +from lib import util import pytest from pyspark.sql.types import * from pyspark.sql import Row @@ -28,13 +28,16 @@ pytestmark = pytest.mark.usefixtures("spark") +def add_res_ref(input): + return util.resource_escape_seq_raw + input + def test_read_csv_valid(spark, write_csv_file, ctx_obj, get_context): csv_str = "\n".join(["a,0.1,", "b,1,1", "c,1.1,4"]) path_to_file = write_csv_file(csv_str) ctx_obj["environment"] = { - "data": {"type": "csv", "path": path_to_file, "schema": ["a_str", "b_float", "c_long"]} + "data": {"type": "csv", "path": path_to_file, "schema": [add_res_ref("a_str"), add_res_ref("b_float"), add_res_ref("c_long")]} } ctx_obj["raw_columns"] = { @@ -52,7 +55,7 @@ def test_read_csv_invalid_type(spark, write_csv_file, ctx_obj, get_context): path_to_file = write_csv_file(csv_str) ctx_obj["environment"] = { - "data": {"type": "csv", "path": path_to_file, "schema": ["a_str", "b_long", "c_long"]} + "data": {"type": "csv", "path": path_to_file, "schema": [add_res_ref("a_str"), add_res_ref("b_long"), add_res_ref("c_long")]} } ctx_obj["raw_columns"] = { @@ -69,7 +72,7 @@ def test_read_csv_infer_type(spark, write_csv_file, ctx_obj, get_context): test_cases = [ { "csv": ["a,0.1,", "b,0.1,1", "c,1.1,4"], - "schema": ["a_str", "b_float", "c_long"], + "schema": [add_res_ref("a_str"), add_res_ref("b_float"), add_res_ref("c_long")], "raw_columns": { "a_str": {"name": "a_str", "type": "INFERRED_COLUMN", "required": True, "id": "-"}, "b_float": { @@ -89,7 +92,7 @@ def test_read_csv_infer_type(spark, write_csv_file, ctx_obj, get_context): }, { "csv": ["1,4,4.5", "1,3,1.2", "1,5,4.7"], - "schema": ["a_str", "b_int", "c_float"], + "schema": [add_res_ref("a_str"), add_res_ref("b_int"), add_res_ref("c_float")], "raw_columns": { "a_str": {"name": "a_str", "type": "STRING_COLUMN", "required": True, "id": "-"}, "b_int": {"name": "b_int", "type": "INFERRED_COLUMN", "required": True, "id": "-"}, @@ -104,7 +107,7 @@ def test_read_csv_infer_type(spark, write_csv_file, ctx_obj, get_context): }, { "csv": ["1,4,2017-09-16", "1,3,2017-09-16", "1,5,2017-09-16"], - "schema": ["a_str", "b_int", "c_str"], + "schema": [add_res_ref("a_str"), add_res_ref("b_int"), add_res_ref("c_str")], "raw_columns": { "a_str": {"name": "a_str", "type": "STRING_COLUMN", "required": True, "id": "-"}, "b_int": {"name": "b_int", "type": "INFERRED_COLUMN", "required": True, "id": "-"}, @@ -114,7 +117,7 @@ def test_read_csv_infer_type(spark, write_csv_file, ctx_obj, get_context): }, { "csv": ["1,4,2017-09-16", "1,3,2017-09-16", "1,5,2017-09-16"], - "schema": ["a_float", "b_int", "c_str"], + "schema": [add_res_ref("a_float"), add_res_ref("b_int"), add_res_ref("c_str")], "raw_columns": { "a_float": {"name": "a_float", "type": "FLOAT_COLUMN", "required": True, "id": "-"}, "b_int": {"name": "b_int", "type": "INFERRED_COLUMN", "required": True, "id": "-"}, @@ -145,7 +148,7 @@ def test_read_csv_infer_invalid(spark, write_csv_file, ctx_obj, get_context): test_cases = [ { "csv": ["a,0.1,", "a,0.1,1", "a,1.1,4"], - "schema": ["a_int", "b_float", "c_long"], + "schema": [add_res_ref("a_int"), add_res_ref("b_float"), add_res_ref("c_long")], "raw_columns": { "a_int": {"name": "a_int", "type": "INT_COLUMN", "required": True, "id": "-"}, "b_float": { @@ -164,7 +167,7 @@ def test_read_csv_infer_invalid(spark, write_csv_file, ctx_obj, get_context): }, { "csv": ["a,1.1,", "a,1.1,1", "a,1.1,4"], - "schema": ["a_str", "b_int", "c_int"], + "schema": [add_res_ref("a_str"), add_res_ref("b_int"), add_res_ref("c_int")], "raw_columns": { "a_str": {"name": "a_str", "type": "INFERRED_COLUMN", "required": True, "id": "-"}, "b_int": {"name": "b_int", "type": "INT_COLUMN", "required": True, "id": "-"}, @@ -193,7 +196,7 @@ def test_read_csv_missing_column(spark, write_csv_file, ctx_obj, get_context): path_to_file = write_csv_file(csv_str) ctx_obj["environment"] = { - "data": {"type": "csv", "path": path_to_file, "schema": ["a_str", "b_long", "c_long"]} + "data": {"type": "csv", "path": path_to_file, "schema": [add_res_ref("a_str"), add_res_ref("b_long"), add_res_ref("c_long")]} } ctx_obj["raw_columns"] = { @@ -221,7 +224,7 @@ def test_read_csv_valid_options(spark, write_csv_file, ctx_obj, get_context): "data": { "type": "csv", "path": path_to_file, - "schema": ["a_str", "b_float", "c_long"], + "schema": [add_res_ref("a_str"), add_res_ref("b_float"), add_res_ref("c_long")], "csv_config": { "header": True, "sep": "|", @@ -437,9 +440,9 @@ def test_ingest_parquet_valid(spark, write_parquet_file, ctx_obj, get_context): "type": "parquet", "path": path_to_file, "schema": [ - {"parquet_column_name": "a_str", "raw_column_name": "a_str"}, - {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, - {"parquet_column_name": "c_long", "raw_column_name": "c_long"}, + {"parquet_column_name": "a_str", "raw_column": add_res_ref("a_str")}, + {"parquet_column_name": "b_float", "raw_column": add_res_ref("b_float")}, + {"parquet_column_name": "c_long", "raw_column": add_res_ref("c_long")}, ], } } @@ -473,9 +476,9 @@ def test_ingest_parquet_infer_valid(spark, write_parquet_file, ctx_obj, get_cont ] ), "env": [ - {"parquet_column_name": "a_str", "raw_column_name": "a_str"}, - {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, - {"parquet_column_name": "c_long", "raw_column_name": "c_long"}, + {"parquet_column_name": "a_str", "raw_column": add_res_ref("a_str")}, + {"parquet_column_name": "b_float", "raw_column": add_res_ref("b_float")}, + {"parquet_column_name": "c_long", "raw_column": add_res_ref("c_long")}, ], "raw_columns": { "a_str": {"name": "a_str", "type": "INFERRED_COLUMN", "required": True, "id": "1"}, @@ -508,9 +511,9 @@ def test_ingest_parquet_infer_valid(spark, write_parquet_file, ctx_obj, get_cont ] ), "env": [ - {"parquet_column_name": "a_str", "raw_column_name": "a_str"}, - {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, - {"parquet_column_name": "c_long", "raw_column_name": "c_long"}, + {"parquet_column_name": "a_str", "raw_column": add_res_ref("a_str")}, + {"parquet_column_name": "b_float", "raw_column": add_res_ref("b_float")}, + {"parquet_column_name": "c_long", "raw_column": add_res_ref("c_long")}, ], "raw_columns": { "a_str": {"name": "a_str", "type": "INFERRED_COLUMN", "required": True, "id": "1"}, @@ -547,9 +550,9 @@ def test_ingest_parquet_infer_valid(spark, write_parquet_file, ctx_obj, get_cont ] ), "env": [ - {"parquet_column_name": "a_str", "raw_column_name": "a_str"}, - {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, - {"parquet_column_name": "c_str", "raw_column_name": "c_str"}, + {"parquet_column_name": "a_str", "raw_column": add_res_ref("a_str")}, + {"parquet_column_name": "b_float", "raw_column": add_res_ref("b_float")}, + {"parquet_column_name": "c_str", "raw_column": add_res_ref("c_str")}, ], "raw_columns": { "a_str": {"name": "a_str", "type": "INFERRED_COLUMN", "required": True, "id": "1"}, @@ -581,9 +584,9 @@ def test_ingest_parquet_infer_valid(spark, write_parquet_file, ctx_obj, get_cont ] ), "env": [ - {"parquet_column_name": "a_str", "raw_column_name": "a_str"}, - {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, - {"parquet_column_name": "c_str", "raw_column_name": "c_str"}, + {"parquet_column_name": "a_str", "raw_column": add_res_ref("a_str")}, + {"parquet_column_name": "b_float", "raw_column": add_res_ref("b_float")}, + {"parquet_column_name": "c_str", "raw_column": add_res_ref("c_str")}, ], "raw_columns": { "a_str": {"name": "a_str", "type": "INFERRED_COLUMN", "required": True, "id": "1"}, @@ -637,9 +640,9 @@ def test_read_parquet_infer_invalid(spark, write_parquet_file, ctx_obj, get_cont ] ), "env": [ - {"parquet_column_name": "a_str", "raw_column_name": "a_str"}, - {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, - {"parquet_column_name": "c_long", "raw_column_name": "c_long"}, + {"parquet_column_name": "a_str", "raw_column": add_res_ref("a_str")}, + {"parquet_column_name": "b_float", "raw_column": add_res_ref("b_float")}, + {"parquet_column_name": "c_long", "raw_column": add_res_ref("c_long")}, ], "raw_columns": { "a_str": {"name": "a_str", "type": "INFERRED_COLUMN", "required": True, "id": "1"}, @@ -662,9 +665,9 @@ def test_read_parquet_infer_invalid(spark, write_parquet_file, ctx_obj, get_cont ] ), "env": [ - {"parquet_column_name": "a_str", "raw_column_name": "a_str"}, - {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, - {"parquet_column_name": "c_str", "raw_column_name": "c_str"}, + {"parquet_column_name": "a_str", "raw_column": add_res_ref("a_str")}, + {"parquet_column_name": "b_float", "raw_column": add_res_ref("b_float")}, + {"parquet_column_name": "c_str", "raw_column": add_res_ref("c_str")}, ], "raw_columns": { "a_str": {"name": "a_str", "type": "INFERRED_COLUMN", "required": True, "id": "1"}, @@ -687,9 +690,9 @@ def test_read_parquet_infer_invalid(spark, write_parquet_file, ctx_obj, get_cont ] ), "env": [ - {"parquet_column_name": "a_str", "raw_column_name": "a_str"}, - {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, - {"parquet_column_name": "c_long", "raw_column_name": "c_long"}, + {"parquet_column_name": "a_str", "raw_column": add_res_ref("a_str")}, + {"parquet_column_name": "b_float", "raw_column": add_res_ref("b_float")}, + {"parquet_column_name": "c_long", "raw_column": add_res_ref("c_long")}, ], "raw_columns": { "a_str": {"name": "a_str", "type": "INFERRED_COLUMN", "required": True, "id": "1"}, @@ -712,9 +715,9 @@ def test_read_parquet_infer_invalid(spark, write_parquet_file, ctx_obj, get_cont ] ), "env": [ - {"parquet_column_name": "a_str", "raw_column_name": "a_str"}, - {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, - {"parquet_column_name": "c_long", "raw_column_name": "c_long"}, + {"parquet_column_name": "a_str", "raw_column": add_res_ref("a_str")}, + {"parquet_column_name": "b_float", "raw_column": add_res_ref("b_float")}, + {"parquet_column_name": "c_long", "raw_column": add_res_ref("c_long")}, ], "raw_columns": { "a_str": {"name": "a_str", "type": "INT_COLUMN", "required": True, "id": "1"}, @@ -742,9 +745,9 @@ def test_read_parquet_infer_invalid(spark, write_parquet_file, ctx_obj, get_cont ] ), "env": [ - {"parquet_column_name": "a_str", "raw_column_name": "a_str"}, - {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, - {"parquet_column_name": "c_long", "raw_column_name": "c_long"}, + {"parquet_column_name": "a_str", "raw_column": add_res_ref("a_str")}, + {"parquet_column_name": "b_float", "raw_column": add_res_ref("b_float")}, + {"parquet_column_name": "c_long", "raw_column": add_res_ref("c_long")}, ], "raw_columns": { "a_str": {"name": "a_str", "type": "INT_COLUMN", "required": True, "id": "1"}, @@ -799,10 +802,10 @@ def test_ingest_parquet_extra_cols(spark, write_parquet_file, ctx_obj, get_conte "type": "parquet", "path": path_to_file, "schema": [ - {"parquet_column_name": "a_str", "raw_column_name": "a_str"}, - {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, - {"parquet_column_name": "c_long", "raw_column_name": "c_long"}, - {"parquet_column_name": "d_long", "raw_column_name": "d_long"}, + {"parquet_column_name": "a_str", "raw_column": add_res_ref("a_str")}, + {"parquet_column_name": "b_float", "raw_column": add_res_ref("b_float")}, + {"parquet_column_name": "c_long", "raw_column": add_res_ref("c_long")}, + {"parquet_column_name": "d_long", "raw_column": add_res_ref("d_long")}, ], } } @@ -834,9 +837,9 @@ def test_ingest_parquet_missing_cols(spark, write_parquet_file, ctx_obj, get_con "type": "parquet", "path": path_to_file, "schema": [ - {"parquet_column_name": "a_str", "raw_column_name": "a_str"}, - {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, - {"parquet_column_name": "c_long", "raw_column_name": "c_long"}, + {"parquet_column_name": "a_str", "raw_column": add_res_ref("a_str")}, + {"parquet_column_name": "b_float", "raw_column": add_res_ref("b_float")}, + {"parquet_column_name": "c_long", "raw_column": add_res_ref("c_long")}, ], } } @@ -870,9 +873,9 @@ def test_ingest_parquet_type_mismatch(spark, write_parquet_file, ctx_obj, get_co "type": "parquet", "path": path_to_file, "schema": [ - {"parquet_column_name": "a_str", "raw_column_name": "a_str"}, - {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, - {"parquet_column_name": "c_long", "raw_column_name": "c_long"}, + {"parquet_column_name": "a_str", "raw_column": add_res_ref("a_str")}, + {"parquet_column_name": "b_float", "raw_column": add_res_ref("b_float")}, + {"parquet_column_name": "c_long", "raw_column": add_res_ref("c_long")}, ], } } @@ -908,9 +911,9 @@ def test_ingest_parquet_failed_requirements( "type": "parquet", "path": path_to_file, "schema": [ - {"parquet_column_name": "a_str", "raw_column_name": "a_str"}, - {"parquet_column_name": "b_float", "raw_column_name": "b_float"}, - {"parquet_column_name": "c_long", "raw_column_name": "c_long"}, + {"parquet_column_name": "a_str", "raw_column": add_res_ref("a_str")}, + {"parquet_column_name": "b_float", "raw_column": add_res_ref("b_float")}, + {"parquet_column_name": "c_long", "raw_column": add_res_ref("c_long")}, ], } } @@ -928,43 +931,29 @@ def test_ingest_parquet_failed_requirements( assert validations == {"a_str": [("(a_str IN (a, b))", 1)]} -def test_column_names_to_index(): - sample_columns_input_config = {"b": "b_col", "a": "a_col"} - actual_list, actual_dict = spark_util.column_names_to_index(sample_columns_input_config) - assert (["a_col", "b_col"], {"b": 1, "a": 0}) == (actual_list, actual_dict) - - sample_columns_input_config = {"a": "a_col"} - - actual_list, actual_dict = spark_util.column_names_to_index(sample_columns_input_config) - assert (["a_col"], {"a": 0}) == (actual_list, actual_dict) - - sample_columns_input_config = {"nums": ["a_long", "a_col", "b_col", "b_col"], "a": "a_long"} - - expected_col_list = ["a_col", "a_long", "b_col"] - expected_columns_input_config = {"nums": [1, 0, 2, 2], "a": 1} - actual_list, actual_dict = spark_util.column_names_to_index(sample_columns_input_config) - - assert (expected_col_list, expected_columns_input_config) == (actual_list, actual_dict) - - def test_run_builtin_aggregators_success(spark, ctx_obj, get_context): + ctx_obj["raw_columns"] = { + "a": { + "id": "2", + "name": "a", + "type": "INT_COLUMN" + } + } ctx_obj["aggregators"] = { - "cortex.sum": {"name": "sum", "namespace": "cortex"}, - "cortex.first": {"name": "first", "namespace": "cortex"}, + "cortex.sum_int": { + "name": "sum_int", + "namespace": "cortex", + "input": {"_type": "INT_COLUMN"}, + "output_type": "INT_COLUMN", + } } ctx_obj["aggregates"] = { "sum_a": { "name": "sum_a", "id": "1", - "aggregator": "cortex.sum", - "inputs": {"columns": {"col": "a"}}, - }, - "first_a": { - "id": "2", - "name": "first_a", - "aggregator": "cortex.first", - "inputs": {"columns": {"col": "a"}, "args": {"ignorenulls": "some_constant"}}, - }, + "aggregator": "cortex.sum_int", + "input": add_res_ref("a"), + } } aggregate_list = [v for v in ctx_obj["aggregates"].values()] @@ -976,47 +965,15 @@ def test_run_builtin_aggregators_success(spark, ctx_obj, get_context): df = spark.createDataFrame(data, StructType([StructField("a", LongType())])) spark_util.run_builtin_aggregators(aggregate_list, df, ctx, spark) - calls = [call(6, ctx_obj["aggregates"]["sum_a"]), call(1, ctx_obj["aggregates"]["first_a"])] + calls = [call(6, ctx_obj["aggregates"]["sum_a"])] ctx.store_aggregate_result.assert_has_calls(calls, any_order=True) - ctx.populate_args.assert_called_once_with({"ignorenulls": "some_constant"}) - - -def test_run_builtin_aggregators_error(spark, ctx_obj, get_context): - ctx_obj["aggregators"] = {"cortex.first": {"name": "first", "namespace": "cortex"}} - ctx_obj["aggregates"] = { - "first_a": { - "name": "first_a", - "aggregator": "cortex.first", - "inputs": { - "columns": {"col": "a"}, - "args": {"ignoreNulls": "some_constant"}, # supposed to be ignorenulls - }, - "id": "1", - } - } - - aggregate_list = [v for v in ctx_obj["aggregates"].values()] - - ctx = get_context(ctx_obj) - ctx.store_aggregate_result = MagicMock() - ctx.populate_args = MagicMock(return_value={"ignoreNulls": True}) - - data = [Row(a=None), Row(a=1), Row(a=2), Row(a=3)] - df = spark.createDataFrame(data, StructType([StructField("a", LongType())])) - - with pytest.raises(Exception) as exec_info: - spark_util.run_builtin_aggregators(aggregate_list, df, ctx, spark) - - ctx.store_aggregate_result.assert_not_called() - ctx.populate_args.assert_called_once_with({"ignoreNulls": "some_constant"}) - -def test_infer_type(): - assert spark_util.infer_type(1) == consts.COLUMN_TYPE_INT - assert spark_util.infer_type(1.0) == consts.COLUMN_TYPE_FLOAT - assert spark_util.infer_type("cortex") == consts.COLUMN_TYPE_STRING +def test_infer_python_type(): + assert spark_util.infer_python_type(1) == consts.COLUMN_TYPE_INT + assert spark_util.infer_python_type(1.0) == consts.COLUMN_TYPE_FLOAT + assert spark_util.infer_python_type("cortex") == consts.COLUMN_TYPE_STRING - assert spark_util.infer_type([1]) == consts.COLUMN_TYPE_INT_LIST - assert spark_util.infer_type([1.0]) == consts.COLUMN_TYPE_FLOAT_LIST - assert spark_util.infer_type(["cortex"]) == consts.COLUMN_TYPE_STRING_LIST + assert spark_util.infer_python_type([1]) == consts.COLUMN_TYPE_INT_LIST + assert spark_util.infer_python_type([1.0]) == consts.COLUMN_TYPE_FLOAT_LIST + assert spark_util.infer_python_type(["cortex"]) == consts.COLUMN_TYPE_STRING_LIST From df0ee36d06cb15392003827df5078b0b9572e913 Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Wed, 12 Jun 2019 16:51:20 -0700 Subject: [PATCH 34/44] Fix lint --- pkg/estimators/boosted_trees_classifier.py | 15 +++++++++ pkg/estimators/boosted_trees_regressor.py | 15 +++++++++ pkg/estimators/dnn_classifier.py | 15 +++++++++ .../dnn_linear_combined_classifier.py | 15 +++++++++ .../dnn_linear_combined_regressor.py | 15 +++++++++ pkg/estimators/dnn_regressor.py | 15 +++++++++ pkg/estimators/linear_classifier.py | 15 +++++++++ pkg/estimators/linear_regressor.py | 15 +++++++++ pkg/operator/context/resources.go | 3 +- .../spark_job/test/unit/spark_util_test.py | 31 ++++++++++++------- 10 files changed, 140 insertions(+), 14 deletions(-) diff --git a/pkg/estimators/boosted_trees_classifier.py b/pkg/estimators/boosted_trees_classifier.py index 67594299a3..21a4aa5d8a 100644 --- a/pkg/estimators/boosted_trees_classifier.py +++ b/pkg/estimators/boosted_trees_classifier.py @@ -1,3 +1,18 @@ +# Copyright 2019 Cortex Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import tensorflow as tf diff --git a/pkg/estimators/boosted_trees_regressor.py b/pkg/estimators/boosted_trees_regressor.py index 55881eec4f..aa6d9a9418 100644 --- a/pkg/estimators/boosted_trees_regressor.py +++ b/pkg/estimators/boosted_trees_regressor.py @@ -1,3 +1,18 @@ +# Copyright 2019 Cortex Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import tensorflow as tf diff --git a/pkg/estimators/dnn_classifier.py b/pkg/estimators/dnn_classifier.py index de5bfcc599..410ca0b0e8 100644 --- a/pkg/estimators/dnn_classifier.py +++ b/pkg/estimators/dnn_classifier.py @@ -1,3 +1,18 @@ +# Copyright 2019 Cortex Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import tensorflow as tf diff --git a/pkg/estimators/dnn_linear_combined_classifier.py b/pkg/estimators/dnn_linear_combined_classifier.py index 2857bcc88c..21121a4272 100644 --- a/pkg/estimators/dnn_linear_combined_classifier.py +++ b/pkg/estimators/dnn_linear_combined_classifier.py @@ -1,3 +1,18 @@ +# Copyright 2019 Cortex Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import tensorflow as tf diff --git a/pkg/estimators/dnn_linear_combined_regressor.py b/pkg/estimators/dnn_linear_combined_regressor.py index 14ba51f7a6..25434a7bed 100644 --- a/pkg/estimators/dnn_linear_combined_regressor.py +++ b/pkg/estimators/dnn_linear_combined_regressor.py @@ -1,3 +1,18 @@ +# Copyright 2019 Cortex Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import tensorflow as tf diff --git a/pkg/estimators/dnn_regressor.py b/pkg/estimators/dnn_regressor.py index 4a3cf491cb..b3ad780c6c 100644 --- a/pkg/estimators/dnn_regressor.py +++ b/pkg/estimators/dnn_regressor.py @@ -1,3 +1,18 @@ +# Copyright 2019 Cortex Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import tensorflow as tf diff --git a/pkg/estimators/linear_classifier.py b/pkg/estimators/linear_classifier.py index 70514f84af..3eaf0abb48 100644 --- a/pkg/estimators/linear_classifier.py +++ b/pkg/estimators/linear_classifier.py @@ -1,3 +1,18 @@ +# Copyright 2019 Cortex Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import tensorflow as tf diff --git a/pkg/estimators/linear_regressor.py b/pkg/estimators/linear_regressor.py index 7c990d2829..2c9d7bb9e9 100644 --- a/pkg/estimators/linear_regressor.py +++ b/pkg/estimators/linear_regressor.py @@ -1,3 +1,18 @@ +# Copyright 2019 Cortex Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import tensorflow as tf diff --git a/pkg/operator/context/resources.go b/pkg/operator/context/resources.go index cbab68bca4..465c057263 100644 --- a/pkg/operator/context/resources.go +++ b/pkg/operator/context/resources.go @@ -55,9 +55,8 @@ func ValidateInput( if input == nil { if schema.Optional { return nil, hash.Any(nil), nil - } else { - return nil, "", userconfig.ErrorMustBeDefined(schema) } + return nil, "", userconfig.ErrorMustBeDefined(schema) } castedInput, err = validateRuntimeTypes(input, schema, validResourcesMap, aggregators, transformers, false) diff --git a/pkg/workloads/spark_job/test/unit/spark_util_test.py b/pkg/workloads/spark_job/test/unit/spark_util_test.py index b5e02ff950..9a8bf11233 100644 --- a/pkg/workloads/spark_job/test/unit/spark_util_test.py +++ b/pkg/workloads/spark_job/test/unit/spark_util_test.py @@ -31,13 +31,18 @@ def add_res_ref(input): return util.resource_escape_seq_raw + input + def test_read_csv_valid(spark, write_csv_file, ctx_obj, get_context): csv_str = "\n".join(["a,0.1,", "b,1,1", "c,1.1,4"]) path_to_file = write_csv_file(csv_str) ctx_obj["environment"] = { - "data": {"type": "csv", "path": path_to_file, "schema": [add_res_ref("a_str"), add_res_ref("b_float"), add_res_ref("c_long")]} + "data": { + "type": "csv", + "path": path_to_file, + "schema": [add_res_ref("a_str"), add_res_ref("b_float"), add_res_ref("c_long")], + } } ctx_obj["raw_columns"] = { @@ -55,7 +60,11 @@ def test_read_csv_invalid_type(spark, write_csv_file, ctx_obj, get_context): path_to_file = write_csv_file(csv_str) ctx_obj["environment"] = { - "data": {"type": "csv", "path": path_to_file, "schema": [add_res_ref("a_str"), add_res_ref("b_long"), add_res_ref("c_long")]} + "data": { + "type": "csv", + "path": path_to_file, + "schema": [add_res_ref("a_str"), add_res_ref("b_long"), add_res_ref("c_long")], + } } ctx_obj["raw_columns"] = { @@ -196,7 +205,11 @@ def test_read_csv_missing_column(spark, write_csv_file, ctx_obj, get_context): path_to_file = write_csv_file(csv_str) ctx_obj["environment"] = { - "data": {"type": "csv", "path": path_to_file, "schema": [add_res_ref("a_str"), add_res_ref("b_long"), add_res_ref("c_long")]} + "data": { + "type": "csv", + "path": path_to_file, + "schema": [add_res_ref("a_str"), add_res_ref("b_long"), add_res_ref("c_long")], + } } ctx_obj["raw_columns"] = { @@ -932,17 +945,11 @@ def test_ingest_parquet_failed_requirements( def test_run_builtin_aggregators_success(spark, ctx_obj, get_context): - ctx_obj["raw_columns"] = { - "a": { - "id": "2", - "name": "a", - "type": "INT_COLUMN" - } - } + ctx_obj["raw_columns"] = {"a": {"id": "2", "name": "a", "type": "INT_COLUMN"}} ctx_obj["aggregators"] = { "cortex.sum_int": { - "name": "sum_int", - "namespace": "cortex", + "name": "sum_int", + "namespace": "cortex", "input": {"_type": "INT_COLUMN"}, "output_type": "INT_COLUMN", } From 588c28cff476da4459c77e129a773d05c2dc8c87 Mon Sep 17 00:00:00 2001 From: Omer Spillinger Date: Wed, 12 Jun 2019 17:59:38 -0700 Subject: [PATCH 35/44] Update examples --- examples/fraud/implementations/models/dnn.py | 16 ----- .../implementations/transformers/weight.py | 6 +- examples/fraud/resources/apis.yaml | 2 +- examples/fraud/resources/dnn.yaml | 19 ----- examples/fraud/resources/models.yaml | 19 +++++ examples/fraud/resources/normalize.yaml | 26 ------- .../fraud/resources/normalized_columns.yaml | 21 ++++++ examples/fraud/resources/raw_columns.yaml | 2 +- examples/fraud/resources/weight_column.yaml | 23 ++----- .../insurance/implementations/models/dnn.py | 40 ----------- .../{environments.yaml => data.yaml} | 0 examples/insurance/resources/models.yaml | 27 ++++---- examples/poker/resources/apis.yaml | 2 +- examples/poker/resources/data.yaml | 6 ++ examples/poker/resources/environments.yaml | 17 ----- examples/poker/resources/models.yaml | 38 ++++++---- examples/poker/resources/raw_columns.yaml | 54 --------------- examples/wine/implementations/models/dnn.py | 15 ---- examples/wine/resources/apis.yaml | 2 +- examples/wine/resources/data.yaml | 9 +++ examples/wine/resources/models.yaml | 8 ++- examples/wine/resources/normalize.yaml | 26 ------- .../wine/resources/normalized_columns.yaml | 21 ++++++ examples/wine/resources/quality.yaml | 8 +-- examples/wine/resources/raw_columns.yaml | 69 ------------------- 25 files changed, 136 insertions(+), 340 deletions(-) delete mode 100644 examples/fraud/implementations/models/dnn.py delete mode 100644 examples/fraud/resources/dnn.yaml create mode 100644 examples/fraud/resources/models.yaml delete mode 100644 examples/fraud/resources/normalize.yaml delete mode 100644 examples/insurance/implementations/models/dnn.py rename examples/insurance/resources/{environments.yaml => data.yaml} (100%) create mode 100644 examples/poker/resources/data.yaml delete mode 100644 examples/poker/resources/environments.yaml delete mode 100644 examples/poker/resources/raw_columns.yaml delete mode 100644 examples/wine/implementations/models/dnn.py create mode 100644 examples/wine/resources/data.yaml delete mode 100644 examples/wine/resources/normalize.yaml delete mode 100644 examples/wine/resources/raw_columns.yaml diff --git a/examples/fraud/implementations/models/dnn.py b/examples/fraud/implementations/models/dnn.py deleted file mode 100644 index c4bca40016..0000000000 --- a/examples/fraud/implementations/models/dnn.py +++ /dev/null @@ -1,16 +0,0 @@ -import tensorflow as tf - - -def create_estimator(run_config, model_config): - feature_columns = [ - tf.feature_column.numeric_column(feature_column["name"]) - for feature_column in model_config["feature_columns"] - ] - - return tf.estimator.DNNClassifier( - feature_columns=feature_columns, - hidden_units=model_config["hparams"]["hidden_units"], - n_classes=2, - weight_column="weight_column", - config=run_config, - ) diff --git a/examples/fraud/implementations/transformers/weight.py b/examples/fraud/implementations/transformers/weight.py index ca81863ccf..f3ce6027cd 100644 --- a/examples/fraud/implementations/transformers/weight.py +++ b/examples/fraud/implementations/transformers/weight.py @@ -1,9 +1,9 @@ -def transform_spark(data, columns, args, transformed_column_name): +def transform_spark(data, input, transformed_column_name): import pyspark.sql.functions as F - distribution = args["class_distribution"] + distribution = input["class_distribution"] return data.withColumn( transformed_column_name, - F.when(data[columns["col"]] == 0, distribution[1]).otherwise(distribution[0]), + F.when(data[input["col"]] == 0, distribution[1]).otherwise(distribution[0]), ) diff --git a/examples/fraud/resources/apis.yaml b/examples/fraud/resources/apis.yaml index 8d98858464..6c67398ae1 100644 --- a/examples/fraud/resources/apis.yaml +++ b/examples/fraud/resources/apis.yaml @@ -1,5 +1,5 @@ - kind: api name: fraud - model_name: dnn + model: @dnn compute: replicas: 1 diff --git a/examples/fraud/resources/dnn.yaml b/examples/fraud/resources/dnn.yaml deleted file mode 100644 index 87ca86009e..0000000000 --- a/examples/fraud/resources/dnn.yaml +++ /dev/null @@ -1,19 +0,0 @@ -- kind: model - name: dnn - type: classification - target_column: class - feature_columns: - [time_normalized, v1_normalized, v2_normalized, v3_normalized, v4_normalized, - v5_normalized, v6_normalized, v7_normalized, v8_normalized, v9_normalized, - v10_normalized, v11_normalized, v12_normalized, v13_normalized, v14_normalized, - v15_normalized, v16_normalized, v17_normalized, v18_normalized, v19_normalized, - v20_normalized, v21_normalized, v22_normalized, v23_normalized, v24_normalized, - v25_normalized, v26_normalized, v27_normalized, v28_normalized, amount_normalized] - training_columns: [weight_column] - hparams: - hidden_units: [100, 100, 100] - data_partition_ratio: - training: 0.8 - evaluation: 0.2 - training: - num_steps: 5000 diff --git a/examples/fraud/resources/models.yaml b/examples/fraud/resources/models.yaml new file mode 100644 index 0000000000..9330309dce --- /dev/null +++ b/examples/fraud/resources/models.yaml @@ -0,0 +1,19 @@ +- kind: model + name: dnn + estimator: cortex.dnn_classifier + target_column: @class + input: + num_classes: 2 + numeric_columns: + [@time_normalized, @v1_normalized, @v2_normalized, @v3_normalized, @v4_normalized, + @v5_normalized, @v6_normalized, @v7_normalized, @v8_normalized, @v9_normalized, + @v10_normalized, @v11_normalized, @v12_normalized, @v13_normalized, @v14_normalized, + @v15_normalized, @v16_normalized, @v17_normalized, @v18_normalized, @v19_normalized, + @v20_normalized, @v21_normalized, @v22_normalized, @v23_normalized, @v24_normalized, + @v25_normalized, @v26_normalized, @v27_normalized, @v28_normalized, @amount_normalized] + training_input: + weight_column: @weight_column + hparams: + hidden_units: [100, 100, 100] + training: + num_steps: 5000 diff --git a/examples/fraud/resources/normalize.yaml b/examples/fraud/resources/normalize.yaml deleted file mode 100644 index 88848bc7b2..0000000000 --- a/examples/fraud/resources/normalize.yaml +++ /dev/null @@ -1,26 +0,0 @@ -- kind: template - name: normalize - yaml: | - - kind: aggregate - name: {column}_mean - aggregator: cortex.mean - inputs: - columns: - col: {column} - - - kind: aggregate - name: {column}_stddev - aggregator: cortex.stddev - inputs: - columns: - col: {column} - - - kind: transformed_column - name: {column}_normalized - transformer: cortex.normalize - inputs: - columns: - num: {column} - args: - mean: {column}_mean - stddev: {column}_stddev diff --git a/examples/fraud/resources/normalized_columns.yaml b/examples/fraud/resources/normalized_columns.yaml index 2e1975ab3c..1aa0580355 100644 --- a/examples/fraud/resources/normalized_columns.yaml +++ b/examples/fraud/resources/normalized_columns.yaml @@ -1,3 +1,24 @@ +- kind: template + name: normalize + yaml: | + - kind: aggregate + name: {column}_mean + aggregator: cortex.mean + input: @{column} + + - kind: aggregate + name: {column}_stddev + aggregator: cortex.stddev + input: @{column} + + - kind: transformed_column + name: {column}_normalized + transformer: cortex.normalize + input: + col: @{column} + mean: @{column}_mean + stddev: @{column}_stddev + - kind: embed template: normalize args: diff --git a/examples/fraud/resources/raw_columns.yaml b/examples/fraud/resources/raw_columns.yaml index 9ed537b832..ba87addd46 100644 --- a/examples/fraud/resources/raw_columns.yaml +++ b/examples/fraud/resources/raw_columns.yaml @@ -5,7 +5,7 @@ path: s3a://cortex-examples/fraud.csv csv_config: header: true - schema: [time, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, amount, class] + schema: [@time, @v1, @v2, @v3, @v4, @v5, @v6, @v7, @v8, @v9, @v10, @v11, @v12, @v13, @v14, @v15, @v16, @v17, @v18, @v19, @v20, @v21, @v22, @v23, @v24, @v25, @v26, @v27, @v28, @amount, @class] - kind: raw_column name: time diff --git a/examples/fraud/resources/weight_column.yaml b/examples/fraud/resources/weight_column.yaml index a6f87aa58b..a5112dd8c5 100644 --- a/examples/fraud/resources/weight_column.yaml +++ b/examples/fraud/resources/weight_column.yaml @@ -1,24 +1,11 @@ - kind: aggregate name: class_distribution aggregator: cortex.class_distribution_int - inputs: - columns: - col: class - -- kind: transformer - name: weight - inputs: - columns: - col: INT_COLUMN - args: - class_distribution: {INT: FLOAT} - output_type: FLOAT_COLUMN + input: @class - kind: transformed_column name: weight_column - transformer: weight - inputs: - columns: - col: class - args: - class_distribution: class_distribution + transformer_path: implementations/transformers/weight.py + input: + col: @class + class_distribution: @class_distribution diff --git a/examples/insurance/implementations/models/dnn.py b/examples/insurance/implementations/models/dnn.py deleted file mode 100644 index 3c5a2ccec0..0000000000 --- a/examples/insurance/implementations/models/dnn.py +++ /dev/null @@ -1,40 +0,0 @@ -import tensorflow as tf - - -def create_estimator(run_config, model_config): - aggregates = model_config["input"]["aggregates"] - - feature_columns = [ - tf.feature_column.indicator_column( - tf.feature_column.categorical_column_with_vocabulary_list( - "sex", aggregates["sex_vocab"] - ) - ), - tf.feature_column.indicator_column( - tf.feature_column.categorical_column_with_vocabulary_list( - "smoker", aggregates["smoker_vocab"] - ) - ), - tf.feature_column.indicator_column( - tf.feature_column.categorical_column_with_vocabulary_list( - "region", aggregates["region_vocab"] - ) - ), - tf.feature_column.bucketized_column( - tf.feature_column.numeric_column("age"), aggregates["age_buckets"] - ), - tf.feature_column.bucketized_column( - tf.feature_column.numeric_column("bmi"), aggregates["bmi_buckets"] - ), - tf.feature_column.indicator_column( - tf.feature_column.categorical_column_with_vocabulary_list( - "children", aggregates["children_set"] - ) - ), - ] - - return tf.estimator.DNNRegressor( - feature_columns=feature_columns, - hidden_units=model_config["hparams"]["hidden_units"], - config=run_config, - ) diff --git a/examples/insurance/resources/environments.yaml b/examples/insurance/resources/data.yaml similarity index 100% rename from examples/insurance/resources/environments.yaml rename to examples/insurance/resources/data.yaml diff --git a/examples/insurance/resources/models.yaml b/examples/insurance/resources/models.yaml index 76ad5d1119..d3f4646df5 100644 --- a/examples/insurance/resources/models.yaml +++ b/examples/insurance/resources/models.yaml @@ -1,20 +1,23 @@ - kind: model name: dnn - estimator_path: implementations/models/dnn.py + estimator: cortex.dnn_regressor target_column: @charges_normalized input: - features: [@age, @sex, @bmi, @children, @smoker, @region, @charges] - aggregates: - children_set: @children_set - region_vocab: ["northwest", "northeast", "southwest", "southeast"] - age_buckets: [15, 20, 25, 35, 40, 45, 50, 55, 60, 65] - bmi_buckets: [15, 20, 25, 35, 40, 45, 50, 55] - smoker_vocab: ["yes", "no"] - sex_vocab: ["female", "male"] + categorical_columns_with_vocab: + - col: @sex + vocab: ['female', 'male'] + - col: @smoker + vocab: ['yes', 'no'] + - col: @region + vocab: ['northwest', 'northeast', 'southwest', 'southeast'] + - col: @children + vocab: @children_set + bucketized_columns: + - col: @age + boundaries: [15, 20, 25, 35, 40, 45, 50, 55, 60, 65] + - col: @bmi + boundaries: [15, 20, 25, 35, 40, 45, 50, 55] hparams: hidden_units: [100, 100, 100] - data_partition_ratio: - training: 0.8 - evaluation: 0.2 training: num_steps: 10000 diff --git a/examples/poker/resources/apis.yaml b/examples/poker/resources/apis.yaml index 0f0dd10752..b3a9458458 100644 --- a/examples/poker/resources/apis.yaml +++ b/examples/poker/resources/apis.yaml @@ -1,5 +1,5 @@ - kind: api name: hand - model_name: dnn + model: @dnn compute: replicas: 1 diff --git a/examples/poker/resources/data.yaml b/examples/poker/resources/data.yaml new file mode 100644 index 0000000000..743a525a4e --- /dev/null +++ b/examples/poker/resources/data.yaml @@ -0,0 +1,6 @@ +- kind: environment + name: dev + data: + type: csv + path: s3a://cortex-examples/poker.csv + schema: [@card_1_suit, @card_1_rank, @card_2_suit, @card_2_rank, @card_3_suit, @card_3_rank, @card_4_suit, @card_4_rank, @card_5_suit, @card_5_rank, @class] diff --git a/examples/poker/resources/environments.yaml b/examples/poker/resources/environments.yaml deleted file mode 100644 index 4e5f8f7e44..0000000000 --- a/examples/poker/resources/environments.yaml +++ /dev/null @@ -1,17 +0,0 @@ -- kind: environment - name: dev - data: - type: csv - path: s3a://cortex-examples/poker.csv - schema: - - card_1_suit - - card_1_rank - - card_2_suit - - card_2_rank - - card_3_suit - - card_3_rank - - card_4_suit - - card_4_rank - - card_5_suit - - card_5_rank - - class diff --git a/examples/poker/resources/models.yaml b/examples/poker/resources/models.yaml index 53be744036..87c6101087 100644 --- a/examples/poker/resources/models.yaml +++ b/examples/poker/resources/models.yaml @@ -1,18 +1,30 @@ - kind: model name: dnn - type: classification - target_column: class - feature_columns: - - card_1_suit - - card_1_rank - - card_2_suit - - card_2_rank - - card_3_suit - - card_3_rank - - card_4_suit - - card_4_rank - - card_5_suit - - card_5_rank + estimator: cortex.dnn_classifier + target_column: @class + input: + num_classes: 10 + categorical_columns_with_vocab: + - col: @card_1_suit + vocab: [1, 2, 3, 4] + - col: @card_1_rank + vocab: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + - col: @card_2_suit + vocab: [1, 2, 3, 4] + - col: @card_2_rank + vocab: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + - col: @card_3_suit + vocab: [1, 2, 3, 4] + - col: @card_3_rank + vocab: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + - col: @card_4_suit + vocab: [1, 2, 3, 4] + - col: @card_4_rank + vocab: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + - col: @card_5_suit + vocab: [1, 2, 3, 4] + - col: @card_5_rank + vocab: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] hparams: hidden_units: [100, 100, 100, 100, 100] data_partition_ratio: diff --git a/examples/poker/resources/raw_columns.yaml b/examples/poker/resources/raw_columns.yaml deleted file mode 100644 index ac3d6d2096..0000000000 --- a/examples/poker/resources/raw_columns.yaml +++ /dev/null @@ -1,54 +0,0 @@ -- kind: raw_column - name: card_1_suit - type: INT_COLUMN - required: true - -- kind: raw_column - name: card_1_rank - type: INT_COLUMN - required: true - -- kind: raw_column - name: card_2_suit - type: INT_COLUMN - required: true - -- kind: raw_column - name: card_2_rank - type: INT_COLUMN - required: true - -- kind: raw_column - name: card_3_suit - type: INT_COLUMN - required: true - -- kind: raw_column - name: card_3_rank - type: INT_COLUMN - required: true - -- kind: raw_column - name: card_4_suit - type: INT_COLUMN - required: true - -- kind: raw_column - name: card_4_rank - type: INT_COLUMN - required: true - -- kind: raw_column - name: card_5_suit - type: INT_COLUMN - required: true - -- kind: raw_column - name: card_5_rank - type: INT_COLUMN - required: true - -- kind: raw_column - name: class - type: INT_COLUMN - required: true diff --git a/examples/wine/implementations/models/dnn.py b/examples/wine/implementations/models/dnn.py deleted file mode 100644 index e3bd25e52f..0000000000 --- a/examples/wine/implementations/models/dnn.py +++ /dev/null @@ -1,15 +0,0 @@ -import tensorflow as tf - - -def create_estimator(run_config, model_config): - feature_columns = [ - tf.feature_column.numeric_column(feature_column["name"]) - for feature_column in model_config["feature_columns"] - ] - - return tf.estimator.DNNClassifier( - feature_columns=feature_columns, - hidden_units=model_config["hparams"]["hidden_units"], - n_classes=2, - config=run_config, - ) diff --git a/examples/wine/resources/apis.yaml b/examples/wine/resources/apis.yaml index 2a051ba435..2c44a8ffce 100644 --- a/examples/wine/resources/apis.yaml +++ b/examples/wine/resources/apis.yaml @@ -1,5 +1,5 @@ - kind: api name: quality - model_name: dnn + model: @dnn compute: replicas: 1 diff --git a/examples/wine/resources/data.yaml b/examples/wine/resources/data.yaml new file mode 100644 index 0000000000..bafdd363b5 --- /dev/null +++ b/examples/wine/resources/data.yaml @@ -0,0 +1,9 @@ +- kind: environment + name: dev + data: + type: csv + path: s3a://cortex-examples/wine-quality.csv + csv_config: + header: true + sep: ';' + schema: [@fixed_acidity, @volatile_acidity, @citric_acid, @residual_sugar, @chlorides, @free_sulfur_dioxide, @total_sulfur_dioxide, @density, @pH, @sulphates, @alcohol, @quality] diff --git a/examples/wine/resources/models.yaml b/examples/wine/resources/models.yaml index b48ab8b9ab..2963cefa78 100644 --- a/examples/wine/resources/models.yaml +++ b/examples/wine/resources/models.yaml @@ -1,8 +1,10 @@ - kind: model name: dnn - type: classification - target_column: quality_bucketized - feature_columns: [fixed_acidity_normalized, volatile_acidity_normalized, citric_acid_normalized, residual_sugar_normalized, chlorides_normalized, free_sulfur_dioxide_normalized, total_sulfur_dioxide_normalized, density_normalized, pH_normalized, sulphates_normalized, alcohol_normalized] + estimator: cortex.dnn_classifier + target_column: @quality_bucketized + input: + num_classes: 2 + numeric_columns: [@fixed_acidity, @volatile_acidity, @citric_acid, @residual_sugar, @chlorides, @free_sulfur_dioxide, @total_sulfur_dioxide, @density, @pH, @sulphates, @alcohol] hparams: hidden_units: [100, 100, 100, 100, 100] data_partition_ratio: diff --git a/examples/wine/resources/normalize.yaml b/examples/wine/resources/normalize.yaml deleted file mode 100644 index 88848bc7b2..0000000000 --- a/examples/wine/resources/normalize.yaml +++ /dev/null @@ -1,26 +0,0 @@ -- kind: template - name: normalize - yaml: | - - kind: aggregate - name: {column}_mean - aggregator: cortex.mean - inputs: - columns: - col: {column} - - - kind: aggregate - name: {column}_stddev - aggregator: cortex.stddev - inputs: - columns: - col: {column} - - - kind: transformed_column - name: {column}_normalized - transformer: cortex.normalize - inputs: - columns: - num: {column} - args: - mean: {column}_mean - stddev: {column}_stddev diff --git a/examples/wine/resources/normalized_columns.yaml b/examples/wine/resources/normalized_columns.yaml index d2fa7d0cae..75f4c7f952 100644 --- a/examples/wine/resources/normalized_columns.yaml +++ b/examples/wine/resources/normalized_columns.yaml @@ -1,3 +1,24 @@ +- kind: template + name: normalize + yaml: | + - kind: aggregate + name: {column}_mean + aggregator: cortex.mean + input: @{column} + + - kind: aggregate + name: {column}_stddev + aggregator: cortex.stddev + input: @{column} + + - kind: transformed_column + name: {column}_normalized + transformer: cortex.normalize + input: + col: @{column} + mean: @{column}_mean + stddev: @{column}_stddev + - kind: embed template: normalize args: diff --git a/examples/wine/resources/quality.yaml b/examples/wine/resources/quality.yaml index 9b6a788e53..0d3bb0214b 100644 --- a/examples/wine/resources/quality.yaml +++ b/examples/wine/resources/quality.yaml @@ -1,8 +1,6 @@ - kind: transformed_column name: quality_bucketized transformer: cortex.bucketize - inputs: - columns: - num: quality - args: - bucket_boundaries: [0, 5, 10] + input: + col: @quality + bucket_boundaries: [0, 5, 10] diff --git a/examples/wine/resources/raw_columns.yaml b/examples/wine/resources/raw_columns.yaml deleted file mode 100644 index 7a67758fca..0000000000 --- a/examples/wine/resources/raw_columns.yaml +++ /dev/null @@ -1,69 +0,0 @@ -- kind: environment - name: dev - data: - type: csv - path: s3a://cortex-examples/wine-quality.csv - csv_config: - header: true - sep: ';' - schema: [fixed_acidity, volatile_acidity, citric_acid, residual_sugar, chlorides, free_sulfur_dioxide, total_sulfur_dioxide, density, pH, sulphates, alcohol, quality] - -- kind: raw_column - name: fixed_acidity - type: FLOAT_COLUMN - required: true - -- kind: raw_column - name: volatile_acidity - type: FLOAT_COLUMN - required: true - -- kind: raw_column - name: citric_acid - type: FLOAT_COLUMN - required: true - -- kind: raw_column - name: residual_sugar - type: FLOAT_COLUMN - required: true - -- kind: raw_column - name: chlorides - type: FLOAT_COLUMN - required: true - -- kind: raw_column - name: free_sulfur_dioxide - type: FLOAT_COLUMN - required: true - -- kind: raw_column - name: total_sulfur_dioxide - type: FLOAT_COLUMN - required: true - -- kind: raw_column - name: density - type: FLOAT_COLUMN - required: true - -- kind: raw_column - name: pH - type: FLOAT_COLUMN - required: true - -- kind: raw_column - name: sulphates - type: FLOAT_COLUMN - required: true - -- kind: raw_column - name: alcohol - type: FLOAT_COLUMN - required: true - -- kind: raw_column - name: quality - type: INT_COLUMN - required: true From 3187b0a3d1b6dcdc43eeceb93616f8f6f47c1ca3 Mon Sep 17 00:00:00 2001 From: Omer Spillinger Date: Wed, 12 Jun 2019 17:59:41 -0700 Subject: [PATCH 36/44] Update tutorial.md --- docs/tutorial.md | 446 ++--------------------------------------------- 1 file changed, 10 insertions(+), 436 deletions(-) diff --git a/docs/tutorial.md b/docs/tutorial.md index e45e8d4330..cebf3e2493 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -46,187 +46,14 @@ Add to `app.yaml`: data: type: csv path: s3a://cortex-examples/iris.csv - schema: [sepal_length, sepal_width, petal_length, petal_width, class] + schema: [@sepal_length, @sepal_width, @petal_length, @petal_width, @class] ``` Cortex will be able to read from any S3 bucket that your AWS credentials grant access to. -#### Define raw columns - -The iris data set consists of four attributes and a label. We ensure that the data matches the types we expect, the numerical data is within a reasonable range, and the class labels are within the set of expected strings. - -Add to `app.yaml`: - -```yaml -# Raw Columns - -- kind: raw_column - name: sepal_length - type: FLOAT_COLUMN - min: 0 - max: 10 - -- kind: raw_column - name: sepal_width - type: FLOAT_COLUMN - min: 0 - max: 10 - -- kind: raw_column - name: petal_length - type: FLOAT_COLUMN - min: 0 - max: 10 - -- kind: raw_column - name: petal_width - type: FLOAT_COLUMN - min: 0 - max: 10 - -- kind: raw_column - name: class - type: STRING_COLUMN - values: ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'] -``` - -#### Define aggregates - -Aggregates are computations that require processing a full column of data. We want to normalize the numeric columns, so we need mean and standard deviation values for each numeric column. We also need a mapping of strings to integers for the label column. Here we use the built-in `mean`, `stddev`, and `index_string` aggregators. - -Add to `app.yaml`: - -```yaml -# Aggregates - -- kind: aggregate - name: sepal_length_mean - aggregator: cortex.mean - inputs: - columns: - col: sepal_length - -- kind: aggregate - name: sepal_length_stddev - aggregator: cortex.stddev - inputs: - columns: - col: sepal_length - -- kind: aggregate - name: sepal_width_mean - aggregator: cortex.mean - inputs: - columns: - col: sepal_width - -- kind: aggregate - name: sepal_width_stddev - aggregator: cortex.stddev - inputs: - columns: - col: sepal_width - -- kind: aggregate - name: petal_length_mean - aggregator: cortex.mean - inputs: - columns: - col: petal_length - -- kind: aggregate - name: petal_length_stddev - aggregator: cortex.stddev - inputs: - columns: - col: petal_length - -- kind: aggregate - name: petal_width_mean - aggregator: cortex.mean - inputs: - columns: - col: petal_width - -- kind: aggregate - name: petal_width_stddev - aggregator: cortex.stddev - inputs: - columns: - col: petal_width - -- kind: aggregate - name: class_index - aggregator: cortex.index_string - inputs: - columns: - col: class -``` - -#### Define transformed columns - -Transformers convert the raw columns into the appropriate inputs for a TensorFlow Estimator. Here we use the built-in `normalize` and `index_string` transformers using the aggregates we computed earlier. - -Add to `app.yaml`: - -```yaml -# Transformed Columns - -- kind: transformed_column - name: sepal_length_normalized - transformer: cortex.normalize - inputs: - columns: - num: sepal_length - args: - mean: sepal_length_mean - stddev: sepal_length_stddev - -- kind: transformed_column - name: sepal_width_normalized - transformer: cortex.normalize - inputs: - columns: - num: sepal_width - args: - mean: sepal_width_mean - stddev: sepal_width_stddev - -- kind: transformed_column - name: petal_length_normalized - transformer: cortex.normalize - inputs: - columns: - num: petal_length - args: - mean: petal_length_mean - stddev: petal_length_stddev - -- kind: transformed_column - name: petal_width_normalized - transformer: cortex.normalize - inputs: - columns: - num: petal_width - args: - mean: petal_width_mean - stddev: petal_width_stddev - -- kind: transformed_column - name: class_indexed - transformer: cortex.index_string - inputs: - columns: - text: class - args: - indexes: class_index -``` - -You can simplify the configuration for aggregates and transformed columns using [templates](applications/advanced/templates.md). - #### Define the model -This configuration will generate a training dataset with the specified columns and train our classifier using the generated dataset. +This configuration will generate a training dataset with the specified columns and train our classifier using the generated dataset. Here we're using TensorFlow's [DNNClassifier](https://www.tensorflow.org/api_docs/python/tf/estimator/DNNClassifier) but Cortex supports any TensorFlow code that adheres to the [tf.estimator API](https://www.tensorflow.org/guide/estimators). Add to `app.yaml`: @@ -235,48 +62,18 @@ Add to `app.yaml`: - kind: model name: dnn - path: dnn.py - type: classification - target_column: class_indexed - feature_columns: [sepal_length_normalized, sepal_width_normalized, petal_length_normalized, petal_width_normalized] + estimator: cortex.dnn_classifier + target_column: @class + input: + numeric_columns: [@sepal_length, @sepal_width, @petal_length, @petal_width] + target_vocab: ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'] hparams: hidden_units: [4, 2] - data_partition_ratio: - training: 80 - evaluation: 20 training: - num_steps: 1000 batch_size: 10 - aggregates: [class_index] -``` - -#### Implement the Estimator - -Define an Estimator in `dnn.py`: - -```python -import tensorflow as tf - - -def create_estimator(run_config, model_config): - feature_columns = [ - tf.feature_column.numeric_column("sepal_length_normalized"), - tf.feature_column.numeric_column("sepal_width_normalized"), - tf.feature_column.numeric_column("petal_length_normalized"), - tf.feature_column.numeric_column("petal_width_normalized"), - ] - - # returns an instance of tf.estimator.Estimator - return tf.estimator.DNNClassifier( - feature_columns=feature_columns, - hidden_units=model_config["hparams"]["hidden_units"], - n_classes=len(model_config["aggregates"]["class_index"]), - config=run_config, - ) + num_steps: 1000 ``` -Cortex supports any TensorFlow code that adheres to the [tf.estimator API](https://www.tensorflow.org/guide/estimators). - #### Define web APIs This will make the model available as a live web service that can serve real-time predictions. @@ -288,7 +85,7 @@ Add to `app.yaml`: - kind: api name: iris-type - model_name: dnn + model: @dnn compute: replicas: 1 ``` @@ -301,78 +98,6 @@ $ cortex deploy Deployment started ``` -The first deployment may take some extra time as Cortex's dependencies are downloaded. - -You can get an overview of the deployment using `cortex get` (see [resource statuses](applications/resources/statuses.md) for the meaning of each status): - -``` -$ cortex get - ---------------- -Python Packages ---------------- - -None - ------------ -Raw Columns ------------ - -NAME STATUS AGE -class ready 56s -petal_length ready 56s -petal_width ready 56s -sepal_length ready 56s -sepal_width ready 56s - ----------- -Aggregates ----------- - -NAME STATUS AGE -class_index ready 33s -petal_length_mean ready 44s -petal_length_stddev ready 44s -petal_width_mean ready 44s -petal_width_stddev ready 44s -sepal_length_mean ready 44s -sepal_length_stddev ready 44s -sepal_width_mean ready 44s -sepal_width_stddev ready 44s - -------------------- -Transformed Columns -------------------- - -NAME STATUS AGE -class_indexed ready 29s -petal_length_normalized ready 26s -petal_width_normalized ready 23s -sepal_length_normalized ready 20s -sepal_width_normalized ready 17s - ------------------ -Training Datasets ------------------ - -NAME STATUS AGE -dnn/training_dataset ready 9s - ------- -Models ------- - -NAME STATUS AGE -dnn training - - ----- -APIs ----- - -NAME STATUS LAST UPDATE -iris-type pending - -``` - You can get a summary of the status of resources using `cortex status`: ``` @@ -387,126 +112,6 @@ Models: 1 training APIs: 1 pending ``` -You can also view the status of individual resources using the `status` command: - -``` -$ cortex status sepal_length_normalized - ---------- -Ingesting ---------- - -Ingesting iris data from s3a://cortex-examples/iris.csv -Caching iris data (version: 2019-02-12-22-55-07-611766) -150 rows ingested - -Reading iris data (version: 2019-02-12-22-55-07-611766) - -First 3 samples: - -class: Iris-setosa Iris-setosa Iris-setosa -petal_length: 1.40 1.40 1.30 -petal_width: 0.20 0.20 0.20 -sepal_length: 5.10 4.90 4.70 -sepal_width: 3.50 3.00 3.20 - ------------ -Aggregating ------------ - -Aggregating petal_length_mean -Aggregating petal_length_stddev -Aggregating petal_width_mean -Aggregating petal_width_stddev -Aggregating sepal_length_mean -Aggregating sepal_length_stddev -Aggregating sepal_width_mean -Aggregating sepal_width_stddev -Aggregating class_index - -Aggregates: - -class_index: ["Iris-setosa", "Iris-versicolor", "Iris-virginica"] -petal_length_mean: 3.7586666552225747 -petal_length_stddev: 1.7644204144315179 -petal_width_mean: 1.198666658103466 -petal_width_stddev: 0.7631607319020202 -sepal_length_mean: 5.843333326975505 -sepal_length_stddev: 0.8280661128539085 -sepal_width_mean: 3.0540000025431313 -sepal_width_stddev: 0.43359431104332985 - ------------------------ -Validating Transformers ------------------------ - -Sanity checking transformers against the first 100 samples - -Transforming class to class_indexed -class: Iris-setosa Iris-setosa Iris-setosa -class_indexed: 0 0 0 - -Transforming petal_length to petal_length_normalized -petal_length: 1.40 1.40 1.30 -petal_length_norm...: -1.34 -1.34 -1.39 - -Transforming petal_width to petal_width_normalized -petal_width: 0.20 0.20 0.20 -petal_width_norma...: -1.31 -1.31 -1.31 - -Transforming sepal_length to sepal_length_normalized -sepal_length: 5.10 4.90 4.70 -sepal_length_norm...: -0.90 -1.14 -1.38 - -Transforming sepal_width to sepal_width_normalized -sepal_width: 3.50 3.00 3.20 -sepal_width_norma...: 1.03 -0.12 0.34 - ----------------------------- -Generating Training Datasets ----------------------------- - -Generating dnn/training_dataset - -Completed on Tuesday, February 14, 2019 at 2:56pm PST -``` - -``` -$ cortex status dnn - --------- -Training --------- - -loss = 13.321785, step = 1 -loss = 3.8588388, step = 101 (0.226 sec) -loss = 4.1841183, step = 201 (0.241 sec) -loss = 4.089279, step = 301 (0.194 sec) -loss = 1.646344, step = 401 (0.174 sec) -loss = 2.367354, step = 501 (0.189 sec) -loss = 2.0011806, step = 601 (0.192 sec) -loss = 1.7621514, step = 701 (0.211 sec) -loss = 0.8322474, step = 801 (0.190 sec) -loss = 1.3244338, step = 901 (0.194 sec) - ----------- -Evaluating ----------- - -accuracy = 0.96153843 -average_loss = 0.13040856 -global_step = 1000 -loss = 3.3906221 - -------- -Caching -------- - -Caching model dnn - -Completed on Tuesday, February 14, 2019 at 2:56pm PST -``` - #### Test the iris classification service Define a sample in `irises.json`: @@ -542,38 +147,7 @@ Get the API's endpoint: ``` $ cortex get api iris-type -------- -Summary -------- - -Status: ready -Updated replicas: 1/1 ready -Created at: 2019-02-14 14:57:04 PST -Refreshed at: 2019-02-14 14:57:35 PST - --------- -Endpoint --------- - -URL: https://a84607a462f1811e9aa3b020abd0a844-645332984.us-west-2.elb.amazonaws.com/iris/iris-type -Method: POST -Header: "Content-Type: application/json" -Payload: { "samples": [ { "petal_length": FLOAT, "petal_width": FLOAT, "sepal_length": FLOAT, "sepal_width": FLOAT } ] } - -------------- -Configuration -------------- - -{ - "name": "iris-type", - "model_name": "dnn", - "compute": { - "replicas": 1, - "cpu": , - "mem": - }, - "tags": {} -} +# https://abc.amazonaws.com/iris/iris-type ``` Use cURL to test the API: From deb8360da9e45aff3d60e8d5dc3ad336a5ce33d5 Mon Sep 17 00:00:00 2001 From: Omer Spillinger Date: Wed, 12 Jun 2019 18:14:00 -0700 Subject: [PATCH 37/44] Update tutorial.md --- docs/tutorial.md | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/docs/tutorial.md b/docs/tutorial.md index cebf3e2493..cddba9e990 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -102,14 +102,6 @@ You can get a summary of the status of resources using `cortex status`: ``` $ cortex status --watch - -Python Packages: none -Raw Columns: 5 ready -Aggregates: 9 ready -Transformed Columns: 5 ready -Training Datasets: 1 ready -Models: 1 training -APIs: 1 pending ``` #### Test the iris classification service @@ -133,11 +125,6 @@ When the API is ready, request a prediction from the API: ``` $ cortex predict iris-type irises.json - -iris-type was last updated on Tuesday, February 14, 2019 at 2:57pm PST - -Predicted class: -Iris-setosa ``` #### Call the API from other clients (e.g. cURL) @@ -146,8 +133,6 @@ Get the API's endpoint: ``` $ cortex get api iris-type - -# https://abc.amazonaws.com/iris/iris-type ``` Use cURL to test the API: @@ -158,8 +143,6 @@ $ curl -k \ -H "Content-Type: application/json" \ -d '{ "samples": [ { "sepal_length": 5.2, "sepal_width": 3.6, "petal_length": 1.4, "petal_width": 0.3 } ] }' \ - -{"classification_predictions":[{"class_ids":["0"],"classes":["MA=="],"logits":[1.501487135887146,-0.6141998171806335,-1.4335800409317017],"predicted_class":0,"predicted_class_reversed":"Iris-setosa","probabilities":[0.8520227670669556,0.10271172970533371,0.04526554048061371]}],"resource_id":"18ef9f6fb4a1a8b2a3d3e8068f179f89f65d1ae3d8ac9d96b782b1cec3b39d2"} ``` ## Cleanup From 4146e2ff367f07d634e9c3812744e5dce9a4dbea Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Thu, 13 Jun 2019 02:49:39 -0400 Subject: [PATCH 38/44] External model inputs update (#159) --- cli/cmd/get.go | 36 +++-- cli/cmd/lib_cli_config.go | 16 +- cli/cmd/predict.go | 10 ++ docs/applications/advanced/external-models.md | 29 ++++ docs/applications/resources/apis.md | 1 + docs/summary.md | 1 + examples/external-model/app.yaml | 8 + examples/external-model/samples.json | 10 ++ pkg/lib/aws/errors.go | 11 +- pkg/lib/aws/s3.go | 14 ++ pkg/lib/configreader/float32_ptr.go | 13 +- pkg/lib/configreader/float64_ptr.go | 13 +- pkg/lib/configreader/int32_ptr.go | 13 +- pkg/lib/configreader/int64_ptr.go | 13 +- pkg/lib/configreader/int_ptr.go | 13 +- pkg/lib/configreader/string_ptr.go | 13 +- pkg/lib/configreader/validators.go | 60 ++----- pkg/operator/api/context/apis.go | 2 +- pkg/operator/api/context/dependencies.go | 3 + pkg/operator/api/context/serialize.go | 24 +-- pkg/operator/api/userconfig/apis.go | 46 ++++-- pkg/operator/api/userconfig/config.go | 30 +++- pkg/operator/api/userconfig/environments.go | 7 +- pkg/operator/api/userconfig/errors.go | 9 ++ pkg/operator/context/apis.go | 25 ++- pkg/operator/context/context.go | 18 ++- pkg/operator/workloads/workflow.go | 30 ++-- pkg/workloads/lib/context.py | 46 +++--- pkg/workloads/lib/storage/s3.py | 15 ++ pkg/workloads/lib/util.py | 9 ++ pkg/workloads/tf_api/api.py | 148 ++++++++++++------ 31 files changed, 486 insertions(+), 200 deletions(-) create mode 100644 docs/applications/advanced/external-models.md create mode 100644 examples/external-model/app.yaml create mode 100644 examples/external-model/samples.json diff --git a/cli/cmd/get.go b/cli/cmd/get.go index b3b3d5a0d0..67ddf1090b 100644 --- a/cli/cmd/get.go +++ b/cli/cmd/get.go @@ -382,7 +382,6 @@ func describeAPI(name string, resourcesRes *schema.GetResourcesResponse) (string ctx := resourcesRes.Context api := ctx.APIs[name] - model := ctx.Models[api.ModelName] var staleReplicas int32 var ctxAPIStatus *resource.APIStatus @@ -412,26 +411,29 @@ func describeAPI(name string, resourcesRes *schema.GetResourcesResponse) (string } out += titleStr("Endpoint") - resIDs := strset.New() - combinedInput := []interface{}{model.Input, model.TrainingInput} - for _, res := range ctx.ExtractCortexResources(combinedInput, resource.ConstantType, resource.RawColumnType, resource.AggregateType, resource.TransformedColumnType) { - resIDs.Add(res.GetID()) - resIDs.Merge(ctx.AllComputedResourceDependencies(res.GetID())) - } - var samplePlaceholderFields []string - for rawColumnName, rawColumn := range ctx.RawColumns { - if resIDs.Has(rawColumn.GetID()) { - fieldStr := fmt.Sprintf("\"%s\": %s", rawColumnName, rawColumn.GetColumnType().JSONPlaceholder()) - samplePlaceholderFields = append(samplePlaceholderFields, fieldStr) - } - } - sort.Strings(samplePlaceholderFields) - samplesPlaceholderStr := `{ "samples": [ { ` + strings.Join(samplePlaceholderFields, ", ") + " } ] }" out += "URL: " + urls.Join(resourcesRes.APIsBaseURL, anyAPIStatus.Path) + "\n" out += "Method: POST\n" out += `Header: "Content-Type: application/json"` + "\n" - out += "Payload: " + samplesPlaceholderStr + "\n" + if api.Model != nil { + model := ctx.Models[api.ModelName] + resIDs := strset.New() + combinedInput := []interface{}{model.Input, model.TrainingInput} + for _, res := range ctx.ExtractCortexResources(combinedInput, resource.ConstantType, resource.RawColumnType, resource.AggregateType, resource.TransformedColumnType) { + resIDs.Add(res.GetID()) + resIDs.Merge(ctx.AllComputedResourceDependencies(res.GetID())) + } + var samplePlaceholderFields []string + for rawColumnName, rawColumn := range ctx.RawColumns { + if resIDs.Has(rawColumn.GetID()) { + fieldStr := fmt.Sprintf("\"%s\": %s", rawColumnName, rawColumn.GetColumnType().JSONPlaceholder()) + samplePlaceholderFields = append(samplePlaceholderFields, fieldStr) + } + } + sort.Strings(samplePlaceholderFields) + samplesPlaceholderStr := `{ "samples": [ { ` + strings.Join(samplePlaceholderFields, ", ") + " } ] }" + out += "Payload: " + samplesPlaceholderStr + "\n" + } if api != nil { out += resourceStr(api.API) } diff --git a/cli/cmd/lib_cli_config.go b/cli/cmd/lib_cli_config.go index 811c19f3d1..e881adcc2f 100644 --- a/cli/cmd/lib_cli_config.go +++ b/cli/cmd/lib_cli_config.go @@ -59,10 +59,11 @@ func getPromptValidation(defaults *CliConfig) *cr.PromptValidation { PromptOpts: &cr.PromptOptions{ Prompt: "Enter Cortex operator endpoint", }, - StringValidation: cr.GetURLValidation(&cr.URLValidation{ - Required: true, - Default: defaults.CortexURL, - }), + StringValidation: &cr.StringValidation{ + Required: true, + Default: defaults.CortexURL, + Validator: cr.GetURLValidator(false, false), + }, }, { StructField: "AWSAccessKeyID", @@ -97,9 +98,10 @@ var fileValidation = &cr.StructValidation{ { Key: "cortex_url", StructField: "CortexURL", - StringValidation: cr.GetURLValidation(&cr.URLValidation{ - Required: true, - }), + StringValidation: &cr.StringValidation{ + Required: true, + Validator: cr.GetURLValidator(false, false), + }, }, { Key: "aws_access_key_id", diff --git a/cli/cmd/predict.go b/cli/cmd/predict.go index 4fe20925e1..78c8e7d7fb 100644 --- a/cli/cmd/predict.go +++ b/cli/cmd/predict.go @@ -109,6 +109,16 @@ var predictCmd = &cobra.Command{ } for _, prediction := range predictResponse.Predictions { + if prediction.Prediction == nil { + prettyResp, err := json.Pretty(prediction.Response) + if err != nil { + errors.Exit(err) + } + + fmt.Println(prettyResp) + continue + } + value := prediction.Prediction if prediction.PredictionReversed != nil { value = prediction.PredictionReversed diff --git a/docs/applications/advanced/external-models.md b/docs/applications/advanced/external-models.md new file mode 100644 index 0000000000..d82abdb466 --- /dev/null +++ b/docs/applications/advanced/external-models.md @@ -0,0 +1,29 @@ +# Importing External Models + +You can serve a model that was trained outside of Cortex as an API. + +1. Zip the exported estimator output in your checkpoint directory, e.g. + +```bash +$ ls export/estimator +saved_model.pb variables/ + +$ zip -r model.zip export/estimator +``` + +2. Upload the zipped file to Amazon S3, e.g. + +```bash +$ aws s3 cp model.zip s3://your-bucket/model.zip +``` + +3. Specify `model_path` in an API, e.g. + +```yaml +- kind: api + name: my-api + model_path: s3://your-bucket/model.zip + compute: + replicas: 5 + gpu: 1 +``` diff --git a/docs/applications/resources/apis.md b/docs/applications/resources/apis.md index aa2fb07183..29c8ac0638 100644 --- a/docs/applications/resources/apis.md +++ b/docs/applications/resources/apis.md @@ -8,6 +8,7 @@ Serve models at scale and use them to build smarter applications. - kind: api # (required) name: # API name (required) model_name: # name of a Cortex model (required) + model_path: # path to a zipped model dir (optional) compute: replicas: # number of replicas to launch (default: 1) cpu: # CPU request (default: Null) diff --git a/docs/summary.md b/docs/summary.md index b169d99715..5de9956fc7 100644 --- a/docs/summary.md +++ b/docs/summary.md @@ -38,6 +38,7 @@ * [Compute](applications/advanced/compute.md) * [Python Packages](applications/advanced/python-packages.md) * [Development](development.md) + * [Importing External Models](applications/advanced/external-models.md) ## Operator diff --git a/examples/external-model/app.yaml b/examples/external-model/app.yaml new file mode 100644 index 0000000000..a9c60cd5ed --- /dev/null +++ b/examples/external-model/app.yaml @@ -0,0 +1,8 @@ +- kind: app + name: iris + +- kind: api + name: iris + model_path: s3://cortex-examples/iris-model.zip + compute: + replicas: 1 diff --git a/examples/external-model/samples.json b/examples/external-model/samples.json new file mode 100644 index 0000000000..33d1e6a5b5 --- /dev/null +++ b/examples/external-model/samples.json @@ -0,0 +1,10 @@ +{ + "samples": [ + { + "sepal_length": 5.2, + "sepal_width": 3.6, + "petal_length": 1.4, + "petal_width": 0.3 + } + ] +} diff --git a/pkg/lib/aws/errors.go b/pkg/lib/aws/errors.go index 2a2d09e16d..634ba81e49 100644 --- a/pkg/lib/aws/errors.go +++ b/pkg/lib/aws/errors.go @@ -29,12 +29,14 @@ type ErrorKind int const ( ErrUnknown ErrorKind = iota ErrInvalidS3aPath + ErrInvalidS3Path ErrAuth ) var errorKinds = []string{ "err_unknown", "err_invalid_s3a_path", + "err_invalid_s3_path", "err_auth", } @@ -105,7 +107,14 @@ func (e Error) Error() string { func ErrorInvalidS3aPath(provided string) error { return Error{ Kind: ErrInvalidS3aPath, - message: fmt.Sprintf("%s is not a valid s3a path", s.UserStr(provided)), + message: fmt.Sprintf("%s is not a valid s3a path (e.g. s3a://cortex-examples/iris.csv is a valid s3a path)", s.UserStr(provided)), + } +} + +func ErrorInvalidS3Path(provided string) error { + return Error{ + Kind: ErrInvalidS3Path, + message: fmt.Sprintf("%s is not a valid s3 path (e.g. s3://cortex-examples/iris-model.zip is a valid s3 path)", s.UserStr(provided)), } } diff --git a/pkg/lib/aws/s3.go b/pkg/lib/aws/s3.go index 9c62935b2b..4bcd6a7c4f 100644 --- a/pkg/lib/aws/s3.go +++ b/pkg/lib/aws/s3.go @@ -233,6 +233,20 @@ func (c *Client) DeleteFromS3ByPrefix(prefix string, continueIfFailure bool) err return errors.Wrap(err, prefix) } +func IsValidS3Path(s3Path string) bool { + if !strings.HasPrefix(s3Path, "s3://") { + return false + } + parts := strings.Split(s3Path[5:], "/") + if len(parts) < 2 { + return false + } + if parts[0] == "" || parts[1] == "" { + return false + } + return true +} + func IsValidS3aPath(s3aPath string) bool { if !strings.HasPrefix(s3aPath, "s3a://") { return false diff --git a/pkg/lib/configreader/float32_ptr.go b/pkg/lib/configreader/float32_ptr.go index 540da3ccfb..bc069e2c70 100644 --- a/pkg/lib/configreader/float32_ptr.go +++ b/pkg/lib/configreader/float32_ptr.go @@ -33,7 +33,7 @@ type Float32PtrValidation struct { GreaterThanOrEqualTo *float32 LessThan *float32 LessThanOrEqualTo *float32 - Validator func(*float32) (*float32, error) + Validator func(float32) (float32, error) } func makeFloat32ValValidation(v *Float32PtrValidation) *Float32Validation { @@ -171,8 +171,17 @@ func validateFloat32Ptr(val *float32, v *Float32PtrValidation) (*float32, error) } } + if val == nil { + return val, nil + } + if v.Validator != nil { - return v.Validator(val) + validated, err := v.Validator(*val) + if err != nil { + return nil, err + } + return &validated, nil } + return val, nil } diff --git a/pkg/lib/configreader/float64_ptr.go b/pkg/lib/configreader/float64_ptr.go index 51808f6c63..d5678d43f5 100644 --- a/pkg/lib/configreader/float64_ptr.go +++ b/pkg/lib/configreader/float64_ptr.go @@ -33,7 +33,7 @@ type Float64PtrValidation struct { GreaterThanOrEqualTo *float64 LessThan *float64 LessThanOrEqualTo *float64 - Validator func(*float64) (*float64, error) + Validator func(float64) (float64, error) } func makeFloat64ValValidation(v *Float64PtrValidation) *Float64Validation { @@ -171,8 +171,17 @@ func validateFloat64Ptr(val *float64, v *Float64PtrValidation) (*float64, error) } } + if val == nil { + return val, nil + } + if v.Validator != nil { - return v.Validator(val) + validated, err := v.Validator(*val) + if err != nil { + return nil, err + } + return &validated, nil } + return val, nil } diff --git a/pkg/lib/configreader/int32_ptr.go b/pkg/lib/configreader/int32_ptr.go index e8f3f45d07..c9f8e91f0c 100644 --- a/pkg/lib/configreader/int32_ptr.go +++ b/pkg/lib/configreader/int32_ptr.go @@ -33,7 +33,7 @@ type Int32PtrValidation struct { GreaterThanOrEqualTo *int32 LessThan *int32 LessThanOrEqualTo *int32 - Validator func(*int32) (*int32, error) + Validator func(int32) (int32, error) } func makeInt32ValValidation(v *Int32PtrValidation) *Int32Validation { @@ -171,8 +171,17 @@ func validateInt32Ptr(val *int32, v *Int32PtrValidation) (*int32, error) { } } + if val == nil { + return val, nil + } + if v.Validator != nil { - return v.Validator(val) + validated, err := v.Validator(*val) + if err != nil { + return nil, err + } + return &validated, nil } + return val, nil } diff --git a/pkg/lib/configreader/int64_ptr.go b/pkg/lib/configreader/int64_ptr.go index 560403a7f7..c20c1571ff 100644 --- a/pkg/lib/configreader/int64_ptr.go +++ b/pkg/lib/configreader/int64_ptr.go @@ -33,7 +33,7 @@ type Int64PtrValidation struct { GreaterThanOrEqualTo *int64 LessThan *int64 LessThanOrEqualTo *int64 - Validator func(*int64) (*int64, error) + Validator func(int64) (int64, error) } func makeInt64ValValidation(v *Int64PtrValidation) *Int64Validation { @@ -171,8 +171,17 @@ func validateInt64Ptr(val *int64, v *Int64PtrValidation) (*int64, error) { } } + if val == nil { + return val, nil + } + if v.Validator != nil { - return v.Validator(val) + validated, err := v.Validator(*val) + if err != nil { + return nil, err + } + return &validated, nil } + return val, nil } diff --git a/pkg/lib/configreader/int_ptr.go b/pkg/lib/configreader/int_ptr.go index 5f7d43b5f8..f86e2e1f38 100644 --- a/pkg/lib/configreader/int_ptr.go +++ b/pkg/lib/configreader/int_ptr.go @@ -33,7 +33,7 @@ type IntPtrValidation struct { GreaterThanOrEqualTo *int LessThan *int LessThanOrEqualTo *int - Validator func(*int) (*int, error) + Validator func(int) (int, error) } func makeIntValValidation(v *IntPtrValidation) *IntValidation { @@ -171,8 +171,17 @@ func validateIntPtr(val *int, v *IntPtrValidation) (*int, error) { } } + if val == nil { + return val, nil + } + if v.Validator != nil { - return v.Validator(val) + validated, err := v.Validator(*val) + if err != nil { + return nil, err + } + return &validated, nil } + return val, nil } diff --git a/pkg/lib/configreader/string_ptr.go b/pkg/lib/configreader/string_ptr.go index f9ee57e2e3..a18ce2f553 100644 --- a/pkg/lib/configreader/string_ptr.go +++ b/pkg/lib/configreader/string_ptr.go @@ -35,7 +35,7 @@ type StringPtrValidation struct { DNS1123 bool AllowCortexResources bool RequireCortexResources bool - Validator func(*string) (*string, error) + Validator func(string) (string, error) } func makeStringValValidation(v *StringPtrValidation) *StringValidation { @@ -170,8 +170,17 @@ func validateStringPtr(val *string, v *StringPtrValidation) (*string, error) { } } + if val == nil { + return val, nil + } + if v.Validator != nil { - return v.Validator(val) + validated, err := v.Validator(*val) + if err != nil { + return nil, err + } + return &validated, nil } + return val, nil } diff --git a/pkg/lib/configreader/validators.go b/pkg/lib/configreader/validators.go index 0e8e34bc8e..e6d7745299 100644 --- a/pkg/lib/configreader/validators.go +++ b/pkg/lib/configreader/validators.go @@ -31,69 +31,49 @@ func init() { portRe = regexp.MustCompile(`:[0-9]+$`) } -type PathValidation struct { - Required bool - Default string - BaseDir string -} - -func GetFilePathValidation(v *PathValidation) *StringValidation { - validator := func(val string) (string, error) { - val = files.RelPath(val, v.BaseDir) +func GetFilePathValidator(baseDir string) func(string) (string, error) { + return func(val string) (string, error) { + val = files.RelPath(val, baseDir) if err := files.CheckFile(val); err != nil { return "", err } return val, nil } - - return &StringValidation{ - Required: v.Required, - Default: v.Default, - Validator: validator, - } } -type S3aPathValidation struct { - Required bool - Default string -} - -func GetS3aPathValidation(v *S3aPathValidation) *StringValidation { - validator := func(val string) (string, error) { +func GetS3aPathValidator() func(string) (string, error) { + return func(val string) (string, error) { if !aws.IsValidS3aPath(val) { return "", aws.ErrorInvalidS3aPath(val) } return val, nil } - - return &StringValidation{ - Required: v.Required, - Default: v.Default, - Validator: validator, - } } -type URLValidation struct { - Required bool - Default string - DefaultHTTP bool // Otherwise default is https - AddPort bool +func GetS3PathValidator() func(string) (string, error) { + return func(val string) (string, error) { + if !aws.IsValidS3Path(val) { + return "", aws.ErrorInvalidS3Path(val) + } + return val, nil + } } -func GetURLValidation(v *URLValidation) *StringValidation { - validator := func(val string) (string, error) { +// uses https unless defaultHTTP == true +func GetURLValidator(defaultHTTP bool, addPort bool) func(string) (string, error) { + return func(val string) (string, error) { urlStr := strings.TrimSpace(val) if !strings.HasPrefix(strings.ToLower(urlStr), "http") { - if v.DefaultHTTP { + if defaultHTTP { urlStr = "http://" + urlStr } else { urlStr = "https://" + urlStr } } - if v.AddPort { + if addPort { if !portRe.MatchString(urlStr) { if strings.HasPrefix(strings.ToLower(urlStr), "https") { urlStr = urlStr + ":443" @@ -109,10 +89,4 @@ func GetURLValidation(v *URLValidation) *StringValidation { return urlStr, nil } - - return &StringValidation{ - Required: v.Required, - Default: v.Default, - Validator: validator, - } } diff --git a/pkg/operator/api/context/apis.go b/pkg/operator/api/context/apis.go index a6165d40eb..0a6dd7d2e1 100644 --- a/pkg/operator/api/context/apis.go +++ b/pkg/operator/api/context/apis.go @@ -26,7 +26,7 @@ type API struct { *userconfig.API *ComputedResourceFields Path string `json:"path"` - ModelName string `json:"model_name"` // This is just a convenience which removes the @ from userconfig.API.Model + ModelName string `json:"model_name"` // This removes the @ from userconfig.API.Model, or sets it to userconfig.API.ModelPath if it's external } func APIPath(apiName string, appName string) string { diff --git a/pkg/operator/api/context/dependencies.go b/pkg/operator/api/context/dependencies.go index eb7a8fe3c9..4c80076f53 100644 --- a/pkg/operator/api/context/dependencies.go +++ b/pkg/operator/api/context/dependencies.go @@ -151,6 +151,9 @@ func (ctx *Context) modelDependencies(model *Model) strset.Set { } func (ctx *Context) apiDependencies(api *API) strset.Set { + if api.Model == nil { + return strset.New() + } model := ctx.Models[api.ModelName] return strset.New(model.ID) } diff --git a/pkg/operator/api/context/serialize.go b/pkg/operator/api/context/serialize.go index df147dc310..9a68e373e5 100644 --- a/pkg/operator/api/context/serialize.go +++ b/pkg/operator/api/context/serialize.go @@ -189,6 +189,10 @@ func (ctx *Context) castSchemaTypes() error { } func (ctx *Context) ToSerial() *Serial { + if ctx.Environment == nil { + return &Serial{Context: ctx} + } + serial := Serial{ Context: ctx, RawColumnSplit: ctx.splitRawColumns(), @@ -201,17 +205,19 @@ func (ctx *Context) ToSerial() *Serial { func (serial *Serial) ContextFromSerial() (*Context, error) { ctx := serial.Context - ctx.RawColumns = serial.collectRawColumns() + if ctx.Environment != nil { + ctx.RawColumns = serial.collectRawColumns() - environment, err := serial.collectEnvironment() - if err != nil { - return nil, err - } - ctx.Environment = environment + environment, err := serial.collectEnvironment() + if err != nil { + return nil, err + } + ctx.Environment = environment - err = ctx.castSchemaTypes() - if err != nil { - return nil, err + err = ctx.castSchemaTypes() + if err != nil { + return nil, err + } } return ctx, nil diff --git a/pkg/operator/api/userconfig/apis.go b/pkg/operator/api/userconfig/apis.go index e7cd27909d..fbc960d628 100644 --- a/pkg/operator/api/userconfig/apis.go +++ b/pkg/operator/api/userconfig/apis.go @@ -17,9 +17,8 @@ limitations under the License. package userconfig import ( - "github.com/cortexlabs/yaml" - cr "github.com/cortexlabs/cortex/pkg/lib/configreader" + "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/operator/api/resource" ) @@ -27,9 +26,10 @@ type APIs []*API type API struct { ResourceFields - Model string `json:"model" yaml:"model"` - Compute *APICompute `json:"compute" yaml:"compute"` - Tags Tags `json:"tags" yaml:"tags"` + Model *string `json:"model" yaml:"model"` + ModelPath *string `json:"model_path" yaml:"model_path"` + Compute *APICompute `json:"compute" yaml:"compute"` + Tags Tags `json:"tags" yaml:"tags"` } var apiValidation = &cr.StructValidation{ @@ -42,18 +42,17 @@ var apiValidation = &cr.StructValidation{ }, }, { - StructField: "Model", - DefaultField: "Name", - DefaultFieldFunc: func(name interface{}) interface{} { - model := "@" + name.(string) - escapedModel, _ := yaml.EscapeAtSymbol(model) - return escapedModel - }, - StringValidation: &cr.StringValidation{ - Required: false, + StructField: "Model", + StringPtrValidation: &cr.StringPtrValidation{ RequireCortexResources: true, }, }, + { + StructField: "ModelPath", + StringPtrValidation: &cr.StringPtrValidation{ + Validator: cr.GetS3PathValidator(), + }, + }, apiComputeFieldValidation, tagsFieldValidation, typeFieldValidation, @@ -61,6 +60,12 @@ var apiValidation = &cr.StructValidation{ } func (apis APIs) Validate() error { + for _, api := range apis { + if err := api.Validate(); err != nil { + return err + } + } + resources := make([]Resource, len(apis)) for i, res := range apis { resources[i] = res @@ -70,6 +75,19 @@ func (apis APIs) Validate() error { if len(dups) > 0 { return ErrorDuplicateResourceName(dups...) } + + return nil +} + +func (api *API) Validate() error { + if api.ModelPath == nil && api.Model == nil { + return errors.Wrap(ErrorSpecifyOnlyOneMissing("model_name", "model_path"), Identify(api)) + } + + if api.ModelPath != nil && api.Model != nil { + return errors.Wrap(ErrorSpecifyOnlyOne("model_name", "model_path"), Identify(api)) + } + return nil } diff --git a/pkg/operator/api/userconfig/config.go b/pkg/operator/api/userconfig/config.go index 49afdca6e2..df7523d796 100644 --- a/pkg/operator/api/userconfig/config.go +++ b/pkg/operator/api/userconfig/config.go @@ -74,6 +74,15 @@ func mergeConfigs(target *Config, source *Config) error { target.App = source.App } + if target.Resources == nil { + target.Resources = make(map[string][]Resource) + } + for resourceName, resources := range source.Resources { + for _, res := range resources { + target.Resources[resourceName] = append(target.Resources[resourceName], res) + } + } + return nil } @@ -190,8 +199,27 @@ func (config *Config) Validate(envName string) error { config.Environment = env } } + + apisAllExternal := true + for _, api := range config.APIs { + if api.Model != nil { + apisAllExternal = false + break + } + } + if config.Environment == nil { - return ErrorUndefinedResource(envName, resource.EnvironmentType) + if !apisAllExternal || len(config.APIs) == 0 { + return ErrorUndefinedResource(envName, resource.EnvironmentType) + } + + for _, resources := range config.Resources { + for _, res := range resources { + if res.GetResourceType() != resource.APIType { + return ErrorExtraResourcesWithExternalAPIs(res) + } + } + } } return nil diff --git a/pkg/operator/api/userconfig/environments.go b/pkg/operator/api/userconfig/environments.go index 78f2ac5355..0e2c5e956a 100644 --- a/pkg/operator/api/userconfig/environments.go +++ b/pkg/operator/api/userconfig/environments.go @@ -153,9 +153,10 @@ type ExternalData struct { var externalDataValidation = []*cr.StructFieldValidation{ { StructField: "Path", - StringValidation: cr.GetS3aPathValidation(&cr.S3aPathValidation{ - Required: true, - }), + StringValidation: &cr.StringValidation{ + Required: true, + Validator: cr.GetS3aPathValidator(), + }, }, { StructField: "Region", diff --git a/pkg/operator/api/userconfig/errors.go b/pkg/operator/api/userconfig/errors.go index b5efdc55c7..19b1416996 100644 --- a/pkg/operator/api/userconfig/errors.go +++ b/pkg/operator/api/userconfig/errors.go @@ -73,6 +73,7 @@ const ( ErrPredictionKeyOnModelWithEstimator ErrSpecifyOnlyOneMissing ErrEnvSchemaMismatch + ErrExtraResourcesWithExternalAPIs ErrImplDoesNotExist ) @@ -121,6 +122,7 @@ var errorKinds = []string{ "err_prediction_key_on_model_with_estimator", "err_specify_only_one_missing", "err_env_schema_mismatch", + "err_extra_resources_with_external_a_p_is", "err_impl_does_not_exist", } @@ -560,6 +562,13 @@ func ErrorEnvSchemaMismatch(env1, env2 *Environment) error { } } +func ErrorExtraResourcesWithExternalAPIs(res Resource) error { + return Error{ + Kind: ErrExtraResourcesWithExternalAPIs, + message: fmt.Sprintf("only apis can be defined if environment is not defined (found %s)", Identify(res)), + } +} + func ErrorImplDoesNotExist(path string) error { return Error{ Kind: ErrImplDoesNotExist, diff --git a/pkg/operator/context/apis.go b/pkg/operator/context/apis.go index b2024f5c93..c3383d49de 100644 --- a/pkg/operator/context/apis.go +++ b/pkg/operator/context/apis.go @@ -29,20 +29,31 @@ import ( func getAPIs(config *userconfig.Config, models context.Models, + datasetVersion string, ) (context.APIs, error) { apis := context.APIs{} for _, apiConfig := range config.APIs { - modelName, _ := yaml.ExtractAtSymbolText(apiConfig.Model) - - model := models[modelName] - if model == nil { - return nil, errors.Wrap(userconfig.ErrorUndefinedResource(modelName, resource.ModelType), userconfig.Identify(apiConfig), userconfig.ModelNameKey) - } var buf bytes.Buffer + var modelName string buf.WriteString(apiConfig.Name) - buf.WriteString(model.ID) + + if apiConfig.Model != nil { + modelName, _ = yaml.ExtractAtSymbolText(*apiConfig.Model) + model := models[modelName] + if model == nil { + return nil, errors.Wrap(userconfig.ErrorUndefinedResource(modelName, resource.ModelType), userconfig.Identify(apiConfig), userconfig.ModelNameKey) + } + buf.WriteString(model.ID) + } + + if apiConfig.ModelPath != nil { + modelName = *apiConfig.ModelPath + buf.WriteString(datasetVersion) + buf.WriteString(*apiConfig.ModelPath) + } + id := hash.Bytes(buf.Bytes()) apis[apiConfig.Name] = &context.API{ diff --git a/pkg/operator/context/context.go b/pkg/operator/context/context.go index 641d7f87e9..9742c715da 100644 --- a/pkg/operator/context/context.go +++ b/pkg/operator/context/context.go @@ -133,16 +133,24 @@ func New( } ctx.DatasetVersion = datasetVersion - ctx.Environment = getEnvironment(userconf, datasetVersion) + if userconf.Environment != nil { + ctx.Environment = getEnvironment(userconf, datasetVersion) + } ctx.Root = filepath.Join( consts.AppsDir, ctx.App.Name, consts.DataDir, ctx.DatasetVersion, - ctx.Environment.ID, ) + if ctx.Environment != nil { + ctx.Root = filepath.Join( + ctx.Root, + ctx.Environment.ID, + ) + } + ctx.MetadataRoot = filepath.Join( ctx.Root, consts.MetadataDir, @@ -223,7 +231,7 @@ func New( } ctx.Models = models - apis, err := getAPIs(userconf, ctx.Models) + apis, err := getAPIs(userconf, ctx.Models, ctx.DatasetVersion) if err != nil { return nil, err } @@ -256,7 +264,9 @@ func calculateID(ctx *context.Context) string { ids = append(ids, ctx.RawDataset.Key) ids = append(ids, ctx.StatusPrefix) ids = append(ids, ctx.App.ID) - ids = append(ids, ctx.Environment.ID) + if ctx.Environment != nil { + ids = append(ids, ctx.Environment.ID) + } for _, resource := range ctx.AllResources() { ids = append(ids, resource.GetID()) diff --git a/pkg/operator/workloads/workflow.go b/pkg/operator/workloads/workflow.go index 8827bfd3a5..17c623b5a1 100644 --- a/pkg/operator/workloads/workflow.go +++ b/pkg/operator/workloads/workflow.go @@ -65,23 +65,25 @@ func Create(ctx *context.Context) (*awfv1.Workflow, error) { var allSpecs []*WorkloadSpec - pythonPackageJobSpecs, err := pythonPackageWorkloadSpecs(ctx) - if err != nil { - return nil, err - } - allSpecs = append(allSpecs, pythonPackageJobSpecs...) + if ctx.Environment != nil { + pythonPackageJobSpecs, err := pythonPackageWorkloadSpecs(ctx) + if err != nil { + return nil, err + } + allSpecs = append(allSpecs, pythonPackageJobSpecs...) - dataJobSpecs, err := dataWorkloadSpecs(ctx) - if err != nil { - return nil, err - } - allSpecs = append(allSpecs, dataJobSpecs...) + dataJobSpecs, err := dataWorkloadSpecs(ctx) + if err != nil { + return nil, err + } + allSpecs = append(allSpecs, dataJobSpecs...) - trainingJobSpecs, err := trainingWorkloadSpecs(ctx) - if err != nil { - return nil, err + trainingJobSpecs, err := trainingWorkloadSpecs(ctx) + if err != nil { + return nil, err + } + allSpecs = append(allSpecs, trainingJobSpecs...) } - allSpecs = append(allSpecs, trainingJobSpecs...) apiSpecs, err := apiWorkloadSpecs(ctx) if err != nil { diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index 9135147771..e2730161fb 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -99,11 +99,12 @@ def __init__(self, **kwargs): ) ) - self.columns = util.merge_dicts_overwrite(self.raw_columns, self.transformed_columns) + if self.environment is not None: + self.columns = util.merge_dicts_overwrite(self.raw_columns, self.transformed_columns) - self.raw_column_names = list(self.raw_columns.keys()) - self.transformed_column_names = list(self.transformed_columns.keys()) - self.column_names = list(self.columns.keys()) + self.raw_column_names = list(self.raw_columns.keys()) + self.transformed_column_names = list(self.transformed_columns.keys()) + self.column_names = list(self.columns.keys()) # Internal caches self._transformer_impls = {} @@ -117,15 +118,14 @@ def __init__(self, **kwargs): os.environ["AWS_REGION"] = self.cortex_config.get("region", "") # Id map - self.pp_id_map = ResourceMap(self.python_packages) - self.rf_id_map = ResourceMap(self.raw_columns) - self.ag_id_map = ResourceMap(self.aggregates) - self.tf_id_map = ResourceMap(self.transformed_columns) - self.td_id_map = ResourceMap(self.training_datasets) - self.models_id_map = ResourceMap(self.models) - self.apis_id_map = ResourceMap(self.apis) - self.constants_id_map = ResourceMap(self.constants) - + self.pp_id_map = ResourceMap(self.python_packages) if self.python_packages else None + self.rf_id_map = ResourceMap(self.raw_columns) if self.raw_columns else None + self.ag_id_map = ResourceMap(self.aggregates) if self.aggregates else None + self.tf_id_map = ResourceMap(self.transformed_columns) if self.transformed_columns else None + self.td_id_map = ResourceMap(self.training_datasets) if self.training_datasets else None + self.models_id_map = ResourceMap(self.models) if self.models else None + self.apis_id_map = ResourceMap(self.apis) if self.apis else None + self.constants_id_map = ResourceMap(self.constants) if self.constants else None self.id_map = util.merge_dicts_overwrite( self.pp_id_map, self.rf_id_map, @@ -704,17 +704,19 @@ def _validate_required_fn_args(impl, fn_name, args): def _deserialize_raw_ctx(raw_ctx): - raw_columns = raw_ctx["raw_columns"] - raw_ctx["raw_columns"] = util.merge_dicts_overwrite(*raw_columns.values()) + if raw_ctx.get("environment") is not None: + raw_columns = raw_ctx["raw_columns"] + raw_ctx["raw_columns"] = util.merge_dicts_overwrite(*raw_columns.values()) + + data_split = raw_ctx["environment_data"] - data_split = raw_ctx["environment_data"] + if data_split["csv_data"] is not None and data_split["parquet_data"] is None: + raw_ctx["environment"]["data"] = data_split["csv_data"] + elif data_split["parquet_data"] is not None and data_split["csv_data"] is None: + raw_ctx["environment"]["data"] = data_split["parquet_data"] + else: + raise CortexException("expected csv_data or parquet_data but found " + data_split) - if data_split["csv_data"] is not None and data_split["parquet_data"] is None: - raw_ctx["environment"]["data"] = data_split["csv_data"] - elif data_split["parquet_data"] is not None and data_split["csv_data"] is None: - raw_ctx["environment"]["data"] = data_split["parquet_data"] - else: - raise CortexException("expected csv_data or parquet_data but found " + data_split) return raw_ctx diff --git a/pkg/workloads/lib/storage/s3.py b/pkg/workloads/lib/storage/s3.py index afdb36b4b8..cd8dea6d58 100644 --- a/pkg/workloads/lib/storage/s3.py +++ b/pkg/workloads/lib/storage/s3.py @@ -175,3 +175,18 @@ def download_and_unzip(self, key, local_dir): local_zip = os.path.join(local_dir, "zip.zip") self.download_file(key, local_zip) util.extract_zip(local_zip, delete_zip_file=True) + + def download_and_unzip_external(self, s3_path, local_dir): + util.mkdir_p(local_dir) + local_zip = os.path.join(local_dir, "zip.zip") + self.download_file_external(s3_path, local_zip) + util.extract_zip(local_zip, delete_zip_file=True) + + def download_file_external(self, s3_path, local_path): + try: + util.mkdir_p(os.path.dirname(local_path)) + bucket, key = self.deconstruct_s3_path(s3_path) + self.s3.download_file(bucket, key, local_path) + return local_path + except Exception as e: + raise CortexException("bucket " + bucket, "key " + key) from e diff --git a/pkg/workloads/lib/util.py b/pkg/workloads/lib/util.py index 4a35aec5b1..77103a0994 100644 --- a/pkg/workloads/lib/util.py +++ b/pkg/workloads/lib/util.py @@ -421,6 +421,15 @@ def merge_dicts_no_overwrite(*dicts): def merge_two_dicts_in_place_overwrite(x, y): """Merge y into x, with overwriting. x is updated in place""" + if x is None: + return y + + if y is None: + return x + + if y is None and x is None: + return None + for k, v in y.items(): if k in x and isinstance(x[k], dict) and isinstance(y[k], collections.Mapping): merge_dicts_in_place_overwrite(x[k], y[k]) diff --git a/pkg/workloads/tf_api/api.py b/pkg/workloads/tf_api/api.py index ce763fca79..69fe3ff7da 100644 --- a/pkg/workloads/tf_api/api.py +++ b/pkg/workloads/tf_api/api.py @@ -64,6 +64,14 @@ "DT_COMPLEX128": "dcomplexVal", } +DTYPE_TO_TF_TYPE = { + "DT_INT32": tf.int32, + "DT_INT64": tf.int64, + "DT_FLOAT": tf.float32, + "DT_STRING": tf.string, + "DT_BOOL": tf.bool, +} + def transform_sample(sample): ctx = local_cache["ctx"] @@ -96,8 +104,8 @@ def transform_sample(sample): def create_prediction_request(transformed_sample): ctx = local_cache["ctx"] - signatureDef = local_cache["metadata"]["signatureDef"] - signature_key = list(signatureDef.keys())[0] + signature_def = local_cache["metadata"]["signatureDef"] + signature_key = list(signature_def.keys())[0] prediction_request = predict_pb2.PredictRequest() prediction_request.model_spec.name = "default" prediction_request.model_spec.signature_name = signature_key @@ -114,6 +122,24 @@ def create_prediction_request(transformed_sample): return prediction_request +def create_raw_prediction_request(sample): + signature_def = local_cache["metadata"]["signatureDef"] + signature_key = list(signature_def.keys())[0] + prediction_request = predict_pb2.PredictRequest() + prediction_request.model_spec.name = "default" + prediction_request.model_spec.signature_name = signature_key + + for column_name, value in sample.items(): + shape = [1] + if util.is_list(value): + shape = [len(value)] + sig_type = signature_def[signature_key]["inputs"][column_name]["dtype"] + tensor_proto = tf.make_tensor_proto([value], dtype=DTYPE_TO_TF_TYPE[sig_type], shape=shape) + prediction_request.inputs[column_name].CopyFrom(tensor_proto) + + return prediction_request + + def reverse_transform(value): ctx = local_cache["ctx"] model = local_cache["model"] @@ -204,19 +230,42 @@ def run_get_model_metadata(): return sigmap +def parse_response_proto_raw(response_proto): + results_dict = json_format.MessageToDict(response_proto) + outputs = results_dict["outputs"] + + outputs_simplified = {} + for key in outputs.keys(): + value_key = DTYPE_TO_VALUE_KEY[outputs[key]["dtype"]] + outputs_simplified[key] = outputs[key][value_key] + + return {"response": outputs_simplified} + + def run_predict(sample): - transformed_sample = transform_sample(sample) - prediction_request = create_prediction_request(transformed_sample) - response_proto = local_cache["stub"].Predict(prediction_request, timeout=10.0) - result = parse_response_proto(response_proto) - util.log_indent("Raw sample:", indent=4) - util.log_pretty(sample, indent=6) - util.log_indent("Transformed sample:", indent=4) - util.log_pretty(transformed_sample, indent=6) - util.log_indent("Prediction:", indent=4) - util.log_pretty(result, indent=6) - - result["transformed_sample"] = transformed_sample + if local_cache["ctx"].environment is not None: + transformed_sample = transform_sample(sample) + prediction_request = create_prediction_request(transformed_sample) + response_proto = local_cache["stub"].Predict(prediction_request, timeout=10.0) + result = parse_response_proto(response_proto) + + util.log_indent("Raw sample:", indent=4) + util.log_pretty(sample, indent=6) + util.log_indent("Transformed sample:", indent=4) + util.log_pretty(transformed_sample, indent=6) + util.log_indent("Prediction:", indent=4) + util.log_pretty(result, indent=6) + + result["transformed_sample"] = transformed_sample + + else: + prediction_request = create_raw_prediction_request(sample) + response_proto = local_cache["stub"].Predict(prediction_request, timeout=10.0) + result = parse_response_proto_raw(response_proto) + util.log_indent("Sample:", indent=4) + util.log_pretty(sample, indent=6) + util.log_indent("Prediction:", indent=4) + util.log_pretty(result, indent=6) return result @@ -281,13 +330,14 @@ def predict(app_name, api_name): for i, sample in enumerate(payload["samples"]): util.log_indent("sample {}".format(i + 1), 2) - is_valid, reason = is_valid_sample(sample) - if not is_valid: - return prediction_failed(sample, reason) + if local_cache["ctx"].environment is not None: + is_valid, reason = is_valid_sample(sample) + if not is_valid: + return prediction_failed(sample, reason) - for column in local_cache["required_inputs"]: - column_type = ctx.get_inferred_column_type(column["name"]) - sample[column["name"]] = util.upcast(sample[column["name"]], column_type) + for column in local_cache["required_inputs"]: + column_type = ctx.get_inferred_column_type(column["name"]) + sample[column["name"]] = util.upcast(sample[column["name"]], column_type) try: result = run_predict(sample) @@ -317,40 +367,48 @@ def start(args): package.install_packages(ctx.python_packages, ctx.storage) api = ctx.apis_id_map[args.api] - model = ctx.models[api["model_name"]] - estimator = ctx.estimators[model["estimator"]] - tf_lib.set_logging_verbosity(ctx.environment["log_level"]["tensorflow"]) - local_cache["ctx"] = ctx local_cache["api"] = api - local_cache["model"] = model - local_cache["estimator"] = estimator - local_cache["target_col"] = ctx.columns[util.get_resource_ref(model["target_column"])] - local_cache["target_col_type"] = ctx.get_inferred_column_type( - util.get_resource_ref(model["target_column"]) - ) + local_cache["ctx"] = ctx - if not os.path.isdir(args.model_dir): - ctx.storage.download_and_unzip(model["key"], args.model_dir) + if ctx.environment is not None: + model = ctx.models[api["model_name"]] + estimator = ctx.estimators[model["estimator"]] - for column_name in ctx.extract_column_names([model["input"], model["target_column"]]): - if ctx.is_transformed_column(column_name): - trans_impl, _ = ctx.get_transformer_impl(column_name) - local_cache["trans_impls"][column_name] = trans_impl - transformed_column = ctx.transformed_columns[column_name] + local_cache["model"] = model + local_cache["estimator"] = estimator + local_cache["target_col"] = ctx.columns[util.get_resource_ref(model["target_column"])] + local_cache["target_col_type"] = ctx.get_inferred_column_type( + util.get_resource_ref(model["target_column"]) + ) + + tf_lib.set_logging_verbosity(ctx.environment["log_level"]["tensorflow"]) + + if not os.path.isdir(args.model_dir): + ctx.storage.download_and_unzip(model["key"], args.model_dir) + + for column_name in ctx.extract_column_names([model["input"], model["target_column"]]): + if ctx.is_transformed_column(column_name): + trans_impl, _ = ctx.get_transformer_impl(column_name) + local_cache["trans_impls"][column_name] = trans_impl + transformed_column = ctx.transformed_columns[column_name] - # cache aggregate values - for resource_name in util.extract_resource_refs(transformed_column["input"]): - if resource_name in ctx.aggregates: - ctx.get_obj(ctx.aggregates[resource_name]["key"]) + # cache aggregate values + for resource_name in util.extract_resource_refs(transformed_column["input"]): + if resource_name in ctx.aggregates: + ctx.get_obj(ctx.aggregates[resource_name]["key"]) + + local_cache["required_inputs"] = tf_lib.get_base_input_columns(model["name"], ctx) + + else: + if not os.path.isdir(args.model_dir): + ctx.storage.download_and_unzip_external(api["model_path"], args.model_dir) channel = grpc.insecure_channel("localhost:" + str(args.tf_serve_port)) local_cache["stub"] = prediction_service_pb2_grpc.PredictionServiceStub(channel) - local_cache["required_inputs"] = tf_lib.get_base_input_columns(model["name"], ctx) - # wait a bit for tf serving to start before querying metadata - limit = 600 + limit = 300 for i in range(limit): try: local_cache["metadata"] = run_get_model_metadata() @@ -364,7 +422,7 @@ def start(args): time.sleep(1) - logger.info("Serving model: {}".format(model["name"])) + logger.info("Serving model: {}".format(api["model_name"])) serve(app, listen="*:{}".format(args.port)) From 28098b342f6b0be903ed0e1e704d51d6cd1611dc Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Thu, 13 Jun 2019 00:29:15 -0700 Subject: [PATCH 39/44] Improve python error handling --- pkg/workloads/spark_job/spark_util.py | 155 +++++++++++++++++--------- 1 file changed, 100 insertions(+), 55 deletions(-) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index ce1ebda8d8..04c3cdbaae 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -422,71 +422,91 @@ def split_aggregators(aggregate_names, ctx): def run_builtin_aggregators(builtin_aggregates, df, ctx, spark): - agg_cols = [] + agg_cols = [get_builtin_aggregator_column(agg, ctx) for agg in builtin_aggregates] + results = df.agg(*agg_cols).collect()[0].asDict() + for agg in builtin_aggregates: + result = results[agg["name"]] aggregator = ctx.aggregators[agg["aggregator"]] - input = ctx.populate_values(agg["input"], aggregator["input"], preserve_column_refs=False) + result = util.cast_output_type(result, aggregator["output_type"]) - if aggregator["name"] == "approx_count_distinct": - agg_cols.append( - F.approxCountDistinct(input["col"], input.get("rsd")).alias(agg["name"]) + results[agg["name"]] = result + ctx.store_aggregate_result(result, agg) + + return results + + +def get_builtin_aggregator_column(agg, ctx): + try: + aggregator = ctx.aggregators[agg["aggregator"]] + + try: + input = ctx.populate_values( + agg["input"], aggregator["input"], preserve_column_refs=False ) + except CortexException as e: + e.wrap("input") + raise + + if aggregator["name"] == "approx_count_distinct": + return F.approxCountDistinct(input["col"], input.get("rsd")).alias(agg["name"]) if aggregator["name"] == "avg": - agg_cols.append(F.avg(input).alias(agg["name"])) + return F.avg(input).alias(agg["name"]) if aggregator["name"] in {"collect_set_int", "collect_set_float", "collect_set_string"}: - agg_cols.append(F.collect_set(input).alias(agg["name"])) + return F.collect_set(input).alias(agg["name"]) if aggregator["name"] == "count": - agg_cols.append(F.count(input).alias(agg["name"])) + return F.count(input).alias(agg["name"]) if aggregator["name"] == "count_distinct": - agg_cols.append(F.countDistinct(*input).alias(agg["name"])) + return F.countDistinct(*input).alias(agg["name"]) if aggregator["name"] == "covar_pop": - agg_cols.append(F.covar_pop(input["col1"], input["col2"]).alias(agg["name"])) + return F.covar_pop(input["col1"], input["col2"]).alias(agg["name"]) if aggregator["name"] == "covar_samp": - agg_cols.append(F.covar_samp(input["col1"], input["col2"]).alias(agg["name"])) + return F.covar_samp(input["col1"], input["col2"]).alias(agg["name"]) if aggregator["name"] == "kurtosis": - agg_cols.append(F.kurtosis(input).alias(agg["name"])) + return F.kurtosis(input).alias(agg["name"]) if aggregator["name"] in {"max_int", "max_float", "max_string"}: - agg_cols.append(F.max(input).alias(agg["name"])) + return F.max(input).alias(agg["name"]) if aggregator["name"] == "mean": - agg_cols.append(F.mean(input).alias(agg["name"])) + return F.mean(input).alias(agg["name"]) if aggregator["name"] in {"min_int", "min_float", "min_string"}: - agg_cols.append(F.min(input).alias(agg["name"])) + return F.min(input).alias(agg["name"]) if aggregator["name"] == "skewness": - agg_cols.append(F.skewness(input).alias(agg["name"])) + return F.skewness(input).alias(agg["name"]) if aggregator["name"] == "stddev": - agg_cols.append(F.stddev(input).alias(agg["name"])) + return F.stddev(input).alias(agg["name"]) if aggregator["name"] == "stddev_pop": - agg_cols.append(F.stddev_pop(input).alias(agg["name"])) + return F.stddev_pop(input).alias(agg["name"]) if aggregator["name"] == "stddev_samp": - agg_cols.append(F.stddev_samp(input).alias(agg["name"])) + return F.stddev_samp(input).alias(agg["name"]) if aggregator["name"] in {"sum_int", "sum_float"}: - agg_cols.append(F.sum(input).alias(agg["name"])) + return F.sum(input).alias(agg["name"]) if aggregator["name"] in {"sum_distinct_int", "sum_distinct_float"}: - agg_cols.append(F.sumDistinct(input).alias(agg["name"])) + return F.sumDistinct(input).alias(agg["name"]) if aggregator["name"] == "var_pop": - agg_cols.append(F.var_pop(input).alias(agg["name"])) + return F.var_pop(input).alias(agg["name"]) if aggregator["name"] == "var_samp": - agg_cols.append(F.var_samp(input).alias(agg["name"])) + return F.var_samp(input).alias(agg["name"]) if aggregator["name"] == "variance": - agg_cols.append(F.variance(input).alias(agg["name"])) - - results = df.agg(*agg_cols).collect()[0].asDict() - - for agg in builtin_aggregates: - result = results[agg["name"]] - aggregator = ctx.aggregators[agg["aggregator"]] - result = util.cast_output_type(result, aggregator["output_type"]) + return F.variance(input).alias(agg["name"]) - results[agg["name"]] = result - ctx.store_aggregate_result(result, agg) + raise ValueError("missing builtin aggregator") # unexpected - return results + except CortexException as e: + e.wrap("aggregate " + agg["name"]) + raise def run_custom_aggregator(aggregate, df, ctx, spark): aggregator = ctx.aggregators[aggregate["aggregator"]] aggregator_impl, _ = ctx.get_aggregator_impl(aggregate["name"]) - input = ctx.populate_values(aggregate["input"], aggregator["input"], preserve_column_refs=False) + + try: + input = ctx.populate_values( + aggregate["input"], aggregator["input"], preserve_column_refs=False + ) + except CortexException as e: + e.wrap("aggregate " + aggregate["name"], "input") + raise try: result = aggregator_impl.aggregate_spark(df, input) @@ -522,9 +542,14 @@ def execute_transform_spark(column_name, df, ctx, spark): spark.sparkContext.addPyFile(trans_impl_path) # Executor pods need this because of the UDF ctx.spark_uploaded_impls[trans_impl_path] = True - input = ctx.populate_values( - transformed_column["input"], transformer["input"], preserve_column_refs=False - ) + try: + input = ctx.populate_values( + transformed_column["input"], transformer["input"], preserve_column_refs=False + ) + except CortexException as e: + e.wrap("input") + raise + try: return trans_impl.transform_spark(df, input, column_name) except Exception as e: @@ -537,9 +562,14 @@ def execute_transform_python(column_name, df, ctx, spark, validate=False): transformer = ctx.transformers[transformed_column["transformer"]] input_cols_sorted = sorted(ctx.extract_column_names(transformed_column["input"])) - input = ctx.populate_values( - transformed_column["input"], transformer["input"], preserve_column_refs=True - ) + + try: + input = ctx.populate_values( + transformed_column["input"], transformer["input"], preserve_column_refs=True + ) + except CortexException as e: + e.wrap("input") + raise if trans_impl_path not in ctx.spark_uploaded_impls: spark.sparkContext.addPyFile(trans_impl_path) # Executor pods need this because of the UDF @@ -599,9 +629,13 @@ def validate_transformer(column_name, test_df, ctx, spark): if transformer["output_type"] == consts.COLUMN_TYPE_INFERRED: sample_df = test_df.collect() sample = sample_df[0] - input = ctx.populate_values( - transformed_column["input"], transformer["input"], preserve_column_refs=True - ) + try: + input = ctx.populate_values( + transformed_column["input"], transformer["input"], preserve_column_refs=True + ) + except CortexException as e: + e.wrap("input") + raise transformer_input = create_transformer_inputs_from_map(input, sample) initial_transformed_value = trans_impl.transform_python(transformer_input) inferred_python_type = infer_python_type(initial_transformed_value) @@ -705,11 +739,10 @@ def validate_transformer(column_name, test_df, ctx, spark): "a column besides {} was modifed in the output dataframe".format(column_name) ) except CortexException as e: - e.wrap( + raise UserRuntimeException( "transformed column " + column_name, transformed_column["transformer"] + ".transform_spark", - ) - raise + ) from e if hasattr(trans_impl, "transform_spark") and hasattr(trans_impl, "transform_python"): if ( @@ -756,15 +789,27 @@ def transform_column(column_name, df, ctx, spark): trans_impl, _ = ctx.get_transformer_impl(column_name) if hasattr(trans_impl, "transform_spark"): - df = execute_transform_spark(column_name, df, ctx, spark) - return df.withColumn( - column_name, - F.col(column_name).cast( - CORTEX_TYPE_TO_SPARK_TYPE[ctx.get_inferred_column_type(column_name)] - ), - ) + try: + df = execute_transform_spark(column_name, df, ctx, spark) + return df.withColumn( + column_name, + F.col(column_name).cast( + CORTEX_TYPE_TO_SPARK_TYPE[ctx.get_inferred_column_type(column_name)] + ), + ) + except CortexException as e: + raise UserRuntimeException( + "transformed column " + column_name, + transformed_column["transformer"] + ".transform_spark", + ) from e elif hasattr(trans_impl, "transform_python"): - return execute_transform_python(column_name, df, ctx, spark) + try: + return execute_transform_python(column_name, df, ctx, spark) + except Exception as e: + raise UserRuntimeException( + "transformed column " + column_name, + transformed_column["transformer"] + ".transform_python", + ) from e else: raise UserException( "transformed column " + column_name, From 8f9e56793551fcb9e7ffe52932e35cb4ff0d5d35 Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Thu, 13 Jun 2019 01:05:43 -0700 Subject: [PATCH 40/44] Cleanup --- pkg/workloads/lib/context.py | 29 ++++++++++++++--------------- pkg/workloads/lib/util.py | 7 ++----- 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index e2730161fb..50f9e94d86 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -70,16 +70,16 @@ def __init__(self, **kwargs): self.status_prefix = self.ctx["status_prefix"] self.app = self.ctx["app"] self.environment = self.ctx["environment"] - self.python_packages = self.ctx["python_packages"] - self.raw_columns = self.ctx["raw_columns"] - self.transformed_columns = self.ctx["transformed_columns"] - self.transformers = self.ctx["transformers"] - self.aggregators = self.ctx["aggregators"] - self.aggregates = self.ctx["aggregates"] - self.constants = self.ctx["constants"] - self.models = self.ctx["models"] - self.estimators = self.ctx["estimators"] - self.apis = self.ctx["apis"] + self.python_packages = self.ctx["python_packages"] or {} + self.raw_columns = self.ctx["raw_columns"] or {} + self.transformed_columns = self.ctx["transformed_columns"] or {} + self.transformers = self.ctx["transformers"] or {} + self.aggregators = self.ctx["aggregators"] or {} + self.aggregates = self.ctx["aggregates"] or {} + self.constants = self.ctx["constants"] or {} + self.models = self.ctx["models"] or {} + self.estimators = self.ctx["estimators"] or {} + self.apis = self.ctx["apis"] or {} self.training_datasets = {k: v["dataset"] for k, v in self.models.items()} self.api_version = self.cortex_config["api_version"] @@ -99,12 +99,11 @@ def __init__(self, **kwargs): ) ) - if self.environment is not None: - self.columns = util.merge_dicts_overwrite(self.raw_columns, self.transformed_columns) + self.columns = util.merge_dicts_overwrite(self.raw_columns, self.transformed_columns) - self.raw_column_names = list(self.raw_columns.keys()) - self.transformed_column_names = list(self.transformed_columns.keys()) - self.column_names = list(self.columns.keys()) + self.raw_column_names = list(self.raw_columns.keys()) + self.transformed_column_names = list(self.transformed_columns.keys()) + self.column_names = list(self.columns.keys()) # Internal caches self._transformer_impls = {} diff --git a/pkg/workloads/lib/util.py b/pkg/workloads/lib/util.py index 77103a0994..ac142fb560 100644 --- a/pkg/workloads/lib/util.py +++ b/pkg/workloads/lib/util.py @@ -422,13 +422,10 @@ def merge_dicts_no_overwrite(*dicts): def merge_two_dicts_in_place_overwrite(x, y): """Merge y into x, with overwriting. x is updated in place""" if x is None: - return y + x = {} if y is None: - return x - - if y is None and x is None: - return None + y = {} for k, v in y.items(): if k in x and isinstance(x[k], dict) and isinstance(y[k], collections.Mapping): From 152f39e1cf79702fef3a21c0adfbbb114cf24495 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Thu, 13 Jun 2019 14:53:31 -0400 Subject: [PATCH 41/44] Add region to external models (#161) --- docs/applications/advanced/external-models.md | 6 +- docs/applications/resources/apis.md | 4 +- examples/external-model/app.yaml | 4 +- pkg/lib/aws/s3.go | 35 ++++++++++- pkg/operator/api/userconfig/apis.go | 59 +++++++++++++++---- pkg/operator/api/userconfig/config_key.go | 5 +- pkg/operator/api/userconfig/errors.go | 11 +++- pkg/operator/context/apis.go | 7 ++- pkg/workloads/tf_api/api.py | 2 +- 9 files changed, 109 insertions(+), 24 deletions(-) diff --git a/docs/applications/advanced/external-models.md b/docs/applications/advanced/external-models.md index d82abdb466..ea767666d0 100644 --- a/docs/applications/advanced/external-models.md +++ b/docs/applications/advanced/external-models.md @@ -17,12 +17,14 @@ $ zip -r model.zip export/estimator $ aws s3 cp model.zip s3://your-bucket/model.zip ``` -3. Specify `model_path` in an API, e.g. +3. Specify `external_model` in an API, e.g. ```yaml - kind: api name: my-api - model_path: s3://your-bucket/model.zip + external_model: + path: s3://your-bucket/model.zip + region: us-west-2 compute: replicas: 5 gpu: 1 diff --git a/docs/applications/resources/apis.md b/docs/applications/resources/apis.md index 29c8ac0638..a8260b7cb9 100644 --- a/docs/applications/resources/apis.md +++ b/docs/applications/resources/apis.md @@ -8,7 +8,9 @@ Serve models at scale and use them to build smarter applications. - kind: api # (required) name: # API name (required) model_name: # name of a Cortex model (required) - model_path: # path to a zipped model dir (optional) + external_model: + path: # path to a zipped model dir (optional) + region: # region of external model compute: replicas: # number of replicas to launch (default: 1) cpu: # CPU request (default: Null) diff --git a/examples/external-model/app.yaml b/examples/external-model/app.yaml index a9c60cd5ed..918656db7a 100644 --- a/examples/external-model/app.yaml +++ b/examples/external-model/app.yaml @@ -3,6 +3,8 @@ - kind: api name: iris - model_path: s3://cortex-examples/iris-model.zip + external_model: + path: s3://cortex-examples/iris-model.zip + region: us-west-2 compute: replicas: 1 diff --git a/pkg/lib/aws/s3.go b/pkg/lib/aws/s3.go index 4bcd6a7c4f..51211c0a02 100644 --- a/pkg/lib/aws/s3.go +++ b/pkg/lib/aws/s3.go @@ -265,7 +265,19 @@ func SplitS3aPath(s3aPath string) (string, string, error) { if !IsValidS3aPath(s3aPath) { return "", "", ErrorInvalidS3aPath(s3aPath) } - fullPath := s3aPath[6:] + fullPath := s3aPath[len("s3a://"):] + slashIndex := strings.Index(fullPath, "/") + bucket := fullPath[0:slashIndex] + key := fullPath[slashIndex+1:] + + return bucket, key, nil +} + +func SplitS3Path(s3Path string) (string, string, error) { + if !IsValidS3Path(s3Path) { + return "", "", ErrorInvalidS3aPath(s3Path) + } + fullPath := s3Path[len("s3://"):] slashIndex := strings.Index(fullPath, "/") bucket := fullPath[0:slashIndex] key := fullPath[slashIndex+1:] @@ -291,6 +303,27 @@ func IsS3PrefixExternal(bucket string, prefix string, region string) (bool, erro return hasPrefix, nil } +func IsS3FileExternal(bucket string, key string, region string) (bool, error) { + sess := session.Must(session.NewSession(&aws.Config{ + Region: aws.String(region), + })) + + _, err := s3.New(sess).HeadObject(&s3.HeadObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + }) + + if IsNotFoundErr(err) { + return false, nil + } + + if err != nil { + return false, errors.Wrap(err, key) + } + + return true, nil +} + func IsS3aPrefixExternal(s3aPath string, region string) (bool, error) { bucket, prefix, err := SplitS3aPath(s3aPath) if err != nil { diff --git a/pkg/operator/api/userconfig/apis.go b/pkg/operator/api/userconfig/apis.go index fbc960d628..87048dd85c 100644 --- a/pkg/operator/api/userconfig/apis.go +++ b/pkg/operator/api/userconfig/apis.go @@ -17,6 +17,7 @@ limitations under the License. package userconfig import ( + "github.com/cortexlabs/cortex/pkg/lib/aws" cr "github.com/cortexlabs/cortex/pkg/lib/configreader" "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/operator/api/resource" @@ -26,10 +27,10 @@ type APIs []*API type API struct { ResourceFields - Model *string `json:"model" yaml:"model"` - ModelPath *string `json:"model_path" yaml:"model_path"` - Compute *APICompute `json:"compute" yaml:"compute"` - Tags Tags `json:"tags" yaml:"tags"` + Model *string `json:"model" yaml:"model"` + ExternalModel *ExternalModel `json:"external_model" yaml:"external_model"` + Compute *APICompute `json:"compute" yaml:"compute"` + Tags Tags `json:"tags" yaml:"tags"` } var apiValidation = &cr.StructValidation{ @@ -48,10 +49,8 @@ var apiValidation = &cr.StructValidation{ }, }, { - StructField: "ModelPath", - StringPtrValidation: &cr.StringPtrValidation{ - Validator: cr.GetS3PathValidator(), - }, + StructField: "ExternalModel", + StructValidation: externalModelFieldValidation, }, apiComputeFieldValidation, tagsFieldValidation, @@ -59,6 +58,31 @@ var apiValidation = &cr.StructValidation{ }, } +type ExternalModel struct { + Path string `json:"path" yaml:"path"` + Region string `json:"region" yaml:"region"` +} + +var externalModelFieldValidation = &cr.StructValidation{ + DefaultNil: true, + StructFieldValidations: []*cr.StructFieldValidation{ + { + StructField: "Path", + StringValidation: &cr.StringValidation{ + Validator: cr.GetS3PathValidator(), + Required: true, + }, + }, + { + StructField: "Region", + StringValidation: &cr.StringValidation{ + Default: aws.DefaultS3Region, + AllowedValues: aws.S3Regions.Slice(), + }, + }, + }, +} + func (apis APIs) Validate() error { for _, api := range apis { if err := api.Validate(); err != nil { @@ -80,12 +104,23 @@ func (apis APIs) Validate() error { } func (api *API) Validate() error { - if api.ModelPath == nil && api.Model == nil { - return errors.Wrap(ErrorSpecifyOnlyOneMissing("model_name", "model_path"), Identify(api)) + if api.ExternalModel == nil && api.Model == nil { + return errors.Wrap(ErrorSpecifyOnlyOneMissing(ModelKey, ExternalModelKey), Identify(api)) } - if api.ModelPath != nil && api.Model != nil { - return errors.Wrap(ErrorSpecifyOnlyOne("model_name", "model_path"), Identify(api)) + if api.ExternalModel != nil && api.Model != nil { + return errors.Wrap(ErrorSpecifyOnlyOne(ModelKey, ExternalModelKey), Identify(api)) + } + + if api.ExternalModel != nil { + bucket, key, err := aws.SplitS3Path(api.ExternalModel.Path) + if err != nil { + return errors.Wrap(err, Identify(api), ExternalModelKey, PathKey) + } + + if ok, err := aws.IsS3FileExternal(bucket, key, api.ExternalModel.Region); err != nil || !ok { + return errors.Wrap(ErrorExternalModelNotFound(api.ExternalModel.Path), Identify(api), ExternalModelKey, PathKey) + } } return nil diff --git a/pkg/operator/api/userconfig/config_key.go b/pkg/operator/api/userconfig/config_key.go index b5a5830e20..f4ea733b48 100644 --- a/pkg/operator/api/userconfig/config_key.go +++ b/pkg/operator/api/userconfig/config_key.go @@ -93,6 +93,7 @@ const ( DatasetComputeKey = "dataset_compute" // API - ModelKey = "model" - ModelNameKey = "model_name" + ModelKey = "model" + ModelNameKey = "model_name" + ExternalModelKey = "external_model" ) diff --git a/pkg/operator/api/userconfig/errors.go b/pkg/operator/api/userconfig/errors.go index 19b1416996..87334d0fa0 100644 --- a/pkg/operator/api/userconfig/errors.go +++ b/pkg/operator/api/userconfig/errors.go @@ -75,6 +75,7 @@ const ( ErrEnvSchemaMismatch ErrExtraResourcesWithExternalAPIs ErrImplDoesNotExist + ErrExternalModelNotFound ) var errorKinds = []string{ @@ -124,9 +125,10 @@ var errorKinds = []string{ "err_env_schema_mismatch", "err_extra_resources_with_external_a_p_is", "err_impl_does_not_exist", + "err_external_model_not_found", } -var _ = [1]int{}[int(ErrImplDoesNotExist)-(len(errorKinds)-1)] // Ensure list length matches +var _ = [1]int{}[int(ErrExternalModelNotFound)-(len(errorKinds)-1)] // Ensure list length matches func (t ErrorKind) String() string { return errorKinds[t] @@ -575,3 +577,10 @@ func ErrorImplDoesNotExist(path string) error { message: fmt.Sprintf("%s: implementation file does not exist", path), } } + +func ErrorExternalModelNotFound(path string) error { + return Error{ + Kind: ErrExternalModelNotFound, + message: fmt.Sprintf("%s: file not found or inaccessible", path), + } +} diff --git a/pkg/operator/context/apis.go b/pkg/operator/context/apis.go index c3383d49de..dbc7322bf4 100644 --- a/pkg/operator/context/apis.go +++ b/pkg/operator/context/apis.go @@ -48,10 +48,11 @@ func getAPIs(config *userconfig.Config, buf.WriteString(model.ID) } - if apiConfig.ModelPath != nil { - modelName = *apiConfig.ModelPath + if apiConfig.ExternalModel != nil { + modelName = apiConfig.ExternalModel.Path buf.WriteString(datasetVersion) - buf.WriteString(*apiConfig.ModelPath) + buf.WriteString(apiConfig.ExternalModel.Path) + buf.WriteString(apiConfig.ExternalModel.Region) } id := hash.Bytes(buf.Bytes()) diff --git a/pkg/workloads/tf_api/api.py b/pkg/workloads/tf_api/api.py index 69fe3ff7da..5135758559 100644 --- a/pkg/workloads/tf_api/api.py +++ b/pkg/workloads/tf_api/api.py @@ -402,7 +402,7 @@ def start(args): else: if not os.path.isdir(args.model_dir): - ctx.storage.download_and_unzip_external(api["model_path"], args.model_dir) + ctx.storage.download_and_unzip_external(api["external_model"]["path"], args.model_dir) channel = grpc.insecure_channel("localhost:" + str(args.tf_serve_port)) local_cache["stub"] = prediction_service_pb2_grpc.PredictionServiceStub(channel) From 566c77c01c6d703f2a0a63e7b4cab54529facaaf Mon Sep 17 00:00:00 2001 From: Omer Spillinger Date: Thu, 13 Jun 2019 12:47:17 -0700 Subject: [PATCH 42/44] Update poker example --- examples/poker/app.yaml | 43 ++++++++++++++++++++ examples/poker/implementations/models/dnn.py | 29 ------------- examples/poker/resources/apis.yaml | 5 --- examples/poker/resources/data.yaml | 6 --- examples/poker/resources/models.yaml | 34 ---------------- 5 files changed, 43 insertions(+), 74 deletions(-) delete mode 100644 examples/poker/implementations/models/dnn.py delete mode 100644 examples/poker/resources/apis.yaml delete mode 100644 examples/poker/resources/data.yaml delete mode 100644 examples/poker/resources/models.yaml diff --git a/examples/poker/app.yaml b/examples/poker/app.yaml index 2ff055269d..cd94bbce2b 100644 --- a/examples/poker/app.yaml +++ b/examples/poker/app.yaml @@ -1,2 +1,45 @@ - kind: app name: poker + +- kind: environment + name: dev + data: + type: csv + path: s3a://cortex-examples/poker.csv + schema: [@card_1_suit, @card_1_rank, @card_2_suit, @card_2_rank, @card_3_suit, @card_3_rank, @card_4_suit, @card_4_rank, @card_5_suit, @card_5_rank, @class] + +- kind: model + name: dnn + estimator: cortex.dnn_classifier + target_column: @class + input: + num_classes: 10 + categorical_columns_with_vocab: + - col: @card_1_suit + vocab: [1, 2, 3, 4] + - col: @card_1_rank + vocab: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + - col: @card_2_suit + vocab: [1, 2, 3, 4] + - col: @card_2_rank + vocab: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + - col: @card_3_suit + vocab: [1, 2, 3, 4] + - col: @card_3_rank + vocab: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + - col: @card_4_suit + vocab: [1, 2, 3, 4] + - col: @card_4_rank + vocab: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + - col: @card_5_suit + vocab: [1, 2, 3, 4] + - col: @card_5_rank + vocab: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + hparams: + hidden_units: [100, 100, 100, 100, 100] + training: + num_steps: 20000 + +- kind: api + name: hand + model: @dnn diff --git a/examples/poker/implementations/models/dnn.py b/examples/poker/implementations/models/dnn.py deleted file mode 100644 index 5da788bb98..0000000000 --- a/examples/poker/implementations/models/dnn.py +++ /dev/null @@ -1,29 +0,0 @@ -import tensorflow as tf - - -def create_estimator(run_config, model_config): - feature_columns = [] - suits = [1, 2, 3, 4] - ranks = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] - - for feature_column in model_config["feature_columns"]: - if feature_column["tags"]["type"] == "suit": - categorical_column = tf.feature_column.categorical_column_with_vocabulary_list( - feature_column["name"], suits - ) - indicator_column = tf.feature_column.indicator_column(categorical_column) - feature_columns.append(indicator_column) - - elif feature_column["tags"]["type"] == "rank": - categorical_column = tf.feature_column.categorical_column_with_vocabulary_list( - feature_column["name"], ranks - ) - indicator_column = tf.feature_column.indicator_column(categorical_column) - feature_columns.append(indicator_column) - - return tf.estimator.DNNClassifier( - feature_columns=feature_columns, - hidden_units=model_config["hparams"]["hidden_units"], - n_classes=10, - config=run_config, - ) diff --git a/examples/poker/resources/apis.yaml b/examples/poker/resources/apis.yaml deleted file mode 100644 index b3a9458458..0000000000 --- a/examples/poker/resources/apis.yaml +++ /dev/null @@ -1,5 +0,0 @@ -- kind: api - name: hand - model: @dnn - compute: - replicas: 1 diff --git a/examples/poker/resources/data.yaml b/examples/poker/resources/data.yaml deleted file mode 100644 index 743a525a4e..0000000000 --- a/examples/poker/resources/data.yaml +++ /dev/null @@ -1,6 +0,0 @@ -- kind: environment - name: dev - data: - type: csv - path: s3a://cortex-examples/poker.csv - schema: [@card_1_suit, @card_1_rank, @card_2_suit, @card_2_rank, @card_3_suit, @card_3_rank, @card_4_suit, @card_4_rank, @card_5_suit, @card_5_rank, @class] diff --git a/examples/poker/resources/models.yaml b/examples/poker/resources/models.yaml deleted file mode 100644 index 87c6101087..0000000000 --- a/examples/poker/resources/models.yaml +++ /dev/null @@ -1,34 +0,0 @@ -- kind: model - name: dnn - estimator: cortex.dnn_classifier - target_column: @class - input: - num_classes: 10 - categorical_columns_with_vocab: - - col: @card_1_suit - vocab: [1, 2, 3, 4] - - col: @card_1_rank - vocab: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] - - col: @card_2_suit - vocab: [1, 2, 3, 4] - - col: @card_2_rank - vocab: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] - - col: @card_3_suit - vocab: [1, 2, 3, 4] - - col: @card_3_rank - vocab: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] - - col: @card_4_suit - vocab: [1, 2, 3, 4] - - col: @card_4_rank - vocab: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] - - col: @card_5_suit - vocab: [1, 2, 3, 4] - - col: @card_5_rank - vocab: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] - hparams: - hidden_units: [100, 100, 100, 100, 100] - data_partition_ratio: - training: 0.9 - evaluation: 0.1 - training: - num_steps: 20000 From e26a255587a1a7ef60871917426ea604cf37f802 Mon Sep 17 00:00:00 2001 From: David Eliahu Date: Thu, 13 Jun 2019 12:54:57 -0700 Subject: [PATCH 43/44] Update docs --- cli/cmd/init.go | 176 ++++++++++-------- docs/applications/advanced/python-packages.md | 4 +- docs/applications/advanced/templates.md | 25 +-- .../implementations/aggregators.md | 18 +- .../{models.md => estimators.md} | 32 ++-- .../implementations/transformers.md | 47 +++-- docs/applications/resources/aggregates.md | 36 ++-- docs/applications/resources/aggregators.md | 27 +-- docs/applications/resources/apis.md | 11 +- docs/applications/resources/app.md | 2 +- docs/applications/resources/constants.md | 14 +- docs/applications/resources/data-types.md | 138 ++++++++++---- docs/applications/resources/environments.md | 22 +-- docs/applications/resources/estimators.md | 38 ++++ docs/applications/resources/models.md | 39 ++-- docs/applications/resources/overview.md | 2 + docs/applications/resources/raw-columns.md | 9 - .../resources/transformed-columns.md | 44 ++--- docs/applications/resources/transformers.md | 24 +-- 19 files changed, 374 insertions(+), 334 deletions(-) rename docs/applications/implementations/{models.md => estimators.md} (67%) create mode 100644 docs/applications/resources/estimators.md diff --git a/cli/cmd/init.go b/cli/cmd/init.go index 6fd4192cc0..84a0296ebb 100644 --- a/cli/cmd/init.go +++ b/cli/cmd/init.go @@ -91,10 +91,10 @@ func appInitFiles(appName string) map[string]string { # csv_config: # header: true # schema: -# - column1 -# - column2 -# - column3 -# - label +# - @column1 +# - @column2 +# - @column3 +# - @label `, "resources/raw_columns.yaml": `## Sample raw columns: @@ -125,11 +125,9 @@ func appInitFiles(appName string) map[string]string { # - kind: aggregate # name: column1_bucket_boundaries # aggregator: cortex.bucket_boundaries -# inputs: -# columns: -# col: column1 -# args: -# num_buckets: 3 +# input: +# col: @column1 +# num_buckets: 3 `, "resources/transformed_columns.yaml": `## Sample transformed columns: @@ -137,33 +135,27 @@ func appInitFiles(appName string) map[string]string { # - kind: transformed_column # name: column1_bucketized # transformer: cortex.bucketize # Cortex provided transformer in pkg/transformers -# inputs: -# columns: -# num: column1 -# args: -# bucket_boundaries: column2_bucket_boundaries +# input: +# num: @column1 +# bucket_boundaries: @column2_bucket_boundaries # # - kind: transformed_column # name: column2_transformed -# transformer: my_transformer # Your own custom transformer from the transformers folder +# transformer: my_transformer # Your own custom transformer # inputs: -# columns: -# num: column2 -# args: -# arg1: 10 -# arg2: 100 +# col: @column2 +# arg1: 10 +# arg2: 100 `, "resources/models.yaml": `## Sample model: # # - kind: model -# name: my_model -# type: classification -# target_column: label -# feature_columns: -# - column1 -# - column2 -# - column3 +# name: dnn +# estimator: cortex.dnn_classifier +# target_column: @class +# input: +# numeric_columns: [@column1, @column2, @column3] # hparams: # hidden_units: [4, 2] # data_partition_ratio: @@ -178,7 +170,7 @@ func appInitFiles(appName string) map[string]string { # # - kind: api # name: my-api -# model_name: my_model +# model: @my_model # compute: # replicas: 1 `, @@ -204,10 +196,11 @@ def create_estimator(run_config, model_config): run_config: An instance of tf.estimator.RunConfig to be used when creating the estimator. - model_config: The Cortex configuration for the model. - Note: nested resources are expanded (e.g. model_config["target_column"]) - will be the configuration for the target column, rather than the - name of the target column). + model_config: The Cortex configuration for the model. Column references in all + inputs (i.e. model_config["target_column"], model_config["input"], and + model_config["training_input"]) are replaced by their names (e.g. "@column1" + will be replaced with "column1"). All other resource references (e.g. constants + and aggregates) are replaced by their runtime values. Returns: An instance of tf.estimator.Estimator to train the model. @@ -215,15 +208,13 @@ def create_estimator(run_config, model_config): ## Sample create_estimator implementation: # - # feature_columns = [ - # tf.feature_column.numeric_column("column1"), - # tf.feature_column.indicator_column( - # tf.feature_column.categorical_column_with_identity("column2", num_buckets=3) - # ), - # ] + # feature_columns = [] + # for col_name in model_config["input"]["numeric_columns"]: + # feature_columns.append(tf.feature_column.numeric_column(col_name)) # - # return tf.estimator.DNNRegressor( + # return tf.estimator.DNNClassifier( # feature_columns=feature_columns, + # n_classes=model_config["input"]["num_classes"], # hidden_units=model_config["hparams"]["hidden_units"], # config=run_config, # ) @@ -235,7 +226,6 @@ def create_estimator(run_config, model_config): # # - kind: constant # name: my_constant -# type: [INT] # value: [0, 50, 100] `, @@ -244,14 +234,12 @@ def create_estimator(run_config, model_config): # - kind: aggregator # name: my_aggregator # output_type: [FLOAT] -# inputs: -# columns: -# column1: FLOAT_COLUMN|INT_COLUMN -# args: -# arg1: INT +# input: +# column1: FLOAT_COLUMN|INT_COLUMN +# arg1: INT `, - "implementations/aggregators/my_aggregator.py": `def aggregate_spark(data, columns, args): + "implementations/aggregators/my_aggregator.py": `def aggregate_spark(data, input): """Aggregate a column in a PySpark context. This function is required. @@ -259,15 +247,13 @@ def create_estimator(run_config, model_config): Args: data: A dataframe including all of the raw columns. - columns: A dict with the same structure as the aggregator's input - columns specifying the names of the dataframe's columns that - contain the input columns. - - args: A dict with the same structure as the aggregator's input args - containing the runtime values of the args. + input: The aggregate's input object. Column references in the input are + replaced by their names (e.g. "@column1" will be replaced with "column1"), + and all other resource references (e.g. constants) are replaced by their + runtime values. Returns: - Any json-serializable object that matches the data type of the aggregator. + Any serializable object that matches the output type of the aggregator. """ ## Sample aggregate_spark implementation: @@ -275,7 +261,7 @@ def create_estimator(run_config, model_config): # from pyspark.ml.feature import QuantileDiscretizer # # discretizer = QuantileDiscretizer( - # numBuckets=args["num_buckets"], inputCol=columns["col"], outputCol="_" + # numBuckets=input["num_buckets"], inputCol=input["col"], outputCol="_" # ).fit(data) # # return discretizer.getSplits() @@ -288,15 +274,13 @@ def create_estimator(run_config, model_config): # - kind: transformer # name: my_transformer # output_type: INT_COLUMN -# inputs: -# columns: -# column1: INT_COLUMN|FLOAT_COLUMN -# args: -# arg1: FLOAT -# arg2: FLOAT +# input: +# column1: INT_COLUMN|FLOAT_COLUMN +# arg1: FLOAT +# arg2: FLOAT `, - "implementations/transformers/my_transformer.py": `def transform_spark(data, columns, args, transformed_column_name): + "implementations/transformers/my_transformer.py": `def transform_spark(data, input, transformed_column_name): """Transform a column in a PySpark context. This function is optional (recommended for large-scale data processing). @@ -304,12 +288,10 @@ def create_estimator(run_config, model_config): Args: data: A dataframe including all of the raw columns. - columns: A dict with the same structure as the transformer's input - columns specifying the names of the dataframe's columns that - contain the input columns. - - args: A dict with the same structure as the transformer's input args - containing the runtime values of the args. + input: The transformed column's input object. Column references in the input are + replaced by their names (e.g. "@column1" will be replaced with "column1"), + and all other resource references (e.g. constants and aggregates) are replaced + by their runtime values. transformed_column_name: The name of the column containing the transformed data that is to be appended to the dataframe. @@ -322,23 +304,22 @@ def create_estimator(run_config, model_config): ## Sample transform_spark implementation: # # return data.withColumn( - # transformed_column_name, ((data[columns["num"]] - args["mean"]) / args["stddev"]) + # transformed_column_name, ((data[input["col"]] - input["mean"]) / input["stddev"]) # ) pass -def transform_python(sample, args): +def transform_python(input): """Transform a single data sample outside of a PySpark context. - This function is required. + This function is required for any columns that are used during inference. Args: - sample: A dict with the same structure as the transformer's input - columns containing a data sample to transform. - - args: A dict with the same structure as the transformer's input args - containing the runtime values of the args. + input: The transformed column's input object. Column references in the input are + replaced by their values in the sample (e.g. "@column1" will be replaced with + the value for column1), and all other resource references (e.g. constants and + aggregates) are replaced by their runtime values. Returns: The transformed value. @@ -346,12 +327,12 @@ def transform_python(sample, args): ## Sample transform_python implementation: # - # return (sample["num"] - args["mean"]) / args["stddev"] + # return (input["col"] - input["mean"]) / input["stddev"] pass -def reverse_transform_python(transformed_value, args): +def reverse_transform_python(transformed_value, input): """Reverse transform a single data sample outside of a PySpark context. This function is optional, and only relevant for certain one-to-one @@ -360,8 +341,10 @@ def reverse_transform_python(transformed_value, args): Args: transformed_value: The transformed data value. - args: A dict with the same structure as the transformer's input args - containing the runtime values of the args. + input: The transformed column's input object. Column references in the input are + replaced by their names (e.g. "@column1" will be replaced with "column1"), + and all other resource references (e.g. constants and aggregates) are replaced + by their runtime values. Returns: The raw data value that corresponds to the transformed value. @@ -369,7 +352,40 @@ def reverse_transform_python(transformed_value, args): ## Sample reverse_transform_python implementation: # - # return args["mean"] + (transformed_value * args["stddev"]) + # return input["mean"] + (transformed_value * input["stddev"]) + + pass +`, + + "implementations/estimators/my_estimator.py": `def create_estimator(run_config, model_config): + """Create an estimator to train the model. + + Args: + run_config: An instance of tf.estimator.RunConfig to be used when creating + the estimator. + + model_config: The Cortex configuration for the model. Column references in all + inputs (i.e. model_config["target_column"], model_config["input"], and + model_config["training_input"]) are replaced by their names (e.g. "@column1" + will be replaced with "column1"). All other resource references (e.g. constants + and aggregates) are replaced by their runtime values. + + Returns: + An instance of tf.estimator.Estimator to train the model. + """ + + ## Sample create_estimator implementation: + # + # feature_columns = [] + # for col_name in model_config["input"]["numeric_columns"]: + # feature_columns.append(tf.feature_column.numeric_column(col_name)) + # + # return tf.estimator.DNNClassifier( + # feature_columns=feature_columns, + # n_classes=model_config["input"]["num_classes"], + # hidden_units=model_config["hparams"]["hidden_units"], + # config=run_config, + # ) pass `, diff --git a/docs/applications/advanced/python-packages.md b/docs/applications/advanced/python-packages.md index 9c7cd52a39..a4425476fb 100644 --- a/docs/applications/advanced/python-packages.md +++ b/docs/applications/advanced/python-packages.md @@ -1,10 +1,10 @@ # Python Packages -Cortex allows you to install additional Python packages that can be made available to aggregators, transformers, and models. +Cortex allows you to install additional Python packages that can be made available to aggregators, transformers, and estimators. ## PyPI Packages -Cortex looks for a `requirements.txt` file in the top level directory of the app (in the same level as `app.yaml`). All packages listed in `requirements.txt` will be made available to aggregators, transformers, and models. +Cortex looks for a `requirements.txt` file in the top level directory of the app (in the same level as `app.yaml`). All packages listed in `requirements.txt` will be made available to aggregators, transformers, and estimators. ```text ./iris/ diff --git a/docs/applications/advanced/templates.md b/docs/applications/advanced/templates.md index eb6244b898..25a0c5a544 100644 --- a/docs/applications/advanced/templates.md +++ b/docs/applications/advanced/templates.md @@ -5,15 +5,14 @@ Templates allow you to reuse resource configuration within your application. ## Config ```yaml -- kind: template # (required) +- kind: template name: # template name (required) yaml: # YAML string including named arguments enclosed by {} (required) -- kind: embed # (required) +- kind: embed template: # name of a Cortex template (required) args: : # (required) - ... ``` ## Example @@ -25,28 +24,20 @@ Templates allow you to reuse resource configuration within your application. - kind: aggregate name: {column}_mean aggregator: cortex.mean - inputs: - columns: - col: {column} + input: @{column} - kind: aggregate name: {column}_stddev aggregator: cortex.stddev - inputs: - columns: - col: {column} + input: @{column} - kind: transformed_column name: {column}_normalized - tags: - type: numeric transformer: cortex.normalize - inputs: - columns: - num: {column} - args: - mean: {column}_mean - stddev: {column}_stddev + input: + col: @{column} + mean: @{column}_mean + stddev: @{column}_stddev - kind: embed template: normalize diff --git a/docs/applications/implementations/aggregators.md b/docs/applications/implementations/aggregators.md index 611d264690..0b4e380f0f 100644 --- a/docs/applications/implementations/aggregators.md +++ b/docs/applications/implementations/aggregators.md @@ -3,7 +3,7 @@ ## Implementation ```python -def aggregate_spark(data, columns, args): +def aggregate_spark(data, input): """Aggregate a column in a PySpark context. This function is required. @@ -11,15 +11,13 @@ def aggregate_spark(data, columns, args): Args: data: A dataframe including all of the raw columns. - columns: A dict with the same structure as the aggregator's input - columns specifying the names of the dataframe's columns that - contain the input columns. - - args: A dict with the same structure as the aggregator's input args - containing the values of the args. + input: The aggregate's input object. Column references in the input are + replaced by their names (e.g. "@column1" will be replaced with "column1"), + and all other resource references (e.g. constants) are replaced by their + runtime values. Returns: - Any json-serializable object that matches the data type of the aggregator. + Any serializable object that matches the output type of the aggregator. """ pass ``` @@ -27,11 +25,11 @@ def aggregate_spark(data, columns, args): ## Example ```python -def aggregate_spark(data, columns, args): +def aggregate_spark(data, input): from pyspark.ml.feature import QuantileDiscretizer discretizer = QuantileDiscretizer( - numBuckets=args["num_buckets"], inputCol=columns["col"], outputCol="_" + numBuckets=input["num_buckets"], inputCol=input["col"], outputCol="_" ).fit(data) return discretizer.getSplits() diff --git a/docs/applications/implementations/models.md b/docs/applications/implementations/estimators.md similarity index 67% rename from docs/applications/implementations/models.md rename to docs/applications/implementations/estimators.md index afa601407f..0df01604ba 100644 --- a/docs/applications/implementations/models.md +++ b/docs/applications/implementations/estimators.md @@ -1,4 +1,4 @@ -# Models +# Estimators Cortex can train any model that implements the TensorFlow Estimator API. Models can be trained using any subset of the raw and transformed columns. @@ -14,10 +14,11 @@ def create_estimator(run_config, model_config): run_config: An instance of tf.estimator.RunConfig to be used when creating the estimator. - model_config: The Cortex configuration for the model. - Note: nested resources are expanded (e.g. model_config["target_column"]) - will be the configuration for the target column, rather than the - name of the target column). + model_config: The Cortex configuration for the model. Column references in all + inputs (i.e. model_config["target_column"], model_config["input"], and + model_config["training_input"]) are replaced by their names (e.g. "@column1" + will be replaced with "column1"). All other resource references (e.g. constants + and aggregates) are replaced by their runtime values. Returns: An instance of tf.estimator.Estimator to train the model. @@ -31,17 +32,14 @@ def create_estimator(run_config, model_config): import tensorflow as tf def create_estimator(run_config, model_config): - feature_columns = [ - tf.feature_column.numeric_column("sepal_length_normalized"), - tf.feature_column.numeric_column("sepal_width_normalized"), - tf.feature_column.numeric_column("petal_length_normalized"), - tf.feature_column.numeric_column("petal_width_normalized"), - ] + feature_columns = [] + for col_name in model_config["input"]["numeric_columns"]: + feature_columns.append(tf.feature_column.numeric_column(col_name)) return tf.estimator.DNNClassifier( feature_columns=feature_columns, + n_classes=model_config["input"]["num_classes"], hidden_units=model_config["hparams"]["hidden_units"], - n_classes=len(model_config["aggregates"]["class_index"]), config=run_config, ) ``` @@ -78,11 +76,11 @@ def transform_tensorflow(features, labels, model_config): labels: The label tensor. - model_config: The Cortex configuration for the model. - Note: nested resources are expanded (e.g. model_config["target_column"]) - will be the configuration for the target column, rather than the - name of the target column). - + model_config: The Cortex configuration for the model. Column references in all + inputs (i.e. model_config["target_column"], model_config["input"], and + model_config["training_input"]) are replaced by their names (e.g. "@column1" + will be replaced with "column1"). All other resource references (e.g. constants + and aggregates) are replaced by their runtime values. Returns: features and labels tensors. diff --git a/docs/applications/implementations/transformers.md b/docs/applications/implementations/transformers.md index 64a4ca91af..aa274de36c 100644 --- a/docs/applications/implementations/transformers.md +++ b/docs/applications/implementations/transformers.md @@ -1,11 +1,11 @@ # Transformers -Transformers run both when transforming data before model training and when responding to prediction requests. You may define transformers for both a PySpark and a Python context. The PySpark implementation is optional but recommended for large-scale data processing. +Transformers run when transforming data before model training and when responding to prediction requests. You may define transformers for both a PySpark and a Python context. The PySpark implementation is optional but recommended for large-scale data processing. ## Implementation ```python -def transform_spark(data, columns, args, transformed_column_name): +def transform_spark(data, input, transformed_column_name): """Transform a column in a PySpark context. This function is optional (recommended for large-scale data processing). @@ -13,12 +13,10 @@ def transform_spark(data, columns, args, transformed_column_name): Args: data: A dataframe including all of the raw columns. - columns: A dict with the same structure as the transformer's input - columns specifying the names of the dataframe's columns that - contain the input columns. - - args: A dict with the same structure as the transformer's input args - containing the runtime values of the args. + input: The transformed column's input object. Column references in the input are + replaced by their names (e.g. "@column1" will be replaced with "column1"), + and all other resource references (e.g. constants and aggregates) are replaced + by their runtime values. transformed_column_name: The name of the column containing the transformed data that is to be appended to the dataframe. @@ -30,17 +28,16 @@ def transform_spark(data, columns, args, transformed_column_name): pass -def transform_python(sample, args): +def transform_python(input): """Transform a single data sample outside of a PySpark context. - This function is required. + This function is required for any columns that are used during inference. Args: - sample: A dict with the same structure as the transformer's input - columns containing a data sample to transform. - - args: A dict with the same structure as the transformer's input args - containing the runtime values of the args. + input: The transformed column's input object. Column references in the input are + replaced by their values in the sample (e.g. "@column1" will be replaced with + the value for column1), and all other resource references (e.g. constants and + aggregates) are replaced by their runtime values. Returns: The transformed value. @@ -48,7 +45,7 @@ def transform_python(sample, args): pass -def reverse_transform_python(transformed_value, args): +def reverse_transform_python(transformed_value, input): """Reverse transform a single data sample outside of a PySpark context. This function is optional, and only relevant for certain one-to-one @@ -57,8 +54,10 @@ def reverse_transform_python(transformed_value, args): Args: transformed_value: The transformed data value. - args: A dict with the same structure as the transformer's input args - containing the runtime values of the args. + input: The transformed column's input object. Column references in the input are + replaced by their names (e.g. "@column1" will be replaced with "column1"), + and all other resource references (e.g. constants and aggregates) are replaced + by their runtime values. Returns: The raw data value that corresponds to the transformed value. @@ -69,16 +68,16 @@ def reverse_transform_python(transformed_value, args): ## Example ```python -def transform_spark(data, columns, args, transformed_column_name): +def transform_spark(data, input, transformed_column_name): return data.withColumn( - transformed_column_name, ((data[columns["num"]] - args["mean"]) / args["stddev"]) + transformed_column_name, ((data[input["col"]] - input["mean"]) / input["stddev"]) ) -def transform_python(sample, args): - return (sample["num"] - args["mean"]) / args["stddev"] +def transform_python(input): + return (input["col"] - input["mean"]) / input["stddev"] -def reverse_transform_python(transformed_value, args): - return args["mean"] + (transformed_value * args["stddev"]) +def reverse_transform_python(transformed_value, input): + return input["mean"] + (transformed_value * input["stddev"]) ``` ## Pre-installed Packages diff --git a/docs/applications/resources/aggregates.md b/docs/applications/resources/aggregates.md index 2a95bb0316..8a33312d9d 100644 --- a/docs/applications/resources/aggregates.md +++ b/docs/applications/resources/aggregates.md @@ -5,16 +5,11 @@ Aggregate columns at scale. ## Config ```yaml -- kind: aggregate # (required) +- kind: aggregate name: # aggregate name (required) - aggregator: # the name of the aggregator to use (required) - inputs: - columns: - : or <[string]> # map of column input name to raw column name(s) (required) - ... - args: - : # value may be a constant or literal value (optional) - ... + aggregator: # the name of the aggregator to use (this or aggregator_path must be specified) + aggregator_path: # a path to an aggregator implementation file (this or aggregator must be specified) + input: # the input to the aggregator, which may contain references to columns and constants (e.g. @column1) (required) compute: executors: # number of spark executors (default: 1) driver_cpu: # CPU request for spark driver (default: 1) @@ -24,14 +19,9 @@ Aggregate columns at scale. executor_mem: # memory request for each spark executor (default: 500Mi) executor_mem_overhead: # off-heap (non-JVM) memory allocated to each executor (overrides mem_overhead_factor) (default: min[executor_mem * 0.4, 384Mi]) mem_overhead_factor: # the proportion of driver_mem/executor_mem which will be additionally allocated for off-heap (non-JVM) memory (default: 0.4) - tags: - : # arbitrary key/value pairs to attach to the resource (optional) - ... ``` -Note: the `columns` and `args` fields of the the aggregate must match the data types of the `columns` and `args` fields of the selected aggregator. - -Each `args` value may be the name of a constant or a literal value. Any string value will be assumed to be the name of a constant. To use a string literal as an arg, escape it with double quotes (e.g. `arg_name: "\"string literal\""`. +See [Data Types](data-types.md) for details about input values. Note: the `input` of the the aggregate must match the input type of the aggregator (if specified). See [`aggregators.yaml`](https://github.com/cortexlabs/cortex/blob/master/pkg/aggregators/aggregators.yaml) for a list of built-in aggregators. @@ -41,18 +31,14 @@ See [`aggregators.yaml`](https://github.com/cortexl - kind: aggregate name: age_bucket_boundaries aggregator: cortex.bucket_boundaries - inputs: - columns: - col: age # the name of a numeric raw column - args: - num_buckets: 5 # a value to be used as num_buckets + input: + col: @age # "age" is the name of a numeric raw column + num_buckets: 5 - kind: aggregate name: price_bucket_boundaries aggregator: cortex.bucket_boundaries - inputs: - columns: - col: price # the name of a numeric raw column - args: - num_buckets: num_buckets # the name of an INT constant + input: + col: @price # "price" is the name of a numeric raw column + num_buckets: @num_buckets # "num_buckets" is the name of an INT constant ``` diff --git a/docs/applications/resources/aggregators.md b/docs/applications/resources/aggregators.md index 2c0516f5ec..38cfe6b26f 100644 --- a/docs/applications/resources/aggregators.md +++ b/docs/applications/resources/aggregators.md @@ -1,40 +1,33 @@ # Aggregators -An aggregator converts a set of columns and arbitrary values into a single value. Each aggregator has an input schema and an output data type. The input schema is a map which specifies the name and data type of each input column and argument. Aggregators run before transformers. +An aggregator converts a set of columns and arbitrary values into a single value. Each aggregator has an input type and an output type. Aggregators run before transformers. Custom aggregators can be implemented in Python or PySpark. See the [implementation docs](../implementations/aggregators.md) for a detailed guide. ## Config ```yaml -- kind: aggregator # (required) +- kind: aggregator name: # aggregator name (required) path: # path to the implementation file, relative to the application root (default: implementations/aggregators/.py) - output_type: # output data type (required) - inputs: - columns: - : # map of column input name to column input type(s) (required) - ... - args: - : # map of arg input name to value input type(s) (optional) - ... + output_type: # the output type of the aggregator (required) + input: # the input type of the aggregator (required) ``` -See [Data Types](data-types.md) for a list of valid data types. +See [Data Types](data-types.md) for details about input and output types. ## Example ```yaml - kind: aggregator name: bucket_boundaries + path: bucket_boundaries.py output_type: [FLOAT] - inputs: - columns: - num: FLOAT_COLUMN|INT_COLUMN - args: - num_buckets: INT + input: + num: FLOAT_COLUMN|INT_COLUMN + num_buckets: INT ``` ## Built-in Aggregators -Cortex includes common aggregators that can be used out of the box (see [`aggregators.yaml`](https://github.com/cortexlabs/cortex/blob/master/pkg/aggregators/aggregators.yaml)). To use built-in aggregators, use the `cortex` namespace in the aggregator name (e.g. `cortex.normalize`). +Cortex includes common aggregators that can be used out of the box (see [`aggregators.yaml`](https://github.com/cortexlabs/cortex/blob/master/pkg/aggregators/aggregators.yaml)). To use built-in aggregators, use the `cortex` namespace in the aggregator name (e.g. `cortex.mean`). diff --git a/docs/applications/resources/apis.md b/docs/applications/resources/apis.md index 29c8ac0638..2959af7a37 100644 --- a/docs/applications/resources/apis.md +++ b/docs/applications/resources/apis.md @@ -5,18 +5,15 @@ Serve models at scale and use them to build smarter applications. ## Config ```yaml -- kind: api # (required) +- kind: api name: # API name (required) - model_name: # name of a Cortex model (required) - model_path: # path to a zipped model dir (optional) + model: # name of a Cortex model (this or model_path must be specified) TODO + model_path: # path to a zipped model dir (this or model must be specified) compute: replicas: # number of replicas to launch (default: 1) cpu: # CPU request (default: Null) mem: # memory request (default: Null) gpu: # gpu request (default: Null) - tags: - : # arbitrary key/value pairs to attach to the resource (optional) - ... ``` ## Example @@ -24,7 +21,7 @@ Serve models at scale and use them to build smarter applications. ```yaml - kind: api name: classifier - model_name: dnn + model: @dnn compute: replicas: 3 ``` diff --git a/docs/applications/resources/app.md b/docs/applications/resources/app.md index b0efecf34f..ad456ac678 100644 --- a/docs/applications/resources/app.md +++ b/docs/applications/resources/app.md @@ -5,7 +5,7 @@ The app resource is used to group a set of resources into an application that ca ## Config ```yaml -- kind: app # (required) +- kind: app name: # app name (required) ``` diff --git a/docs/applications/resources/constants.md b/docs/applications/resources/constants.md index 93be255c8b..c4a06d788c 100644 --- a/docs/applications/resources/constants.md +++ b/docs/applications/resources/constants.md @@ -1,25 +1,23 @@ # Constants -Constants represent literal values which can be used in other Cortex configuration files. They can be useful for extracting repetitive literals into a single variable. +Constants represent literal values which can be used in other Cortex resources. They can be useful for extracting repetitive literals into a single variable. ## Config ```yaml -- kind: constant # (required) +- kind: constant name: # constant name (required) - type: # the data type of the constant (required) - value: # a literal value (required) - tags: - : # arbitrary key/value pairs to attach to the resource (optional) - ... + type: # the type of the constant (optional, will be inferred from value if not specified) + value: # a literal value (required) ``` +See [Data Types](data-types.md) for details about output types and values. + ## Example ```yaml - kind: constant name: num_buckets - type: INT value: 5 - kind: constant diff --git a/docs/applications/resources/data-types.md b/docs/applications/resources/data-types.md index 35df1a5b64..cec5f8a495 100644 --- a/docs/applications/resources/data-types.md +++ b/docs/applications/resources/data-types.md @@ -1,6 +1,7 @@ # Data Types -Data types are used in config files to help validate data and ensure your Cortex application is functioning as expected. +Data types are used in configuration files to help validate data and ensure your Cortex application is functioning as expected. + ## Raw Column Types @@ -10,6 +11,7 @@ These are the valid types for raw columns: * `FLOAT_COLUMN` * `STRING_COLUMN` + ## Transformed Column Types These are the valid types for transformed columns (i.e. output types of transformers): @@ -21,60 +23,124 @@ These are the valid types for transformed columns (i.e. output types of transfor * `FLOAT_LIST_COLUMN` * `STRING_LIST_COLUMN` -## Input Column Types - -Some resources specify the types of columns that are to be used as inputs (e.g. `transformer.inputs.columns` and `aggregator.inputs.columns`). For these types, any of the column types may be used: - -* `INT_COLUMN` -* `FLOAT_COLUMN` -* `STRING_COLUMN` -* `INT_LIST_COLUMN` -* `FLOAT_LIST_COLUMN` -* `STRING_LIST_COLUMN` - -Ambiguous input types are also supported, and are represented by joining column types with `|`. For example, `INT_COLUMN|FLOAT_COLUMN` indicates that either a column of type `INT_COLUMN` or a column of type `FLOAT_COLUMN` may be used an the input. Any two or more column types may be combined in this way (e.g. `INT_COLUMN|FLOAT_COLUMN|STRING_COLUMN` is supported). All permutations of ambiguous types are valid (e.g. `INT_COLUMN|FLOAT_COLUMN` and `FLOAT_COLUMN|INT_COLUMN` are equivalent). - -In addition, an input type may be a list of columns. To denote this, use any of the supported input column types in a length-one list. For example, `[INT_COLUMN]` represents a list of integer columns; `[INT_COLUMN|FLOAT_COLUMN]` represents a list of integer or float columns. - -Note: `[INT_COLUMN]` is not equivalent to `INT_LIST_COLUMN`: the former denotes a list of integer columns, whereas the latter denotes a single column which contains a list of integers. -## Value types +## Output Types -These are valid types for all values (e.g. aggregator args, aggregator output types, transformer args, constants). +Output types are used to define the output types of aggregators and constants. There are four base scalar types: * `INT` * `FLOAT` * `STRING` * `BOOL` -As with column input types, ambiguous types are supported (e.g. `INT|FLOAT`), length-one lists of types are valid (e.g. `[STRING]`), and all permutations of ambiguous types are valid (e.g. `INT|FLOAT` is equivalent to `FLOAT|INT`). +In addition, an output type may be a list of scalar types. This is denoted by a length-one list of any of the supported types. For example, `[INT]` represents a list of integers. -In addition, maps are valid value types. There are two types of maps: maps with a single data type key, and maps with any number of arbitrary keys: +An output type may also be a map containing these types. There are two types of maps: generic maps and fixed maps. **Generic maps** represent maps which may have any number of items. The types of the keys and values must match the declared types. For example, `{STRING: INT}` fits `{"San Francisco": -7, "Toronto": -4}`. **Fixed maps** represent maps which must define values for each of the pre-defined keys. For example: `{value1: INT, value2: FLOAT}` fits `{value1: 17, value2: 8.8}`. -### Maps with a single data type key +The values in lists, generic maps, and fixed maps may be arbitrarily nested. -This represents a map which, at runtime, may have any number of items. The types of the keys and values must match the declared types. +These are all valid output types: -Example: `{STRING: INT}` +* `INT` +* `[STRING]` +* `{STRING: INT}` +* `[{STRING: INT}]` +* `{value1: INT, value2: FLOAT}` +* `{STRING: {value1: INT, value2: FLOAT}}` +* `{value1: {STRING: BOOL}, value2: [FLOAT], value2: STRING}` -Example value: `{"San Francisco": -7, "Toronto": -4}` +### Example -### Maps with arbitrary keys +Output type: -This represents a map which, at runtime, must define values for each of the pre-defined keys. +```yaml +output_type: + value1: BOOL + value2: INT|FLOAT + value3: [STRING] + value4: {INT: STRING} +``` -Example: `{"value1": INT, "value2": FLOAT}` +Output value: -Example value: `{"value1": 17, "value2": 8.8}` +```yaml +output_type: + value1: True + value2: 2.2 + value3: [test1, test2, test3] + value4: {1: test1, 2: test2} +``` -### Nested maps -The values in either of the map types may be arbitrarily nested data types. +## Input Types -## Examples +Input types are used to define the inputs to aggregators, transformers, and estimators. Typically, input types can be any combination of the column or scalar types: +* `INT_COLUMN` +* `FLOAT_COLUMN` +* `STRING_COLUMN` +* `INT_LIST_COLUMN` +* `FLOAT_LIST_COLUMN` +* `STRING_LIST_COLUMN` * `INT` -* `[FLOAT|STRING]` -* `{STRING: INT}` -* `{"value1": INT, "value2": FLOAT}` -* `{"value1": {STRING: BOOL|INT}, "value2": [FLOAT], "value2": STRING}` +* `FLOAT` +* `STRING` +* `BOOL` + +Like with output types, input types may occur within lists, generic maps, and fixed maps. + +Ambiguous input types are also supported, and are represented by joining types with `|`. For example, `INT_COLUMN|FLOAT_COLUMN` indicates that either a column of type `INT_COLUMN` or a column of type `FLOAT_COLUMN` may be used an the input. Any two or more types may be combined in this way (e.g. `INT|FLOAT|STRING` is supported). All permutations of ambiguous types are valid (e.g. `INT|FLOAT` and `FLOAT|INT` are equivalent). Column types and scalar types may not be combined (e.g. `INT|FLOAT_COLUMN` is not valid). + +### Input Type Validations + +By default, all declared inputs are required. For example, if the input type is `{value1: INT, value2: FLOAT}`, both `value1` and `value2` must be provided (and cannot be `Null`). With Cortex, it is possible to declare inputs as optional, set default values, allow values to be `Null`, and specify minimum and maximum map/list lengths. + +To specify validation options, the "long form" input schema is used. In the long form, the input type is always a map, with the `_type` key specifying the type, and other keys (which all start with `_`) specifying the options. The available options are: + +* `_optional`: If set to `True`, allows the value to be missing from the input. Only applicable to values in maps. +* `_default`: Specifies a default value to use if the value is missing from the input. Only applicable to values in maps. Setting `_defaut` implies `_optional: True`. +* `_allow_null`: If set to `True`, allows the value to be explicitly set to `Null`. +* `_min_count`: Specifies the minimum number of elements that must be in the list or map. +* `_max_count`: Specifies the maximum number of elements that must be in the list or map. + +### Example + +Short form input type: + +```yaml +input: + value1: INT_COLUMN + value2: INT|FLOAT + value3: [STRING] + value4: {INT: STRING} +``` + +Long form input type: + +```yaml +input: + value1: + _type: INT_COLUMN + _optional: True + value2: + _type: INT|FLOAT + _default: 2.2 + _allow_null: True + value3: + _type: [STRING] + _min_count: 1 + value4: + _type: {INT: STRING} + _min_count: 1 + _max_count: 100 +``` + +Input value (assuming `column1` is an `INT_COLUMN`, `constant1` is a `[STRING]`, and `aggregate1` is an `INT`): + +```yaml +input: + value1: @column1 + value2: 2.2 + value3: @constant1 + value4: {1: test1, 2: @aggregate1} +``` diff --git a/docs/applications/resources/environments.md b/docs/applications/resources/environments.md index b24f4b6497..4bdcd392b1 100644 --- a/docs/applications/resources/environments.md +++ b/docs/applications/resources/environments.md @@ -5,7 +5,7 @@ Transfer data at scale from data warehouses like S3 into the Cortex environment. ## Config ```yaml -- kind: environment # (required) +- kind: environment name: # environment name (required) limit: # specify `num_rows` or `fraction_of_rows` if using `limit` @@ -30,7 +30,7 @@ data: drop_null: # drop any rows that contain at least 1 null value (default: false) csv_config: # optional configuration that can be provided schema: - - # raw column names listed in the CSV columns' order (required) + - # raw column references listed in the CSV columns' order (required) ... ``` @@ -69,7 +69,7 @@ data: drop_null: # drop any rows that contain at least 1 null value (default: false) schema: - parquet_column_name: # name of the column in the parquet file (required) - raw_column_name: # raw column name (required) + raw_column: # raw column reference (required) ... ``` @@ -82,10 +82,10 @@ data: type: csv path: s3a://my-bucket/data.csv schema: - - column1 - - column2 - - column3 - - label + - @column1 + - @column2 + - @column3 + - @label - kind: environment name: prod @@ -94,11 +94,11 @@ data: path: s3a://my-bucket/data.parquet schema: - parquet_column_name: column1 - raw_column_name: column1 + raw_column: @column1 - parquet_column_name: column2 - raw_column_name: column2 + raw_column: @column2 - parquet_column_name: column3 - raw_column_name: column3 + raw_column: @column3 - parquet_column_name: column4 - raw_column_name: label + raw_column: @label ``` diff --git a/docs/applications/resources/estimators.md b/docs/applications/resources/estimators.md new file mode 100644 index 0000000000..4cb02092a0 --- /dev/null +++ b/docs/applications/resources/estimators.md @@ -0,0 +1,38 @@ +# Estimators + +An estimator defines how to train a model. + +Custom estimators can be implemented in Python or PySpark. See the [implementation docs](../implementations/estimators.md) for a detailed guide. + +## Config + +```yaml +- kind: estimator + name: # estimator name (required) + path: # path to the implementation file, relative to the application root (default: implementations/estimators/.py) + target_column: # The type of column that can be used as a target (ambiguous types like INT_COLUMN|FLOAT_COLUMN are supported) (required) + input: # the input type of the estimator (required) + training_input: # the input type of the training input to the estimator (optional) + hparams: # the input type of the hyperparameters to pass into the estimator, which may not contain column types (optional) + prediction_key: # key of the target value in the estimator's exported predict outputs (default: "class_ids" for INT_COLUMN and STRING_COLUMN targets, "predictions" otherwise) +``` + +See [Data Types](data-types.md) for details about input types. + +## Example + +```yaml +- kind: estimator + name: dnn_classifier + path: dnn_classifier.py + target_column: INT_COLUMN + input: + num_classes: INT + numeric_columns: [INT_COLUMN|FLOAT_COLUMN] + hparams: + hidden_units: [INT] +``` + +## Built-in Estimators + +Cortex includes common estimators that can be used out of the box (see [`estimators.yaml`](https://github.com/cortexlabs/cortex/blob/master/pkg/estimators/estimators.yaml)). To use built-in estimators, use the `cortex` namespace in the estimator name (e.g. `cortex.dnn_classifier`). diff --git a/docs/applications/resources/models.md b/docs/applications/resources/models.md index b2d067f063..b5c3bae02c 100644 --- a/docs/applications/resources/models.md +++ b/docs/applications/resources/models.md @@ -1,20 +1,19 @@ # Models -Train custom TensorFlow models at scale. +Train TensorFlow models at scale. ## Config ```yaml -- kind: # (required) +- kind: model name: # model name (required) - type: # "classification" or "regression" (required) - target_column: # the column to predict (must be an integer column for classification, or an integer or float column for regression) (required) - feature_columns: <[string]> # a list of the columns used as input for this model (required) - training_columns: <[string]> # a list of the columns used only during training (optional) - aggregates: <[string]> # a list of aggregates to pass into model training (optional) - hparams: # a map of hyperparameters to pass into model training (optional) - prediction_key: # key of the target value in the estimator's exported predict outputs (default: "class_ids" for classification, "predictions" for regression) - path: # path to the implementation file, relative to the application root (default: implementations/models/.py) + estimator: # the name of the estimator to use (this or estimator_path must be specified) + estimator_path: # a path to an estimator implementation file (this or estimator must be specified) + target_column: # a reference to the column to predict (e.g. @column1) (required) + input: # the input to the model, which may contain references to columns, constants, and aggregates (e.g. @column1) (required) + training_input: # input to the model which is only used during training, which may contain references to columns, constants, and aggregates (e.g. @column1) (optional) + hparams: # hyperparameters to pass into model training, which may not contain reference to other resources (optional) + prediction_key: # key of the target value in the estimator's exported predict outputs (default: "class_ids" for INT_COLUMN and STRING_COLUMN targets, "predictions" otherwise) data_partition_ratio: training: # the proportion of data to be used for training (default: 0.8) @@ -55,12 +54,10 @@ Train custom TensorFlow models at scale. executor_mem: # memory request for each spark executor (default: 500Mi) executor_mem_overhead: # off-heap (non-JVM) memory allocated to each executor (overrides mem_overhead_factor) (default: min[executor_mem * 0.4, 384Mi]) mem_overhead_factor: # the proportion of driver_mem/executor_mem which will be additionally allocated for off-heap (non-JVM) memory (default: 0.4) - - tags: - : # arbitrary key/value pairs to attach to the resource (optional) - ... ``` +See [Data Types](data-types.md) for details about input and output values. Note: the `target_column`, `input`, `training_input`, and `hparams` of the the aggregate must match the input types of the estimator (if specified). + See the [tf.estimator.RunConfig](https://www.tensorflow.org/api_docs/python/tf/estimator/RunConfig) and [tf.estimator.EvalSpec](https://www.tensorflow.org/api_docs/python/tf/estimator/EvalSpec) documentation for more information. ## Example @@ -68,16 +65,10 @@ See the [tf.estimator.RunConfig](https://www.tensorflow.org/api_docs/python/tf/e ```yaml - kind: model name: dnn - type: classification - target_column: label - feature_columns: - - column1 - - column2 - - column3 - training_columns: - - class_weight - aggregates: - - column1_index + estimator: cortex.dnn_classifier + target_column: @class + input: + numeric_columns: [@column1, @column2, @column3] hparams: hidden_units: [4, 2] data_partition_ratio: diff --git a/docs/applications/resources/overview.md b/docs/applications/resources/overview.md index e6ad3b1695..61fb8149eb 100644 --- a/docs/applications/resources/overview.md +++ b/docs/applications/resources/overview.md @@ -13,6 +13,8 @@ Cortex applications consist of declarative resource configuration written in YAM * [api](apis.md) * [constant](constants.md) +Resources can reference other resources from within their configuration (e.g. when defining input values) by prefixing the other resource's name with an `@` symbol. For example, a model may specify `input: @column1`, which denotes that a resource named "column1" is an input to this model. + With the exception of the `app` kind (which must be defined in a top-level `app.yaml` file), resources may be defined in any YAML file within your Cortex application folder or any subdirectories. The `cortex deploy` command will validate all resource configuration and attempt to create the requested state on the cluster. diff --git a/docs/applications/resources/raw-columns.md b/docs/applications/resources/raw-columns.md index d4c7802881..ee33e198dc 100644 --- a/docs/applications/resources/raw-columns.md +++ b/docs/applications/resources/raw-columns.md @@ -21,9 +21,6 @@ Validate raw data at scale and define columns. executor_mem: # memory request for each spark executor (default: 500Mi) executor_mem_overhead: # off-heap (non-JVM) memory allocated to each executor (overrides mem_overhead_factor) (default: min[executor_mem * 0.4, 384Mi]) mem_overhead_factor: # the proportion of driver_mem/executor_mem which will be additionally allocated for off-heap (non-JVM) memory (default: 0.4) - tags: - : # arbitrary key/value pairs to attach to the resource (optional) - ... - kind: raw_column name: # raw column name (required) @@ -41,9 +38,6 @@ Validate raw data at scale and define columns. executor_mem: # memory request for each spark executor (default: 500Mi) executor_mem_overhead: # off-heap (non-JVM) memory allocated to each executor (overrides mem_overhead_factor) (default: min[executor_mem * 0.4, 384Mi]) mem_overhead_factor: # the proportion of driver_mem/executor_mem which will be additionally allocated for off-heap (non-JVM) memory (default: 0.4) - tags: - : # arbitrary key/value pairs to attach to the resource (optional) - ... - kind: raw_column name: # raw column name (required) @@ -59,9 +53,6 @@ Validate raw data at scale and define columns. executor_mem: # memory request for each spark executor (default: 500Mi) executor_mem_overhead: # off-heap (non-JVM) memory allocated to each executor (overrides mem_overhead_factor) (default: min[executor_mem * 0.4, 384Mi]) mem_overhead_factor: # the proportion of driver_mem/executor_mem which will be additionally allocated for off-heap (non-JVM) memory (default: 0.4) - tags: - : # arbitrary key/value pairs to attach to the resource (optional) - ... ``` ## Example diff --git a/docs/applications/resources/transformed-columns.md b/docs/applications/resources/transformed-columns.md index 2bd0dff9e8..519f9965eb 100644 --- a/docs/applications/resources/transformed-columns.md +++ b/docs/applications/resources/transformed-columns.md @@ -7,14 +7,9 @@ Transform data at scale. ```yaml - kind: transformed_column name: # transformed column name (required) - transformer: # the name of the transformer to use (required) - inputs: - columns: - : or <[string]> # map of column input name to raw column name(s) (required) - ... - args: - : # value may be an aggregate, constant, or literal value (optional) - ... + transformer: # the name of the transformer to use (this or transformer_path must be specified) + transformer_path: # a path to an transformer implementation file (this or transformer must be specified) + input: # the input to the transformer, which may contain references to columns, constants, and aggregates (e.g. @column1) (required) compute: executors: # number of spark executors (default: 1) driver_cpu: # CPU request for spark driver (default: 1) @@ -24,14 +19,9 @@ Transform data at scale. executor_mem: # memory request for each spark executor (default: 500Mi) executor_mem_overhead: # off-heap (non-JVM) memory allocated to each executor (overrides mem_overhead_factor) (default: min[executor_mem * 0.4, 384Mi]) mem_overhead_factor: # the proportion of driver_mem/executor_mem which will be additionally allocated for off-heap (non-JVM) memory (default: 0.4) - tags: - : # arbitrary key/value pairs to attach to the resource (optional) - ... ``` -Note: the `columns` and `args` fields of the the transformed column must match the data types of the `columns` and `args` fields of the selected transformer. - -Each `args` value may be the name of an aggregate, the name of a constant, or a literal value. Any string value will be assumed to be the name of an aggregate or constant. To use a string literal as an arg, escape it with double quotes (e.g. `arg_name: "\"string literal\""`. +See [Data Types](data-types.md) for details about input values. Note: the `input` of the the transformed column must match the input type of the transformer (if specified). See [`transformers.yaml`](https://github.com/cortexlabs/cortex/blob/master/pkg/transformers/transformers.yaml) for a list of built-in transformers. @@ -41,30 +31,24 @@ See [`transformers.yaml`](https://github.com/cortex - kind: transformed_column name: age_normalized transformer: cortex.normalize - inputs: - columns: - num: age # column name - args: - mean: age_mean # the name of a cortex.mean aggregator - stddev: age_stddev # the name of a cortex.stddev aggregator + input: + num: @age # "age" is the name of a numeric raw column + mean: @age_mean # "age_mean" is the name of an aggregate which used the cortex.mean aggregator + stddev: @age_stddev # "age_stddev" is the name of an aggregate which used the cortex.stddev aggregator - kind: transformed_column name: class_indexed transformer: cortex.index_string - inputs: - columns: - col: class # the name of a string column - args: - index: {"indexes": ["t", "f"], "reversed_index": ["t": 0, "f": 1]} # a value to be used as the index + input: + col: @class # "class" is the name of a string raw column + index: {"indexes": ["t", "f"], "reversed_index": ["t": 0, "f": 1]} # a value to be used as the index - kind: transformed_column name: price_bucketized transformer: cortex.bucketize - inputs: - columns: - num: price # column name - args: - bucket_boundaries: bucket_boundaries # the name of a [FLOAT] constant + input: + num: @price # "price" is the name of a numeric raw column + bucket_boundaries: @bucket_boundaries # "bucket_boundaries" is the name of a [FLOAT] constant ``` ## Validating Transformers diff --git a/docs/applications/resources/transformers.md b/docs/applications/resources/transformers.md index dc51b1f902..3ba071b9da 100644 --- a/docs/applications/resources/transformers.md +++ b/docs/applications/resources/transformers.md @@ -1,6 +1,6 @@ # Transformers -A transformer converts a set of columns and arbitrary values into a single transformed column. Each transformer has an input schema and an output data type. The input schema is a map which specifies the name and data type of each input column and argument. +A transformer converts a set of columns and arbitrary values into a single transformed column. Each transformer has an input type and an output column type. Custom transformers can be implemented in Python or PySpark. See the [implementation docs](../implementations/transformers.md) for a detailed guide. @@ -10,17 +10,11 @@ Custom transformers can be implemented in Python or PySpark. See the [implementa - kind: transformer name: # transformer name (required) path: # path to the implementation file, relative to the application root (default: implementations/transformers/.py) - output_type: # output data type (required) - inputs: - columns: - : # map of column input name to column input type(s) (required) - ... - args: - : # map of arg input name to value input type(s) (optional) - ... + output_type: # The type of column that will be generated by this transformer (required) + input: # the input type of the transformer (required) ``` -See [Data Types](datatypes.md) for a list of valid data types. +See [Data Types](data-types.md) for details about input and column types. ## Example @@ -28,12 +22,10 @@ See [Data Types](datatypes.md) for a list of valid data types. - kind: transformer name: normalize output_type: FLOAT_COLUMN - inputs: - columns: - num: INT_COLUMN|FLOAT_COLUMN - args: - mean: FLOAT - stddev: FLOAT + input: + num: INT_COLUMN|FLOAT_COLUMN + mean: FLOAT + stddev: FLOAT ``` ## Built-in Transformers From efe898a874f2464645ffba3640953f12e2dac8e3 Mon Sep 17 00:00:00 2001 From: Vishal Bollu Date: Thu, 13 Jun 2019 23:15:49 +0200 Subject: [PATCH 44/44] Input redesign update examples (#162) --- examples/mnist/implementations/models/dnn.py | 2 +- .../transformers/decode_and_normalize.py | 6 +-- examples/mnist/resources/apis.yaml | 6 +-- examples/mnist/resources/data.yaml | 25 +++++++++++ examples/mnist/resources/environments.yaml | 10 ----- examples/mnist/resources/models.yaml | 27 +++++------- examples/mnist/resources/raw_columns.yaml | 11 ----- .../mnist/resources/transformed_columns.yaml | 6 --- .../implementations/models/basic_embedding.py | 22 +++------- examples/movie-ratings/resources/apis.yaml | 2 +- examples/movie-ratings/resources/data.yaml | 44 +++++++++++++++++++ .../movie-ratings/resources/environments.yaml | 20 --------- examples/movie-ratings/resources/models.yaml | 16 ++++--- .../resources/transformed_columns.yaml | 31 ------------- .../implementations/aggregators/max_length.py | 4 +- .../implementations/aggregators/vocab.py | 6 +-- .../implementations/models/sentiment_dnn.py | 2 +- .../models/sentiment_linear.py | 2 +- .../implementations/models/transformer.py | 4 +- .../transformers/tokenize_string_to_int.py | 10 ++--- examples/reviews/resources/apis.yaml | 6 +-- examples/reviews/resources/columns.yaml | 28 ------------ examples/reviews/resources/data.yaml | 41 +++++++++++++++++ examples/reviews/resources/max_length.yaml | 6 --- examples/reviews/resources/models.yaml | 35 +++++++-------- examples/reviews/resources/vocab.yaml | 15 ------- 26 files changed, 177 insertions(+), 210 deletions(-) create mode 100644 examples/mnist/resources/data.yaml delete mode 100644 examples/mnist/resources/environments.yaml delete mode 100644 examples/mnist/resources/raw_columns.yaml delete mode 100644 examples/mnist/resources/transformed_columns.yaml create mode 100644 examples/movie-ratings/resources/data.yaml delete mode 100644 examples/movie-ratings/resources/environments.yaml delete mode 100644 examples/movie-ratings/resources/transformed_columns.yaml delete mode 100644 examples/reviews/resources/columns.yaml create mode 100644 examples/reviews/resources/data.yaml delete mode 100644 examples/reviews/resources/max_length.yaml delete mode 100644 examples/reviews/resources/vocab.yaml diff --git a/examples/mnist/implementations/models/dnn.py b/examples/mnist/implementations/models/dnn.py index 7f3c97d95b..fbf317106f 100644 --- a/examples/mnist/implementations/models/dnn.py +++ b/examples/mnist/implementations/models/dnn.py @@ -4,7 +4,7 @@ def create_estimator(run_config, model_config): feature_columns = [ tf.feature_column.numeric_column( - "image_pixels", shape=model_config["hparams"]["input_shape"] + model_config["input"], shape=model_config["hparams"]["input_shape"] ) ] diff --git a/examples/mnist/implementations/transformers/decode_and_normalize.py b/examples/mnist/implementations/transformers/decode_and_normalize.py index 2ff41440e3..4933445c11 100644 --- a/examples/mnist/implementations/transformers/decode_and_normalize.py +++ b/examples/mnist/implementations/transformers/decode_and_normalize.py @@ -5,10 +5,8 @@ import math -def transform_python(sample, args): - image = sample["image"] - - decoded = base64.b64decode(image) +def transform_python(input): + decoded = base64.b64decode(input) decoded_image = np.asarray(Image.open(BytesIO(decoded)), dtype=np.uint8) # reimplmenting tf.per_image_standardization diff --git a/examples/mnist/resources/apis.yaml b/examples/mnist/resources/apis.yaml index 3bab92bd2a..5b1b966892 100644 --- a/examples/mnist/resources/apis.yaml +++ b/examples/mnist/resources/apis.yaml @@ -1,17 +1,17 @@ - kind: api name: dnn-classifier - model_name: dnn + model: @dnn compute: replicas: 1 - kind: api name: conv-classifier - model_name: conv + model: @conv compute: replicas: 1 - kind: api name: t2t-classifier - model_name: t2t + model: @t2t compute: replicas: 1 diff --git a/examples/mnist/resources/data.yaml b/examples/mnist/resources/data.yaml new file mode 100644 index 0000000000..3eee3607e8 --- /dev/null +++ b/examples/mnist/resources/data.yaml @@ -0,0 +1,25 @@ +- kind: environment + name: dev + data: + type: csv + path: s3a://cortex-examples/mnist.csv + csv_config: + header: true + schema: [@image, @label] + +- kind: raw_column + name: image + type: STRING_COLUMN + required: true + +- kind: raw_column + name: label + type: INT_COLUMN + required: true + min: 0 + max: 9 + +- kind: transformed_column + name: image_pixels + transformer_path: implementations/transformers/decode_and_normalize.py + input: @image diff --git a/examples/mnist/resources/environments.yaml b/examples/mnist/resources/environments.yaml deleted file mode 100644 index 660927f011..0000000000 --- a/examples/mnist/resources/environments.yaml +++ /dev/null @@ -1,10 +0,0 @@ -- kind: environment - name: dev - data: - type: csv - path: s3a://cortex-examples/mnist.csv - csv_config: - header: true - schema: - - image - - label diff --git a/examples/mnist/resources/models.yaml b/examples/mnist/resources/models.yaml index a789ab9d0c..a7f41c6118 100644 --- a/examples/mnist/resources/models.yaml +++ b/examples/mnist/resources/models.yaml @@ -1,14 +1,11 @@ - kind: model name: dnn - path: implementations/models/dnn.py - type: classification - target_column: label - feature_columns: - - image_pixels + estimator_path: implementations/models/dnn.py + target_column: @label + input: @image_pixels hparams: - learning_rate: 0.01 input_shape: [784] - output_shape: [10] + learning_rate: 0.01 hidden_units: [100, 200] data_partition_ratio: training: 0.7 @@ -16,11 +13,9 @@ - kind: model name: conv - path: implementations/models/custom.py - type: classification - target_column: label - feature_columns: - - image_pixels + estimator_path: implementations/models/custom.py + target_column: @label + input: @image_pixels hparams: layer_type: conv learning_rate: 0.01 @@ -38,11 +33,9 @@ - kind: model name: t2t - path: implementations/models/t2t.py - type: classification - target_column: label - feature_columns: - - image_pixels + estimator_path: implementations/models/t2t.py + target_column: @label + input: @image_pixels prediction_key: outputs hparams: input_shape: [28, 28, 1] diff --git a/examples/mnist/resources/raw_columns.yaml b/examples/mnist/resources/raw_columns.yaml deleted file mode 100644 index 742e506f9c..0000000000 --- a/examples/mnist/resources/raw_columns.yaml +++ /dev/null @@ -1,11 +0,0 @@ -- kind: raw_column - name: image - type: STRING_COLUMN - required: true - -- kind: raw_column - name: label - type: INT_COLUMN - required: true - min: 0 - max: 9 diff --git a/examples/mnist/resources/transformed_columns.yaml b/examples/mnist/resources/transformed_columns.yaml deleted file mode 100644 index 4c736b0880..0000000000 --- a/examples/mnist/resources/transformed_columns.yaml +++ /dev/null @@ -1,6 +0,0 @@ -- kind: transformed_column - name: image_pixels - transformer_path: implementations/transformers/decode_and_normalize.py - inputs: - columns: - image: image diff --git a/examples/movie-ratings/implementations/models/basic_embedding.py b/examples/movie-ratings/implementations/models/basic_embedding.py index d2e64c43d5..3a0a507b3e 100644 --- a/examples/movie-ratings/implementations/models/basic_embedding.py +++ b/examples/movie-ratings/implementations/models/basic_embedding.py @@ -2,26 +2,18 @@ def create_estimator(run_config, model_config): - user_id_index = model_config["aggregates"]["user_id_index"] - movie_id_index = model_config["aggregates"]["movie_id_index"] - - feature_columns = [ - tf.feature_column.embedding_column( - tf.feature_column.categorical_column_with_identity( - "user_id_indexed", len(user_id_index) - ), - model_config["hparams"]["embedding_size"], - ), - tf.feature_column.embedding_column( + embedding_feature_columns = [] + for feature_col_data in model_config["input"]["embedding_columns"]: + embedding_col = tf.feature_column.embedding_column( tf.feature_column.categorical_column_with_identity( - "movie_id_indexed", len(movie_id_index) + feature_col_data["col"], len(feature_col_data["vocab"]["index"]) ), model_config["hparams"]["embedding_size"], - ), - ] + ) + embedding_feature_columns.append(embedding_col) return tf.estimator.DNNRegressor( - feature_columns=feature_columns, + feature_columns=embedding_feature_columns, hidden_units=model_config["hparams"]["hidden_units"], config=run_config, ) diff --git a/examples/movie-ratings/resources/apis.yaml b/examples/movie-ratings/resources/apis.yaml index 187089009a..d5e5b259dd 100644 --- a/examples/movie-ratings/resources/apis.yaml +++ b/examples/movie-ratings/resources/apis.yaml @@ -1,5 +1,5 @@ - kind: api name: ratings - model_name: basic_embedding + model: @basic_embedding compute: replicas: 1 diff --git a/examples/movie-ratings/resources/data.yaml b/examples/movie-ratings/resources/data.yaml new file mode 100644 index 0000000000..40182c1727 --- /dev/null +++ b/examples/movie-ratings/resources/data.yaml @@ -0,0 +1,44 @@ +- kind: environment + name: dev + data: + type: csv + path: s3a://cortex-examples/movie-ratings.csv + csv_config: + header: true + schema: [@user_id, @movie_id, @rating, @timestamp] + +- kind: raw_column + name: user_id + type: STRING_COLUMN + +- kind: raw_column + name: movie_id + type: STRING_COLUMN + +- kind: raw_column + name: rating + type: FLOAT_COLUMN + +- kind: aggregate + name: user_id_index + aggregator: cortex.index_string + input: @user_id + +- kind: transformed_column + name: user_id_indexed + transformer: cortex.index_string + input: + col: @user_id + indexes: @user_id_index + +- kind: aggregate + name: movie_id_index + aggregator: cortex.index_string + input: @movie_id + +- kind: transformed_column + name: movie_id_indexed + transformer: cortex.index_string + input: + col: @movie_id + indexes: @movie_id_index diff --git a/examples/movie-ratings/resources/environments.yaml b/examples/movie-ratings/resources/environments.yaml deleted file mode 100644 index a8604c6525..0000000000 --- a/examples/movie-ratings/resources/environments.yaml +++ /dev/null @@ -1,20 +0,0 @@ -- kind: environment - name: dev - data: - type: csv - path: s3a://cortex-examples/movie-ratings.csv - csv_config: - header: true - schema: ['user_id','movie_id','rating','timestamp'] - -- kind: raw_column - name: user_id - type: STRING_COLUMN - -- kind: raw_column - name: movie_id - type: STRING_COLUMN - -- kind: raw_column - name: rating - type: FLOAT_COLUMN diff --git a/examples/movie-ratings/resources/models.yaml b/examples/movie-ratings/resources/models.yaml index 8f5e0ef5c8..6cf6db550a 100644 --- a/examples/movie-ratings/resources/models.yaml +++ b/examples/movie-ratings/resources/models.yaml @@ -1,12 +1,16 @@ - kind: model name: basic_embedding - type: regression - target_column: rating - feature_columns: [user_id_indexed, movie_id_indexed] - aggregates: [user_id_index, movie_id_index] + estimator_path: implementations/models/basic_embedding.py + target_column: @rating + input: + embedding_columns: + - col: @user_id_indexed + vocab: @user_id_index + - col: @movie_id_indexed + vocab: @movie_id_index hparams: - embedding_size: 10 - hidden_units: [128] + embedding_size: 20 + hidden_units: [10, 10] data_partition_ratio: training: 0.8 evaluation: 0.2 diff --git a/examples/movie-ratings/resources/transformed_columns.yaml b/examples/movie-ratings/resources/transformed_columns.yaml deleted file mode 100644 index ac1a27250f..0000000000 --- a/examples/movie-ratings/resources/transformed_columns.yaml +++ /dev/null @@ -1,31 +0,0 @@ -- kind: aggregate - name: user_id_index - aggregator: cortex.index_string - inputs: - columns: - col: user_id - -- kind: transformed_column - name: user_id_indexed - transformer: cortex.index_string - inputs: - columns: - text: user_id - args: - indexes: user_id_index - -- kind: aggregate - name: movie_id_index - aggregator: cortex.index_string - inputs: - columns: - col: movie_id - -- kind: transformed_column - name: movie_id_indexed - transformer: cortex.index_string - inputs: - columns: - text: movie_id - args: - indexes: movie_id_index diff --git a/examples/reviews/implementations/aggregators/max_length.py b/examples/reviews/implementations/aggregators/max_length.py index 5552024dbf..8034888dd9 100644 --- a/examples/reviews/implementations/aggregators/max_length.py +++ b/examples/reviews/implementations/aggregators/max_length.py @@ -1,9 +1,9 @@ -def aggregate_spark(data, columns, args): +def aggregate_spark(data, input): from pyspark.ml.feature import RegexTokenizer import pyspark.sql.functions as F from pyspark.sql.types import IntegerType - regexTokenizer = RegexTokenizer(inputCol=columns["col"], outputCol="token_list", pattern="\\W") + regexTokenizer = RegexTokenizer(inputCol=input, outputCol="token_list", pattern="\\W") regexTokenized = regexTokenizer.transform(data) max_review_length_row = ( diff --git a/examples/reviews/implementations/aggregators/vocab.py b/examples/reviews/implementations/aggregators/vocab.py index 39e0dc7c71..7377585eb8 100644 --- a/examples/reviews/implementations/aggregators/vocab.py +++ b/examples/reviews/implementations/aggregators/vocab.py @@ -1,8 +1,8 @@ -def aggregate_spark(data, columns, args): +def aggregate_spark(data, input): import pyspark.sql.functions as F from pyspark.ml.feature import RegexTokenizer - regexTokenizer = RegexTokenizer(inputCol=columns["col"], outputCol="token_list", pattern="\\W") + regexTokenizer = RegexTokenizer(inputCol=input["col"], outputCol="token_list", pattern="\\W") regexTokenized = regexTokenizer.transform(data) vocab_rows = ( @@ -10,7 +10,7 @@ def aggregate_spark(data, columns, args): .groupBy("word") .count() .orderBy(F.col("count").desc()) - .limit(args["vocab_size"]) + .limit(input["vocab_size"]) .select("word") .collect() ) diff --git a/examples/reviews/implementations/models/sentiment_dnn.py b/examples/reviews/implementations/models/sentiment_dnn.py index b6506000b9..c9aa1155cf 100644 --- a/examples/reviews/implementations/models/sentiment_dnn.py +++ b/examples/reviews/implementations/models/sentiment_dnn.py @@ -4,7 +4,7 @@ def create_estimator(run_config, model_config): hparams = model_config["hparams"] - vocab_size = len(model_config["aggregates"]["reviews_vocab"]) + vocab_size = len(model_config["input"]["vocab"]) def model_fn(features, labels, mode, params): embedding_input = features["embedding_input"] diff --git a/examples/reviews/implementations/models/sentiment_linear.py b/examples/reviews/implementations/models/sentiment_linear.py index 1795e4cec7..6b1e512c14 100644 --- a/examples/reviews/implementations/models/sentiment_linear.py +++ b/examples/reviews/implementations/models/sentiment_linear.py @@ -2,7 +2,7 @@ def create_estimator(run_config, model_config): - vocab_size = len(model_config["aggregates"]["reviews_vocab"]) + vocab_size = len(model_config["input"]["vocab"]) feature_column = tf.feature_column.categorical_column_with_identity( "embedding_input", vocab_size ) diff --git a/examples/reviews/implementations/models/transformer.py b/examples/reviews/implementations/models/transformer.py index 0b03d52825..f4442abddc 100644 --- a/examples/reviews/implementations/models/transformer.py +++ b/examples/reviews/implementations/models/transformer.py @@ -13,7 +13,7 @@ def create_estimator(run_config, model_config): hparams = trainer_lib.create_hparams("transformer_base_single_gpu") # SentimentIMDBCortex subclasses SentimentIMDB - problem = SentimentIMDBCortex(list(model_config["aggregates"]["reviews_vocab"])) + problem = SentimentIMDBCortex(list(model_config["input"]["vocab"])) hparams.problem = problem hparams.problem_hparams = problem.get_hparams(hparams) @@ -39,7 +39,7 @@ def create_estimator(run_config, model_config): def transform_tensorflow(features, labels, model_config): - max_length = model_config["aggregates"]["max_review_length"] + max_length = model_config["input"]["max_review_length"] features["inputs"] = tf.expand_dims(tf.reshape(features["embedding_input"], [max_length]), -1) features["targets"] = tf.expand_dims(tf.expand_dims(labels, -1), -1) diff --git a/examples/reviews/implementations/transformers/tokenize_string_to_int.py b/examples/reviews/implementations/transformers/tokenize_string_to_int.py index 1ab078cc69..151c18f754 100644 --- a/examples/reviews/implementations/transformers/tokenize_string_to_int.py +++ b/examples/reviews/implementations/transformers/tokenize_string_to_int.py @@ -3,19 +3,19 @@ non_word = re.compile("\\W") -def transform_python(sample, args): - text = sample["col"].lower() +def transform_python(input): + text = input["col"].lower() token_index_list = [] - vocab = args["vocab"] + vocab = input["vocab"] for token in non_word.split(text): if len(token) == 0: continue token_index_list.append(vocab.get(token, vocab[""])) - if len(token_index_list) == args["max_len"]: + if len(token_index_list) == input["max_len"]: break - for i in range(args["max_len"] - len(token_index_list)): + for i in range(input["max_len"] - len(token_index_list)): token_index_list.append(vocab[""]) return token_index_list diff --git a/examples/reviews/resources/apis.yaml b/examples/reviews/resources/apis.yaml index 56b549968b..0819d5eba4 100644 --- a/examples/reviews/resources/apis.yaml +++ b/examples/reviews/resources/apis.yaml @@ -1,17 +1,17 @@ - kind: api name: sentiment-dnn - model_name: sentiment_dnn + model: @sentiment_dnn compute: replicas: 1 - kind: api name: sentiment-linear - model_name: sentiment_linear + model: @sentiment_linear compute: replicas: 1 - kind: api name: sentiment-t2t - model_name: transformer + model: @transformer compute: replicas: 1 diff --git a/examples/reviews/resources/columns.yaml b/examples/reviews/resources/columns.yaml deleted file mode 100644 index 562a132d56..0000000000 --- a/examples/reviews/resources/columns.yaml +++ /dev/null @@ -1,28 +0,0 @@ -- kind: environment - name: dev - data: - type: csv - path: s3a://cortex-examples/reviews.csv - csv_config: - header: true - escape: "\"" - schema: ["review", "label"] - -- kind: transformed_column - name: embedding_input - transformer_path: implementations/transformers/tokenize_string_to_int.py - inputs: - columns: - col: review - args: - max_len: max_review_length - vocab: reviews_vocab - -- kind: transformed_column - name: label_indexed - transformer: cortex.index_string - inputs: - columns: - text: label - args: - indexes: label_index diff --git a/examples/reviews/resources/data.yaml b/examples/reviews/resources/data.yaml new file mode 100644 index 0000000000..9da7e15d9c --- /dev/null +++ b/examples/reviews/resources/data.yaml @@ -0,0 +1,41 @@ +- kind: environment + name: dev + data: + type: csv + path: s3a://cortex-examples/reviews.csv + csv_config: + header: true + escape: "\"" + schema: [@review, @label] + +- kind: aggregate + name: max_review_length + aggregator_path: implementations/aggregators/max_length.py + input: @review + +- kind: aggregate + name: reviews_vocab + aggregator_path: implementations/aggregators/vocab.py + input: + col: @review + vocab_size: 10000 + +- kind: aggregate + name: label_index + aggregator: cortex.index_string + input: @label + +- kind: transformed_column + name: embedding_input + transformer_path: implementations/transformers/tokenize_string_to_int.py + input: + col: @review + max_len: @max_review_length + vocab: @reviews_vocab + +- kind: transformed_column + name: label_indexed + transformer: cortex.index_string + input: + col: @label + indexes: @label_index diff --git a/examples/reviews/resources/max_length.yaml b/examples/reviews/resources/max_length.yaml deleted file mode 100644 index 168190f03f..0000000000 --- a/examples/reviews/resources/max_length.yaml +++ /dev/null @@ -1,6 +0,0 @@ -- kind: aggregate - name: max_review_length - aggregator_path: implementations/aggregators/max_length.py - inputs: - columns: - col: review diff --git a/examples/reviews/resources/models.yaml b/examples/reviews/resources/models.yaml index c6aa6ca629..0c353d4508 100644 --- a/examples/reviews/resources/models.yaml +++ b/examples/reviews/resources/models.yaml @@ -1,11 +1,10 @@ - kind: model name: sentiment_dnn - type: classification - target_column: label_indexed - feature_columns: - - embedding_input - aggregates: - - reviews_vocab + estimator_path: implementations/models/sentiment_dnn.py + target_column: @label_indexed + input: + embedding_input: @embedding_input + vocab: @reviews_vocab hparams: learning_rate: 0.01 data_partition_ratio: @@ -17,12 +16,11 @@ - kind: model name: sentiment_linear - type: classification - target_column: label_indexed - feature_columns: - - embedding_input - aggregates: - - reviews_vocab + estimator_path: implementations/models/sentiment_linear.py + target_column: @label_indexed + input: + embedding_input: @embedding_input + vocab: @reviews_vocab data_partition_ratio: training: 0.8 evaluation: 0.2 @@ -32,13 +30,12 @@ - kind: model name: transformer - type: classification - target_column: label_indexed - feature_columns: - - embedding_input - aggregates: - - max_review_length - - reviews_vocab + estimator_path: implementations/models/transformer.py + target_column: @label_indexed + input: + embedding_input: @embedding_input + max_review_length: @max_review_length + vocab: @reviews_vocab prediction_key: outputs data_partition_ratio: training: 0.8 diff --git a/examples/reviews/resources/vocab.yaml b/examples/reviews/resources/vocab.yaml deleted file mode 100644 index a2da56f53f..0000000000 --- a/examples/reviews/resources/vocab.yaml +++ /dev/null @@ -1,15 +0,0 @@ -- kind: aggregate - name: reviews_vocab - aggregator_path: implementations/aggregators/vocab.py - inputs: - columns: - col: review - args: - vocab_size: 10000 - -- kind: aggregate - name: label_index - aggregator: cortex.index_string - inputs: - columns: - col: label