diff --git a/examples/mnist/resources/transformed_columns.yaml b/examples/mnist/resources/transformed_columns.yaml index f46f08aba4..4c736b0880 100644 --- a/examples/mnist/resources/transformed_columns.yaml +++ b/examples/mnist/resources/transformed_columns.yaml @@ -1,6 +1,6 @@ - kind: transformed_column name: image_pixels - transformer: decode_and_normalize + transformer_path: implementations/transformers/decode_and_normalize.py inputs: columns: image: image diff --git a/examples/mnist/resources/transformers.yaml b/examples/mnist/resources/transformers.yaml deleted file mode 100644 index 3d13b2d5ee..0000000000 --- a/examples/mnist/resources/transformers.yaml +++ /dev/null @@ -1,6 +0,0 @@ -- kind: transformer - name: decode_and_normalize - output_type: FLOAT_LIST_COLUMN - inputs: - columns: - image: STRING_COLUMN diff --git a/examples/reviews/resources/max_length.yaml b/examples/reviews/resources/max_length.yaml index d5f642e096..168190f03f 100644 --- a/examples/reviews/resources/max_length.yaml +++ b/examples/reviews/resources/max_length.yaml @@ -1,13 +1,6 @@ -- kind: aggregator - name: max_length - inputs: - columns: - col: STRING_COLUMN - output_type: INT - - kind: aggregate name: max_review_length - aggregator: max_length + aggregator_path: implementations/aggregators/max_length.py inputs: columns: col: review diff --git a/examples/reviews/resources/tokenized_columns.yaml b/examples/reviews/resources/tokenized_columns.yaml index 225bd38f60..cc2715e76f 100644 --- a/examples/reviews/resources/tokenized_columns.yaml +++ b/examples/reviews/resources/tokenized_columns.yaml @@ -1,16 +1,6 @@ -- kind: transformer - name: tokenize_string_to_int - output_type: INT_LIST_COLUMN - inputs: - columns: - col: STRING_COLUMN - args: - max_len: INT - vocab: {STRING: INT} - - kind: transformed_column name: embedding_input - transformer: tokenize_string_to_int + transformer_path: implementations/transformers/tokenize_string_to_int.py inputs: columns: col: review diff --git a/pkg/consts/consts.go b/pkg/consts/consts.go index daf5e39737..e0835c7c7d 100644 --- a/pkg/consts/consts.go +++ b/pkg/consts/consts.go @@ -36,21 +36,24 @@ var ( RequirementsTxt = "requirements.txt" PackageDir = "packages" - AppsDir = "apps" - DataDir = "data" - RawDataDir = "data_raw" - TrainingDataDir = "data_training" - AggregatorsDir = "aggregators" - AggregatesDir = "aggregates" - TransformersDir = "transformers" - ModelImplsDir = "model_implementations" - PythonPackagesDir = "python_packages" - ModelsDir = "models" - ConstantsDir = "constants" - ContextsDir = "contexts" - ResourceStatusesDir = "resource_statuses" - WorkloadSpecsDir = "workload_specs" - LogPrefixesDir = "log_prefixes" + AppsDir = "apps" + APIsDir = "apis" + DataDir = "data" + RawDataDir = "data_raw" + TrainingDataDir = "data_training" + AggregatorsDir = "aggregators" + AggregatesDir = "aggregates" + TransformersDir = "transformers" + ModelImplsDir = "model_implementations" + PythonPackagesDir = "python_packages" + ModelsDir = "models" + ConstantsDir = "constants" + ContextsDir = "contexts" + ResourceStatusesDir = "resource_statuses" + WorkloadSpecsDir = "workload_specs" + LogPrefixesDir = "log_prefixes" + RawColumnsDir = "raw_columns" + TransformedColumnsDir = "transformed_columns" TelemetryURL = "https://telemetry.cortexlabs.dev" ) diff --git a/pkg/lib/configreader/string.go b/pkg/lib/configreader/string.go index 45e990715f..e0ecec4189 100644 --- a/pkg/lib/configreader/string.go +++ b/pkg/lib/configreader/string.go @@ -28,15 +28,16 @@ import ( ) type StringValidation struct { - Required bool - Default string - AllowEmpty bool - AllowedValues []string - Prefix string - AlphaNumericDashDotUnderscore bool - AlphaNumericDashUnderscore bool - DNS1035 bool - Validator func(string) (string, error) + Required bool + Default string + AllowEmpty bool + AllowedValues []string + Prefix string + AlphaNumericDashDotUnderscoreOrEmpty bool + AlphaNumericDashDotUnderscore bool + AlphaNumericDashUnderscore bool + DNS1035 bool + Validator func(string) (string, error) } func EnvVar(envVarName string) string { @@ -190,6 +191,12 @@ func ValidateStringVal(val string, v *StringValidation) error { } } + if v.AlphaNumericDashDotUnderscoreOrEmpty { + if !regex.CheckAlphaNumericDashDotUnderscore(val) && val != "" { + return ErrorAlphaNumericDashDotUnderscore(val) + } + } + if v.DNS1035 { if err := urls.CheckDNS1035(val); err != nil { return err diff --git a/pkg/operator/api/context/context.go b/pkg/operator/api/context/context.go index ab062a32ed..594df54b87 100644 --- a/pkg/operator/api/context/context.go +++ b/pkg/operator/api/context/context.go @@ -55,6 +55,7 @@ type Resource interface { GetID() string GetIDWithTags() string GetResourceFields() *ResourceFields + GetMetadataKey() string } type ComputedResource interface { @@ -72,6 +73,7 @@ type ResourceFields struct { ID string `json:"id"` IDWithTags string `json:"id_with_tags"` ResourceType resource.Type `json:"resource_type"` + MetadataKey string `json:"metadata_key"` } type ComputedResourceFields struct { @@ -91,6 +93,10 @@ func (r *ResourceFields) GetResourceFields() *ResourceFields { return r } +func (r *ResourceFields) GetMetadataKey() string { + return r.MetadataKey +} + func (r *ComputedResourceFields) GetWorkloadID() string { return r.WorkloadID } diff --git a/pkg/operator/api/context/models.go b/pkg/operator/api/context/models.go index dd45711a76..9df19a0bc2 100644 --- a/pkg/operator/api/context/models.go +++ b/pkg/operator/api/context/models.go @@ -38,12 +38,11 @@ type Model struct { } type TrainingDataset struct { - userconfig.ResourceConfigFields + userconfig.ResourceFields *ComputedResourceFields - ModelName string `json:"model_name"` - TrainKey string `json:"train_key"` - EvalKey string `json:"eval_key"` - MetadataKey string `json:"metadata_key"` + ModelName string `json:"model_name"` + TrainKey string `json:"train_key"` + EvalKey string `json:"eval_key"` } func (trainingDataset *TrainingDataset) GetResourceType() resource.Type { diff --git a/pkg/operator/api/context/python_packages.go b/pkg/operator/api/context/python_packages.go index 9818e0082b..4b5627c7c9 100644 --- a/pkg/operator/api/context/python_packages.go +++ b/pkg/operator/api/context/python_packages.go @@ -24,7 +24,7 @@ import ( type PythonPackages map[string]*PythonPackage type PythonPackage struct { - userconfig.ResourceConfigFields + userconfig.ResourceFields *ComputedResourceFields SrcKey string `json:"src_key"` PackageKey string `json:"package_key"` diff --git a/pkg/operator/api/userconfig/aggregates.go b/pkg/operator/api/userconfig/aggregates.go index cae31fe5c1..613ad89b3a 100644 --- a/pkg/operator/api/userconfig/aggregates.go +++ b/pkg/operator/api/userconfig/aggregates.go @@ -26,11 +26,12 @@ import ( type Aggregates []*Aggregate type Aggregate struct { - ResourceConfigFields - Aggregator string `json:"aggregator" yaml:"aggregator"` - Inputs *Inputs `json:"inputs" yaml:"inputs"` - Compute *SparkCompute `json:"compute" yaml:"compute"` - Tags Tags `json:"tags" yaml:"tags"` + ResourceFields + Aggregator string `json:"aggregator" yaml:"aggregator"` + AggregatorPath *string `json:"aggregator_path" yaml:"aggregator_path"` + Inputs *Inputs `json:"inputs" yaml:"inputs"` + Compute *SparkCompute `json:"compute" yaml:"compute"` + Tags Tags `json:"tags" yaml:"tags"` } var aggregateValidation = &configreader.StructValidation{ @@ -45,10 +46,14 @@ var aggregateValidation = &configreader.StructValidation{ { StructField: "Aggregator", StringValidation: &configreader.StringValidation{ - Required: true, - AlphaNumericDashDotUnderscore: true, + AllowEmpty: true, + AlphaNumericDashDotUnderscoreOrEmpty: true, }, }, + { + StructField: "AggregatorPath", + StringPtrValidation: &configreader.StringPtrValidation{}, + }, inputValuesFieldValidation, sparkComputeFieldValidation("Compute"), tagsFieldValidation, diff --git a/pkg/operator/api/userconfig/aggregators.go b/pkg/operator/api/userconfig/aggregators.go index 226606cdbd..3eefddf752 100644 --- a/pkg/operator/api/userconfig/aggregators.go +++ b/pkg/operator/api/userconfig/aggregators.go @@ -24,7 +24,7 @@ import ( type Aggregators []*Aggregator type Aggregator struct { - ResourceConfigFields + ResourceFields Inputs *Inputs `json:"inputs" yaml:"inputs"` OutputType interface{} `json:"output_type" yaml:"output_type"` Path string `json:"path" yaml:"path"` diff --git a/pkg/operator/api/userconfig/apis.go b/pkg/operator/api/userconfig/apis.go index d7add9a17e..854dfee201 100644 --- a/pkg/operator/api/userconfig/apis.go +++ b/pkg/operator/api/userconfig/apis.go @@ -24,7 +24,7 @@ import ( type APIs []*API type API struct { - ResourceConfigFields + ResourceFields ModelName string `json:"model_name" yaml:"model_name"` Compute *APICompute `json:"compute" yaml:"compute"` Tags Tags `json:"tags" yaml:"tags"` diff --git a/pkg/operator/api/userconfig/config.go b/pkg/operator/api/userconfig/config.go index f245f6436b..99a3086b17 100644 --- a/pkg/operator/api/userconfig/config.go +++ b/pkg/operator/api/userconfig/config.go @@ -200,18 +200,38 @@ func (config *Config) Validate(envName string) error { } } - // Check local aggregators exist + // Check local aggregators exist or a path to one is defined aggregatorNames := config.Aggregators.Names() for _, aggregate := range config.Aggregates { - if !strings.Contains(aggregate.Aggregator, ".") && !slices.HasString(aggregatorNames, aggregate.Aggregator) { + if aggregate.AggregatorPath == nil && aggregate.Aggregator == "" { + return errors.Wrap(ErrorSpecifyOnlyOneMissing("aggregator", "aggregator_path"), Identify(aggregate)) + } + + if aggregate.AggregatorPath != nil && aggregate.Aggregator != "" { + return errors.Wrap(ErrorSpecifyOnlyOne("aggregator", "aggregator_path"), Identify(aggregate)) + } + + if aggregate.Aggregator != "" && + !strings.Contains(aggregate.Aggregator, ".") && + !slices.HasString(aggregatorNames, aggregate.Aggregator) { return errors.Wrap(ErrorUndefinedResource(aggregate.Aggregator, resource.AggregatorType), Identify(aggregate), AggregatorKey) } } - // Check local transformers exist + // Check local transformers exist or a path to one is defined transformerNames := config.Transformers.Names() for _, transformedColumn := range config.TransformedColumns { - if !strings.Contains(transformedColumn.Transformer, ".") && !slices.HasString(transformerNames, transformedColumn.Transformer) { + if transformedColumn.TransformerPath == nil && transformedColumn.Transformer == "" { + return errors.Wrap(ErrorSpecifyOnlyOneMissing("transformer", "transformer_path"), Identify(transformedColumn)) + } + + if transformedColumn.TransformerPath != nil && transformedColumn.Transformer != "" { + return errors.Wrap(ErrorSpecifyOnlyOne("transformer", "transformer_path"), Identify(transformedColumn)) + } + + if transformedColumn.Transformer != "" && + !strings.Contains(transformedColumn.Transformer, ".") && + !slices.HasString(transformerNames, transformedColumn.Transformer) { return errors.Wrap(ErrorUndefinedResource(transformedColumn.Transformer, resource.TransformerType), Identify(transformedColumn), TransformerKey) } } diff --git a/pkg/operator/api/userconfig/constants.go b/pkg/operator/api/userconfig/constants.go index b1217622a7..a84c678940 100644 --- a/pkg/operator/api/userconfig/constants.go +++ b/pkg/operator/api/userconfig/constants.go @@ -25,7 +25,7 @@ import ( type Constants []*Constant type Constant struct { - ResourceConfigFields + ResourceFields Type interface{} `json:"type" yaml:"type"` Value interface{} `json:"value" yaml:"value"` Tags Tags `json:"tags" yaml:"tags"` diff --git a/pkg/operator/api/userconfig/embed.go b/pkg/operator/api/userconfig/embed.go index 8713130b35..3778e189d0 100644 --- a/pkg/operator/api/userconfig/embed.go +++ b/pkg/operator/api/userconfig/embed.go @@ -24,7 +24,7 @@ import ( type Embeds []*Embed type Embed struct { - ResourceConfigFields + ResourceFields Template string `json:"template" yaml:"template"` Args map[string]interface{} `json:"args" yaml:"args"` } diff --git a/pkg/operator/api/userconfig/environments.go b/pkg/operator/api/userconfig/environments.go index 26c06534c2..940765f080 100644 --- a/pkg/operator/api/userconfig/environments.go +++ b/pkg/operator/api/userconfig/environments.go @@ -28,7 +28,7 @@ import ( type Environments []*Environment type Environment struct { - ResourceConfigFields + ResourceFields LogLevel *LogLevel `json:"log_level" yaml:"log_level"` Limit *Limit `json:"limit" yaml:"limit"` Data Data `json:"-" yaml:"-"` diff --git a/pkg/operator/api/userconfig/errors.go b/pkg/operator/api/userconfig/errors.go index 5ff2b9511f..142f311733 100644 --- a/pkg/operator/api/userconfig/errors.go +++ b/pkg/operator/api/userconfig/errors.go @@ -58,6 +58,7 @@ const ( ErrK8sQuantityMustBeInt ErrRegressionTargetType ErrClassificationTargetType + ErrSpecifyOnlyOneMissing ) var errorKinds = []string{ @@ -90,9 +91,10 @@ var errorKinds = []string{ "err_k8s_quantity_must_be_int", "err_regression_target_type", "err_classification_target_type", + "err_specify_only_one_missing", } -var _ = [1]int{}[int(ErrClassificationTargetType)-(len(errorKinds)-1)] // Ensure list length matches +var _ = [1]int{}[int(ErrSpecifyOnlyOneMissing)-(len(errorKinds)-1)] // Ensure list length matches func (t ErrorKind) String() string { return errorKinds[t] @@ -376,9 +378,22 @@ func ErrorRegressionTargetType() error { message: "regression models can only predict float target values", } } + func ErrorClassificationTargetType() error { return Error{ Kind: ErrClassificationTargetType, message: "classification models can only predict integer target values (i.e. {0, 1, ..., num_classes-1})", } } + +func ErrorSpecifyOnlyOneMissing(vals ...string) error { + message := fmt.Sprintf("please specify one of %s", s.UserStrsOr(vals)) + if len(vals) == 2 { + message = fmt.Sprintf("please specify either %s or %s", s.UserStr(vals[0]), s.UserStr(vals[1])) + } + + return Error{ + Kind: ErrSpecifyOnlyOneMissing, + message: message, + } +} diff --git a/pkg/operator/api/userconfig/models.go b/pkg/operator/api/userconfig/models.go index 6e30a3b79d..b1f7acfea1 100644 --- a/pkg/operator/api/userconfig/models.go +++ b/pkg/operator/api/userconfig/models.go @@ -29,7 +29,7 @@ import ( type Models []*Model type Model struct { - ResourceConfigFields + ResourceFields Type ModelType `json:"type" yaml:"type"` Path string `json:"path" yaml:"path"` TargetColumn string `json:"target_column" yaml:"target_column"` diff --git a/pkg/operator/api/userconfig/raw_columns.go b/pkg/operator/api/userconfig/raw_columns.go index c321c38199..1f04b7dad7 100644 --- a/pkg/operator/api/userconfig/raw_columns.go +++ b/pkg/operator/api/userconfig/raw_columns.go @@ -53,7 +53,7 @@ var rawColumnValidation = &cr.InterfaceStructValidation{ } type RawIntColumn struct { - ResourceConfigFields + ResourceFields Type ColumnType `json:"type" yaml:"type"` Required bool `json:"required" yaml:"required"` Min *int64 `json:"min" yaml:"min"` @@ -100,7 +100,7 @@ var rawIntColumnFieldValidations = []*cr.StructFieldValidation{ } type RawFloatColumn struct { - ResourceConfigFields + ResourceFields Type ColumnType `json:"type" yaml:"type"` Required bool `json:"required" yaml:"required"` Min *float32 `json:"min" yaml:"min"` @@ -147,7 +147,7 @@ var rawFloatColumnFieldValidations = []*cr.StructFieldValidation{ } type RawStringColumn struct { - ResourceConfigFields + ResourceFields Type ColumnType `json:"type" yaml:"type"` Required bool `json:"required" yaml:"required"` Values []string `json:"values" yaml:"values"` diff --git a/pkg/operator/api/userconfig/resource.go b/pkg/operator/api/userconfig/resource.go index a996845604..da91815692 100644 --- a/pkg/operator/api/userconfig/resource.go +++ b/pkg/operator/api/userconfig/resource.go @@ -34,39 +34,39 @@ type Resource interface { SetEmbed(*Embed) } -type ResourceConfigFields struct { +type ResourceFields struct { Name string `json:"name" yaml:"name"` Index int `json:"index" yaml:"-"` FilePath string `json:"file_path" yaml:"-"` Embed *Embed `json:"embed" yaml:"-"` } -func (resourceConfigFields *ResourceConfigFields) GetName() string { - return resourceConfigFields.Name +func (resourceFields *ResourceFields) GetName() string { + return resourceFields.Name } -func (resourceConfigFields *ResourceConfigFields) GetIndex() int { - return resourceConfigFields.Index +func (resourceFields *ResourceFields) GetIndex() int { + return resourceFields.Index } -func (resourceConfigFields *ResourceConfigFields) SetIndex(index int) { - resourceConfigFields.Index = index +func (resourceFields *ResourceFields) SetIndex(index int) { + resourceFields.Index = index } -func (resourceConfigFields *ResourceConfigFields) GetFilePath() string { - return resourceConfigFields.FilePath +func (resourceFields *ResourceFields) GetFilePath() string { + return resourceFields.FilePath } -func (resourceConfigFields *ResourceConfigFields) SetFilePath(filePath string) { - resourceConfigFields.FilePath = filePath +func (resourceFields *ResourceFields) SetFilePath(filePath string) { + resourceFields.FilePath = filePath } -func (resourceConfigFields *ResourceConfigFields) GetEmbed() *Embed { - return resourceConfigFields.Embed +func (resourceFields *ResourceFields) GetEmbed() *Embed { + return resourceFields.Embed } -func (resourceConfigFields *ResourceConfigFields) SetEmbed(embed *Embed) { - resourceConfigFields.Embed = embed +func (resourceFields *ResourceFields) SetEmbed(embed *Embed) { + resourceFields.Embed = embed } func Identify(r Resource) string { diff --git a/pkg/operator/api/userconfig/templates.go b/pkg/operator/api/userconfig/templates.go index c0149a75b1..85f728eae4 100644 --- a/pkg/operator/api/userconfig/templates.go +++ b/pkg/operator/api/userconfig/templates.go @@ -31,7 +31,7 @@ var templateVarRegex = regexp.MustCompile("\\{\\s*([a-zA-Z0-9_-]+)\\s*\\}") type Templates []*Template type Template struct { - ResourceConfigFields + ResourceFields YAML string `json:"yaml" yaml:"yaml"` } diff --git a/pkg/operator/api/userconfig/transformed_columns.go b/pkg/operator/api/userconfig/transformed_columns.go index a5118bb3a1..476b71f368 100644 --- a/pkg/operator/api/userconfig/transformed_columns.go +++ b/pkg/operator/api/userconfig/transformed_columns.go @@ -26,11 +26,12 @@ import ( type TransformedColumns []*TransformedColumn type TransformedColumn struct { - ResourceConfigFields - Transformer string `json:"transformer" yaml:"transformer"` - Inputs *Inputs `json:"inputs" yaml:"inputs"` - Compute *SparkCompute `json:"compute" yaml:"compute"` - Tags Tags `json:"tags" yaml:"tags"` + ResourceFields + Transformer string `json:"transformer" yaml:"transformer"` + TransformerPath *string `json:"transformer_path" yaml:"transformer_path"` + Inputs *Inputs `json:"inputs" yaml:"inputs"` + Compute *SparkCompute `json:"compute" yaml:"compute"` + Tags Tags `json:"tags" yaml:"tags"` } var transformedColumnValidation = &configreader.StructValidation{ @@ -45,10 +46,14 @@ var transformedColumnValidation = &configreader.StructValidation{ { StructField: "Transformer", StringValidation: &configreader.StringValidation{ - Required: true, - AlphaNumericDashDotUnderscore: true, + AllowEmpty: true, + AlphaNumericDashDotUnderscoreOrEmpty: true, }, }, + { + StructField: "TransformerPath", + StringPtrValidation: &configreader.StringPtrValidation{}, + }, inputValuesFieldValidation, sparkComputeFieldValidation("Compute"), tagsFieldValidation, diff --git a/pkg/operator/api/userconfig/transformers.go b/pkg/operator/api/userconfig/transformers.go index cb7b09744c..6eada9effa 100644 --- a/pkg/operator/api/userconfig/transformers.go +++ b/pkg/operator/api/userconfig/transformers.go @@ -24,7 +24,7 @@ import ( type Transformers []*Transformer type Transformer struct { - ResourceConfigFields + ResourceFields Inputs *Inputs `json:"inputs" yaml:"inputs"` OutputType ColumnType `json:"output_type" yaml:"output_type"` Path string `json:"path" yaml:"path"` diff --git a/pkg/operator/context/aggregates.go b/pkg/operator/context/aggregates.go index f5541f9741..d4ea14374e 100644 --- a/pkg/operator/context/aggregates.go +++ b/pkg/operator/context/aggregates.go @@ -80,11 +80,13 @@ func getAggregates( buf.WriteString(aggregateConfig.Tags.ID()) idWithTags := hash.Bytes(buf.Bytes()) - aggregateKey := filepath.Join( + aggregateRootKey := filepath.Join( root, consts.AggregatesDir, - id+".msgpack", + id, ) + aggregateKey := aggregateRootKey + ".msgpack" + aggregateMetadataKey := aggregateRootKey + "_metadata.json" aggregates[aggregateConfig.Name] = &context.Aggregate{ ComputedResourceFields: &context.ComputedResourceFields{ @@ -92,6 +94,7 @@ func getAggregates( ID: id, IDWithTags: idWithTags, ResourceType: resource.AggregateType, + MetadataKey: aggregateMetadataKey, }, }, Aggregate: aggregateConfig, @@ -109,6 +112,9 @@ func validateAggregateInputs( rawColumns context.RawColumns, aggregator *context.Aggregator, ) error { + if aggregateConfig.AggregatorPath != nil { + return nil + } columnRuntimeTypes, err := context.GetColumnRuntimeTypes(aggregateConfig.Inputs.Columns, rawColumns) if err != nil { diff --git a/pkg/operator/context/aggregators.go b/pkg/operator/context/aggregators.go index bfe2f48c27..0129bb658f 100644 --- a/pkg/operator/context/aggregators.go +++ b/pkg/operator/context/aggregators.go @@ -30,13 +30,13 @@ import ( ) func loadUserAggregators( - aggregatorConfigs userconfig.Aggregators, + config *userconfig.Config, impls map[string][]byte, pythonPackages context.PythonPackages, ) (map[string]*context.Aggregator, error) { userAggregators := make(map[string]*context.Aggregator) - for _, aggregatorConfig := range aggregatorConfigs { + for _, aggregatorConfig := range config.Aggregators { impl, ok := impls[aggregatorConfig.Path] if !ok { return nil, errors.Wrap(ErrorImplDoesNotExist(aggregatorConfig.Path), userconfig.Identify(aggregatorConfig)) @@ -48,6 +48,36 @@ func loadUserAggregators( userAggregators[aggregator.Name] = aggregator } + for _, aggregateConfig := range config.Aggregates { + if aggregateConfig.AggregatorPath == nil { + continue + } + + impl, ok := impls[*aggregateConfig.AggregatorPath] + if !ok { + return nil, errors.Wrap(ErrorImplDoesNotExist(*aggregateConfig.AggregatorPath), userconfig.Identify(aggregateConfig)) + } + + implHash := hash.Bytes(impl) + if _, ok := userAggregators[implHash]; ok { + continue + } + + anonAggregatorConfig := &userconfig.Aggregator{ + ResourceFields: userconfig.ResourceFields{ + Name: implHash, + }, + Path: *aggregateConfig.AggregatorPath, + } + aggregator, err := newAggregator(*anonAggregatorConfig, impl, nil, pythonPackages) + if err != nil { + return nil, err + } + + aggregateConfig.Aggregator = aggregator.Name + userAggregators[anonAggregatorConfig.Name] = aggregator + } + return userAggregators, nil } @@ -76,6 +106,7 @@ func newAggregator( ID: id, IDWithTags: id, ResourceType: resource.AggregatorType, + MetadataKey: filepath.Join(consts.AggregatorsDir, id+"_metadata.json"), }, Aggregator: &aggregatorConfig, Namespace: namespace, @@ -132,15 +163,14 @@ func getAggregators( aggregators := context.Aggregators{} for _, aggregateConfig := range config.Aggregates { - aggregatorName := aggregateConfig.Aggregator - if _, ok := aggregators[aggregatorName]; ok { + if _, ok := aggregators[aggregateConfig.Aggregator]; ok { continue } - aggregator, err := getAggregator(aggregatorName, userAggregators) + aggregator, err := getAggregator(aggregateConfig.Aggregator, userAggregators) if err != nil { return nil, errors.Wrap(err, userconfig.Identify(aggregateConfig), userconfig.AggregatorKey) } - aggregators[aggregatorName] = aggregator + aggregators[aggregateConfig.Aggregator] = aggregator } return aggregators, nil diff --git a/pkg/operator/context/apis.go b/pkg/operator/context/apis.go index 1ff5045575..9f0ef58bae 100644 --- a/pkg/operator/context/apis.go +++ b/pkg/operator/context/apis.go @@ -18,7 +18,9 @@ package context import ( "bytes" + "path/filepath" + "github.com/cortexlabs/cortex/pkg/consts" "github.com/cortexlabs/cortex/pkg/lib/hash" "github.com/cortexlabs/cortex/pkg/operator/api/context" "github.com/cortexlabs/cortex/pkg/operator/api/resource" @@ -48,6 +50,7 @@ func getAPIs(config *userconfig.Config, ID: id, IDWithTags: idWithTags, ResourceType: resource.APIType, + MetadataKey: filepath.Join(consts.APIsDir, id+"_metadata.json"), }, }, API: apiConfig, diff --git a/pkg/operator/context/autogenerator.go b/pkg/operator/context/autogenerator.go index e57259eab8..a28e0d7e6a 100644 --- a/pkg/operator/context/autogenerator.go +++ b/pkg/operator/context/autogenerator.go @@ -47,9 +47,14 @@ func autoGenerateConfig( if err != nil { return errors.Wrap(err, userconfig.Identify(aggregate), userconfig.AggregatorKey) } - argType, ok := aggregator.Inputs.Args[argName] - if !ok { - return errors.Wrap(configreader.ErrorUnsupportedKey(argName), userconfig.Identify(aggregate), userconfig.InputsKey, userconfig.ArgsKey) + + var argType interface{} + if aggregator.Inputs != nil { + var ok bool + argType, ok = aggregator.Inputs.Args[argName] + if !ok { + return errors.Wrap(configreader.ErrorUnsupportedKey(argName), userconfig.Identify(aggregate), userconfig.InputsKey, userconfig.ArgsKey) + } } constantName := strings.Join([]string{ @@ -61,7 +66,7 @@ func autoGenerateConfig( }, "/") constant := &userconfig.Constant{ - ResourceConfigFields: userconfig.ResourceConfigFields{ + ResourceFields: userconfig.ResourceFields{ Name: constantName, }, Type: argType, @@ -88,9 +93,14 @@ func autoGenerateConfig( if err != nil { return errors.Wrap(err, userconfig.Identify(transformedColumn), userconfig.TransformerKey) } - argType, ok := transformer.Inputs.Args[argName] - if !ok { - return errors.Wrap(configreader.ErrorUnsupportedKey(argName), userconfig.Identify(transformedColumn), userconfig.InputsKey, userconfig.ArgsKey) + + var argType interface{} + if transformer.Inputs != nil { + var ok bool + argType, ok = transformer.Inputs.Args[argName] + if !ok { + return errors.Wrap(configreader.ErrorUnsupportedKey(argName), userconfig.Identify(transformedColumn), userconfig.InputsKey, userconfig.ArgsKey) + } } constantName := strings.Join([]string{ @@ -102,7 +112,7 @@ func autoGenerateConfig( }, "/") constant := &userconfig.Constant{ - ResourceConfigFields: userconfig.ResourceConfigFields{ + ResourceFields: userconfig.ResourceFields{ Name: constantName, }, Type: argType, @@ -115,8 +125,5 @@ func autoGenerateConfig( } } - if err := config.Validate(config.Environment.Name); err != nil { - return err - } return nil } diff --git a/pkg/operator/context/constants.go b/pkg/operator/context/constants.go index bdb2edcba1..ecc1836d8b 100644 --- a/pkg/operator/context/constants.go +++ b/pkg/operator/context/constants.go @@ -59,6 +59,7 @@ func newConstant(constantConfig userconfig.Constant) (*context.Constant, error) ID: id, IDWithTags: idWithTags, ResourceType: resource.ConstantType, + MetadataKey: filepath.Join(consts.ConstantsDir, id+"_metadata.json"), }, Constant: &constantConfig, Key: filepath.Join(consts.ConstantsDir, id+".msgpack"), diff --git a/pkg/operator/context/context.go b/pkg/operator/context/context.go index 3c5babff09..1d9d0d87ce 100644 --- a/pkg/operator/context/context.go +++ b/pkg/operator/context/context.go @@ -130,12 +130,12 @@ func New( } ctx.PythonPackages = pythonPackages - userTransformers, err := loadUserTransformers(userconf.Transformers, files, pythonPackages) + userTransformers, err := loadUserTransformers(userconf, files, pythonPackages) if err != nil { return nil, err } - userAggregators, err := loadUserAggregators(userconf.Aggregators, files, pythonPackages) + userAggregators, err := loadUserAggregators(userconf, files, pythonPackages) if err != nil { return nil, err } diff --git a/pkg/operator/context/models.go b/pkg/operator/context/models.go index 8397dd09fc..a1c900e50a 100644 --- a/pkg/operator/context/models.go +++ b/pkg/operator/context/models.go @@ -85,7 +85,6 @@ func getModels( datasetIDWithTags := hash.Bytes(buf.Bytes()) datasetRoot := filepath.Join(root, consts.TrainingDataDir, datasetID) - trainingDatasetName := strings.Join([]string{ modelConfig.Name, resource.TrainingDatasetType.String(), @@ -97,6 +96,7 @@ func getModels( ID: modelID, IDWithTags: modelID, ResourceType: resource.ModelType, + MetadataKey: filepath.Join(root, consts.ModelsDir, modelID+"_metadata.json"), }, }, Model: modelConfig, @@ -104,7 +104,7 @@ func getModels( ImplID: modelImplID, ImplKey: modelImplKey, Dataset: &context.TrainingDataset{ - ResourceConfigFields: userconfig.ResourceConfigFields{ + ResourceFields: userconfig.ResourceFields{ Name: trainingDatasetName, FilePath: modelConfig.FilePath, Embed: modelConfig.Embed, @@ -114,12 +114,12 @@ func getModels( ID: datasetID, IDWithTags: datasetIDWithTags, ResourceType: resource.TrainingDatasetType, + MetadataKey: filepath.Join(datasetRoot, "metadata.json"), }, }, - ModelName: modelConfig.Name, - TrainKey: filepath.Join(datasetRoot, "train.tfrecord"), - EvalKey: filepath.Join(datasetRoot, "eval.tfrecord"), - MetadataKey: filepath.Join(datasetRoot, "metadata.json"), + ModelName: modelConfig.Name, + TrainKey: filepath.Join(datasetRoot, "train.tfrecord"), + EvalKey: filepath.Join(datasetRoot, "eval.tfrecord"), }, } } diff --git a/pkg/operator/context/python_packages.go b/pkg/operator/context/python_packages.go index adc0afa7e9..8cf89da0a5 100644 --- a/pkg/operator/context/python_packages.go +++ b/pkg/operator/context/python_packages.go @@ -57,13 +57,14 @@ func loadPythonPackages(files map[string][]byte, datasetVersion string) (context buf.WriteString(datasetVersion) id := hash.Bytes(buf.Bytes()) pythonPackage := context.PythonPackage{ - ResourceConfigFields: userconfig.ResourceConfigFields{ + ResourceFields: userconfig.ResourceFields{ Name: consts.RequirementsTxt, }, ComputedResourceFields: &context.ComputedResourceFields{ ResourceFields: &context.ResourceFields{ ID: id, ResourceType: resource.PythonPackageType, + MetadataKey: filepath.Join(consts.PythonPackagesDir, id, "metadata.json"), }, }, SrcKey: filepath.Join(consts.PythonPackagesDir, id, "src.txt"), @@ -95,13 +96,14 @@ func loadPythonPackages(files map[string][]byte, datasetVersion string) (context } id := hash.Bytes(buf.Bytes()) pythonPackage := context.PythonPackage{ - ResourceConfigFields: userconfig.ResourceConfigFields{ + ResourceFields: userconfig.ResourceFields{ Name: packageName, }, ComputedResourceFields: &context.ComputedResourceFields{ ResourceFields: &context.ResourceFields{ ID: id, ResourceType: resource.PythonPackageType, + MetadataKey: filepath.Join(consts.PythonPackagesDir, id, "metadata.json"), }, }, SrcKey: filepath.Join(consts.PythonPackagesDir, id, "src.zip"), diff --git a/pkg/operator/context/raw_columns.go b/pkg/operator/context/raw_columns.go index 4ccf432aef..a3102cff45 100644 --- a/pkg/operator/context/raw_columns.go +++ b/pkg/operator/context/raw_columns.go @@ -18,7 +18,9 @@ package context import ( "bytes" + "path/filepath" + "github.com/cortexlabs/cortex/pkg/consts" "github.com/cortexlabs/cortex/pkg/lib/configreader" "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/hash" @@ -57,6 +59,7 @@ func getRawColumns( ID: id, IDWithTags: idWithTags, ResourceType: resource.RawColumnType, + MetadataKey: filepath.Join(consts.RawColumnsDir, id+"_metadata.json"), }, }, RawIntColumn: typedColumnConfig, @@ -74,6 +77,7 @@ func getRawColumns( ID: id, IDWithTags: idWithTags, ResourceType: resource.RawColumnType, + MetadataKey: filepath.Join(consts.RawColumnsDir, id+"_metadata.json"), }, }, RawFloatColumn: typedColumnConfig, @@ -89,6 +93,7 @@ func getRawColumns( ID: id, IDWithTags: idWithTags, ResourceType: resource.RawColumnType, + MetadataKey: filepath.Join(consts.RawColumnsDir, id+"_metadata.json"), }, }, RawStringColumn: typedColumnConfig, diff --git a/pkg/operator/context/transformed_columns.go b/pkg/operator/context/transformed_columns.go index c6bf707c89..bcc94cea9c 100644 --- a/pkg/operator/context/transformed_columns.go +++ b/pkg/operator/context/transformed_columns.go @@ -18,7 +18,9 @@ package context import ( "bytes" + "path/filepath" + "github.com/cortexlabs/cortex/pkg/consts" "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/hash" s "github.com/cortexlabs/cortex/pkg/lib/strings" @@ -80,6 +82,7 @@ func getTransformedColumns( ID: id, IDWithTags: idWithTags, ResourceType: resource.TransformedColumnType, + MetadataKey: filepath.Join(consts.TransformedColumnsDir, id+"_metadata.json"), }, }, TransformedColumn: transformedColumnConfig, @@ -97,6 +100,9 @@ func validateTransformedColumnInputs( aggregates context.Aggregates, transformer *context.Transformer, ) error { + if transformedColumnConfig.TransformerPath != nil { + return nil + } columnRuntimeTypes, err := context.GetColumnRuntimeTypes(transformedColumnConfig.Inputs.Columns, rawColumns) if err != nil { diff --git a/pkg/operator/context/transformers.go b/pkg/operator/context/transformers.go index ff4bad3b2c..a335ddcd52 100644 --- a/pkg/operator/context/transformers.go +++ b/pkg/operator/context/transformers.go @@ -30,13 +30,13 @@ import ( ) func loadUserTransformers( - transConfigs userconfig.Transformers, + config *userconfig.Config, impls map[string][]byte, pythonPackages context.PythonPackages, ) (map[string]*context.Transformer, error) { userTransformers := make(map[string]*context.Transformer) - for _, transConfig := range transConfigs { + for _, transConfig := range config.Transformers { impl, ok := impls[transConfig.Path] if !ok { return nil, errors.Wrap(ErrorImplDoesNotExist(transConfig.Path), userconfig.Identify(transConfig)) @@ -48,6 +48,34 @@ func loadUserTransformers( userTransformers[transformer.Name] = transformer } + for _, transColConfig := range config.TransformedColumns { + if transColConfig.TransformerPath == nil { + continue + } + + impl, ok := impls[*transColConfig.TransformerPath] + if !ok { + return nil, errors.Wrap(ErrorImplDoesNotExist(*transColConfig.TransformerPath), userconfig.Identify(transColConfig)) + } + + implHash := hash.Bytes(impl) + if _, ok := userTransformers[implHash]; ok { + continue + } + + anonTransformerConfig := &userconfig.Transformer{ + ResourceFields: userconfig.ResourceFields{ + Name: implHash, + }, + Path: *transColConfig.TransformerPath, + } + transformer, err := newTransformer(*anonTransformerConfig, impl, nil, pythonPackages) + if err != nil { + return nil, err + } + transColConfig.Transformer = transformer.Name + userTransformers[transformer.Name] = transformer + } return userTransformers, nil } @@ -75,6 +103,7 @@ func newTransformer( ID: id, IDWithTags: id, ResourceType: resource.TransformerType, + MetadataKey: filepath.Join(consts.TransformersDir, id+"_metadata.json"), }, Transformer: &transConfig, Namespace: namespace, @@ -114,7 +143,6 @@ func getTransformer( name string, userTransformers map[string]*context.Transformer, ) (*context.Transformer, error) { - if transformer, ok := builtinTransformers[name]; ok { return transformer, nil } @@ -131,15 +159,15 @@ func getTransformers( transformers := context.Transformers{} for _, transformedColumnConfig := range config.TransformedColumns { - transformerName := transformedColumnConfig.Transformer - if _, ok := transformers[transformerName]; ok { + if _, ok := transformers[transformedColumnConfig.Transformer]; ok { continue } - transformer, err := getTransformer(transformerName, userTransformers) + + transformer, err := getTransformer(transformedColumnConfig.Transformer, userTransformers) if err != nil { return nil, errors.Wrap(err, userconfig.Identify(transformedColumnConfig), userconfig.TransformerKey) } - transformers[transformerName] = transformer + transformers[transformedColumnConfig.Transformer] = transformer } return transformers, nil diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index 7dc2de1ab6..d8286c1c63 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -80,7 +80,6 @@ def __init__(self, **kwargs): self.models = self.ctx["models"] self.apis = self.ctx["apis"] self.training_datasets = {k: v["dataset"] for k, v in self.models.items()} - self.api_version = self.cortex_config["api_version"] if "local_storage_path" in kwargs: @@ -113,6 +112,7 @@ def __init__(self, **kwargs): self._transformer_impls = {} self._aggregator_impls = {} self._model_impls = {} + self._metadatas = {} # This affects Tensorflow S3 access os.environ["AWS_REGION"] = self.cortex_config.get("region", "") @@ -228,7 +228,6 @@ def get_transformer_impl(self, column_name): return None, None transformer_name = self.transformed_columns[column_name]["transformer"] - if transformer_name in self._transformer_impls: return self._transformer_impls[transformer_name] @@ -469,6 +468,30 @@ def upload_resource_status_end(self, exit_code, *resources): def resource_status_key(self, resource): return os.path.join(self.status_prefix, resource["id"], resource["workload_id"]) + def write_metadata(self, resource_id, metadata_key, metadata): + if resource_id in self._metadatas and self._metadatas[resource_id] == metadata: + return + + self._metadatas[resource_id] = metadata + self.storage.put_json(metadata, metadata_key) + + def get_metadata(self, resource_id, metadata_key, use_cache=True): + if use_cache and resource_id in self._metadatas: + return self._metadatas[resource_id] + + metadata = self.storage.get_json(metadata_key, allow_missing=True) + self._metadatas[resource_id] = metadata + return metadata + + def get_inferred_column_type(self, column_name): + column = self.columns[column_name] + column_type = self.columns[column_name].get("type", "unknown") + if column_type == "unknown": + column_type = self.get_metadata(column["id"], column["metadata_key"])["type"] + self.columns[column_name]["type"] = column_type + + return column_type + MODEL_IMPL_VALIDATION = { "required": [{"name": "create_estimator", "args": ["run_config", "model_config"]}], @@ -483,6 +506,8 @@ def resource_status_key(self, resource): "optional": [ {"name": "transform_spark", "args": ["data", "columns", "args", "transformed_column_name"]}, {"name": "reverse_transform_python", "args": ["transformed_value", "args"]}, + # it is possible to not define transform_python() + # if you are only using the transformation for training columns, and not for inference {"name": "transform_python", "args": ["sample", "args"]}, ] } diff --git a/pkg/workloads/lib/storage/s3.py b/pkg/workloads/lib/storage/s3.py index 2d7f77ac9f..90348dff04 100644 --- a/pkg/workloads/lib/storage/s3.py +++ b/pkg/workloads/lib/storage/s3.py @@ -131,10 +131,10 @@ def put_json(self, obj, key): self._upload_string_to_s3(json.dumps(obj), key) def get_json(self, key, allow_missing=False): - obj = self._read_bytes_from_s3(key, allow_missing).decode("utf-8") + obj = self._read_bytes_from_s3(key, allow_missing) if obj is None: return None - return json.loads(obj) + return json.loads(obj.decode("utf-8")) def put_msgpack(self, obj, key): self._upload_string_to_s3(msgpack.dumps(obj), key) diff --git a/pkg/workloads/lib/tf_lib.py b/pkg/workloads/lib/tf_lib.py index ce21da7550..9123e69160 100644 --- a/pkg/workloads/lib/tf_lib.py +++ b/pkg/workloads/lib/tf_lib.py @@ -57,10 +57,10 @@ def get_column_tf_types(model_name, ctx, training=True): """Generate a dict {column name -> tf_type}""" model = ctx.models[model_name] - column_types = { - column_name: CORTEX_TYPE_TO_TF_TYPE[ctx.columns[column_name]["type"]] - for column_name in model["feature_columns"] - } + column_types = {} + for column_name in model["feature_columns"]: + column_type = ctx.get_inferred_column_type(column_name) + column_types[column_name] = CORTEX_TYPE_TO_TF_TYPE[column_type] if training: target_column_name = model["target_column"] @@ -69,7 +69,8 @@ def get_column_tf_types(model_name, ctx, training=True): ] for column_name in model["training_columns"]: - column_types[column_name] = CORTEX_TYPE_TO_TF_TYPE[ctx.columns[column_name]["type"]] + column_type = ctx.get_inferred_column_type(column_name) + column_types[column_name] = CORTEX_TYPE_TO_TF_TYPE[column_type] return column_types @@ -79,7 +80,8 @@ def get_feature_spec(model_name, ctx, training=True): column_types = get_column_tf_types(model_name, ctx, training) feature_spec = {} for column_name, tf_type in column_types.items(): - if ctx.columns[column_name]["type"] in consts.COLUMN_LIST_TYPES: + column_type = ctx.get_inferred_column_type(column_name) + if column_type in consts.COLUMN_LIST_TYPES: feature_spec[column_name] = tf.FixedLenSequenceFeature( shape=(), dtype=tf_type, allow_missing=True ) diff --git a/pkg/workloads/spark_job/spark_job.py b/pkg/workloads/spark_job/spark_job.py index 3c7472d676..53247f090d 100644 --- a/pkg/workloads/spark_job/spark_job.py +++ b/pkg/workloads/spark_job/spark_job.py @@ -91,7 +91,9 @@ def parse_args(args): def validate_dataset(ctx, raw_df, cols_to_validate): - total_row_count = ctx.storage.get_json(ctx.raw_dataset["metadata_key"])["dataset_size"] + total_row_count = ctx.get_metadata(ctx.raw_dataset["key"], ctx.raw_dataset["metadata_key"])[ + "dataset_size" + ] conditions_dict = spark_util.value_check_data(ctx, raw_df, cols_to_validate) if len(conditions_dict) > 0: @@ -160,8 +162,11 @@ def ingest_raw_dataset(spark, ctx, cols_to_validate, should_ingest): ingest_df = limit_dataset(full_dataset_size, ingest_df, ctx.environment["limit"]) written_count = write_raw_dataset(ingest_df, ctx, spark) - metadata = {"dataset_size": written_count} - ctx.storage.put_json(metadata, ctx.raw_dataset["metadata_key"]) + ctx.write_metadata( + ctx.raw_dataset["key"], + ctx.raw_dataset["metadata_key"], + {"dataset_size": written_count}, + ) if written_count != full_dataset_size: logger.info( "{} rows read, {} rows dropped, {} rows ingested".format( diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index c6c75c4d38..558eb8b99a 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -48,6 +48,31 @@ consts.COLUMN_TYPE_STRING_LIST: [ArrayType(StringType(), True)], } +PYTHON_TYPE_TO_CORTEX_TYPE = { + int: consts.COLUMN_TYPE_INT, + float: consts.COLUMN_TYPE_FLOAT, + str: consts.COLUMN_TYPE_STRING, +} + +PYTHON_TYPE_TO_CORTEX_LIST_TYPE = { + int: consts.COLUMN_TYPE_INT_LIST, + float: consts.COLUMN_TYPE_FLOAT_LIST, + str: consts.COLUMN_TYPE_STRING_LIST, +} + +SPARK_TYPE_TO_CORTEX_TYPE = { + IntegerType(): consts.COLUMN_TYPE_INT, + LongType(): consts.COLUMN_TYPE_INT, + ArrayType(IntegerType(), True): consts.COLUMN_TYPE_INT_LIST, + ArrayType(LongType(), True): consts.COLUMN_TYPE_INT_LIST, + FloatType(): consts.COLUMN_TYPE_FLOAT, + DoubleType(): consts.COLUMN_TYPE_FLOAT, + ArrayType(FloatType(), True): consts.COLUMN_TYPE_FLOAT_LIST, + ArrayType(DoubleType(), True): consts.COLUMN_TYPE_FLOAT_LIST, + StringType(): consts.COLUMN_TYPE_STRING, + ArrayType(StringType(), True): consts.COLUMN_TYPE_STRING_LIST, +} + def accumulate_count(df, spark): acc = df._sc.accumulator(0) @@ -94,8 +119,11 @@ def write_training_data(model_name, df, ctx, spark): ctx.storage.hadoop_path(training_dataset["eval_key"]) ) - metadata = {"training_size": train_df_acc.value, "eval_size": eval_df_acc.value} - ctx.storage.put_json(metadata, training_dataset["metadata_key"]) + ctx.write_metadata( + training_dataset["id"], + training_dataset["metadata_key"], + {"training_size": train_df_acc.value, "eval_size": eval_df_acc.value}, + ) return df @@ -373,7 +401,7 @@ def run_custom_aggregator(aggregator_resource, df, ctx, spark): aggregator_column_input = input_schema["columns"] args_schema = input_schema["args"] args = {} - if input_schema.get("args", None) is not None and len(input_schema["args"]) > 0: + if input_schema.get("args", None) is not None and len(args_schema) > 0: args = ctx.populate_args(input_schema["args"]) try: result = aggregator_impl.aggregate_spark(df, aggregator_column_input, args) @@ -384,7 +412,9 @@ def run_custom_aggregator(aggregator_resource, df, ctx, spark): "function aggregate_spark", ) from e - if not util.validate_value_type(result, aggregator["output_type"]): + if aggregator["output_type"] and not util.validate_value_type( + result, aggregator["output_type"] + ): raise UserException( "aggregate " + aggregator_resource["name"], "aggregator " + aggregator["name"], @@ -435,34 +465,75 @@ def _transform(*values): if validate: transformed_column = ctx.transformed_columns[column_name] + column_type = ctx.get_inferred_column_type(column_name) def _transform_and_validate(*values): result = _transform(*values) - if not util.validate_column_type(result, transformed_column["type"]): + if not util.validate_column_type(result, column_type): raise UserException( "transformed column " + column_name, "tranformation " + transformed_column["transformer"], - "type of {} is not {}".format(result, transformed_column["type"]), + "type of {} is not {}".format(result, column_type), ) return result transform_python_func = _transform_and_validate - column_data_type_str = ctx.transformed_columns[column_name]["type"] + column_data_type_str = ctx.get_inferred_column_type(column_name) transform_udf = F.udf(transform_python_func, CORTEX_TYPE_TO_SPARK_TYPE[column_data_type_str]) return df.withColumn(column_name, transform_udf(*required_columns_sorted)) -def validate_transformer(column_name, df, ctx, spark): - transformed_column = ctx.transformed_columns[column_name] +def infer_type(obj): + obj_type = type(obj) + + if obj_type == list: + obj_type = type(obj[0]) + return PYTHON_TYPE_TO_CORTEX_LIST_TYPE[obj_type] + return PYTHON_TYPE_TO_CORTEX_TYPE[obj_type] + + +def validate_transformer(column_name, test_df, ctx, spark): + transformed_column = ctx.transformed_columns[column_name] + transformer = ctx.transformers[transformed_column["transformer"]] trans_impl, _ = ctx.get_transformer_impl(column_name) + inferred_python_type = None + inferred_spark_type = None + if hasattr(trans_impl, "transform_python"): try: + if transformer["output_type"] == "unknown": + sample_df = test_df.collect() + sample = sample_df[0] + inputs = ctx.create_column_inputs_map(sample, column_name) + _, impl_args = extract_inputs(column_name, ctx) + initial_transformed_sample = trans_impl.transform_python(inputs, impl_args) + inferred_python_type = infer_type(initial_transformed_sample) + + for row in sample_df: + inputs = ctx.create_column_inputs_map(row, column_name) + transformed_sample = trans_impl.transform_python(inputs, impl_args) + if inferred_python_type != infer_type(transformed_sample): + raise UserRuntimeException( + "transformed column " + column_name, + "type inference failed, mixed data types in dataframe.", + 'expected type of "' + + transformed_sample + + '" to be ' + + inferred_python_type, + ) + + ctx.write_metadata( + transformed_column["id"], + transformed_column["metadata_key"], + {"type": inferred_python_type}, + ) + transform_python_collect = execute_transform_python( - column_name, df, ctx, spark, validate=True + column_name, test_df, ctx, spark, validate=True ).collect() except Exception as e: raise UserRuntimeException( @@ -471,9 +542,8 @@ def validate_transformer(column_name, df, ctx, spark): ) from e if hasattr(trans_impl, "transform_spark"): - try: - transform_spark_df = execute_transform_spark(column_name, df, ctx, spark) + transform_spark_df = execute_transform_spark(column_name, test_df, ctx, spark) # check that the return object is a dataframe if type(transform_spark_df) is not DataFrame: @@ -505,33 +575,43 @@ def validate_transformer(column_name, df, ctx, spark): # check that expected output column has the correct data type if ( actual_structfield.dataType - not in CORTEX_TYPE_TO_ACCEPTABLE_SPARK_TYPES[transformed_column["type"]] + not in CORTEX_TYPE_TO_ACCEPTABLE_SPARK_TYPES[ + ctx.get_inferred_column_type(column_name) + ] ): raise UserException( "incorrect column type, expected {}, found {}.".format( " or ".join( str(t) for t in CORTEX_TYPE_TO_ACCEPTABLE_SPARK_TYPES[ - transformed_column["type"] + ctx.get_inferred_column_type(column_name) ] ), actual_structfield.dataType, ) ) + if transformer["output_type"] == "unknown": + inferred_spark_type = transform_spark_df.select(column_name).schema[0].dataType + ctx.write_metadata( + transformed_column["id"], + transformed_column["metadata_key"], + {"type": inferred_spark_type}, + ) + # perform the necessary upcast/downcast for the column e.g INT -> LONG or DOUBLE -> FLOAT transform_spark_df = transform_spark_df.withColumn( column_name, F.col(column_name).cast( - CORTEX_TYPE_TO_SPARK_TYPE[ctx.transformed_columns[column_name]["type"]] + CORTEX_TYPE_TO_SPARK_TYPE[ctx.get_inferred_column_type(column_name)] ), ) # check that the function doesn't modify the schema of the other columns in the input dataframe - if set(transform_spark_df.columns) - set([column_name]) != set(df.columns): + if set(transform_spark_df.columns) - set([column_name]) != set(test_df.columns): logger.error("expected schema:") - log_df_schema(df, logger.error) + log_df_schema(test_df, logger.error) logger.error("found schema (with {} dropped):".format(column_name)) log_df_schema(transform_spark_df.drop(column_name), logger.error) @@ -546,28 +626,36 @@ def validate_transformer(column_name, df, ctx, spark): ) raise - if hasattr(trans_impl, "transform_spark") and hasattr(trans_impl, "transform_python"): - name_type_map = [(s.name, s.dataType) for s in transform_spark_df.schema] - transform_spark_collect = transform_spark_df.collect() - - for tp_row, ts_row in zip(transform_python_collect, transform_spark_collect): - tp_dict = tp_row.asDict() - ts_dict = ts_row.asDict() - - for name, dataType in name_type_map: - if tp_dict[name] == ts_dict[name]: - continue - elif dataType == FloatType() and util.isclose( - tp_dict[name], ts_dict[name], FLOAT_PRECISION - ): - continue - raise UserException( - column_name, - "{0}.transform_spark and {0}.transform_python had differing values".format( - transformed_column["transformer"] - ), - "{} != {}".format(ts_row, tp_row), - ) + if hasattr(trans_impl, "transform_spark") and hasattr(trans_impl, "transform_python"): + if transformer["output_type"] == "unknown" and inferred_spark_type != inferred_python_type: + raise UserRuntimeException( + "transformed column " + column_name, + "type inference failed, transform_spark and transform_python had differing types.", + "transform_python: " + inferred_python_type, + "transform_spark: " + inferred_spark_type, + ) + + name_type_map = [(s.name, s.dataType) for s in transform_spark_df.schema] + transform_spark_collect = transform_spark_df.collect() + + for tp_row, ts_row in zip(transform_python_collect, transform_spark_collect): + tp_dict = tp_row.asDict() + ts_dict = ts_row.asDict() + + for name, dataType in name_type_map: + if tp_dict[name] == ts_dict[name]: + continue + elif dataType == FloatType() and util.isclose( + tp_dict[name], ts_dict[name], FLOAT_PRECISION + ): + continue + raise UserException( + column_name, + "{0}.transform_spark and {0}.transform_python had differing values".format( + transformed_column["transformer"] + ), + "{} != {}".format(ts_row, tp_row), + ) def transform_column(column_name, df, ctx, spark): @@ -575,16 +663,14 @@ def transform_column(column_name, df, ctx, spark): return df if column_name in df.columns: return df + transformed_column = ctx.transformed_columns[column_name] + trans_impl, _ = ctx.get_transformer_impl(column_name) - trans_impl, trans_impl_path = ctx.get_transformer_impl(column_name) if hasattr(trans_impl, "transform_spark"): - return execute_transform_spark(column_name, df, ctx, spark).withColumn( - column_name, - F.col(column_name).cast( - CORTEX_TYPE_TO_SPARK_TYPE[ctx.transformed_columns[column_name]["type"]] - ), - ) + column_type = CORTEX_TYPE_TO_SPARK_TYPE[ctx.get_inferred_column_type(column_name)] + df = execute_transform_spark(column_name, df, ctx, spark) + return df.withColumn(column_name, F.col(column_name).cast(column_type)) elif hasattr(trans_impl, "transform_python"): return execute_transform_python(column_name, df, ctx, spark) else: diff --git a/pkg/workloads/spark_job/test/integration/iris_test.py b/pkg/workloads/spark_job/test/integration/iris_test.py index 73ff1e9ec1..35eca5e371 100644 --- a/pkg/workloads/spark_job/test/integration/iris_test.py +++ b/pkg/workloads/spark_job/test/integration/iris_test.py @@ -77,7 +77,10 @@ def test_simple_end_to_end(spark): raw_df = spark_job.ingest_raw_dataset(spark, ctx, cols_to_validate, should_ingest) assert raw_df.count() == 15 - assert storage.get_json(ctx.raw_dataset["metadata_key"])["dataset_size"] == 15 + assert ( + ctx.get_metadata(ctx.raw_dataset["key"], ctx.raw_dataset["metadata_key"])["dataset_size"] + == 15 + ) for raw_column_id in cols_to_validate: path = os.path.join(raw_ctx["status_prefix"], raw_column_id, workload_id) status = storage.get_json(str(path)) @@ -117,7 +120,7 @@ def test_simple_end_to_end(spark): status["exist_code"] = "succeeded" dataset = raw_ctx["models"]["dnn"]["dataset"] - metadata_key = storage.get_json(dataset["metadata_key"]) - assert metadata_key["training_size"] + metadata_key["eval_size"] == 15 + metadata = ctx.get_metadata(dataset["id"], dataset["metadata_key"]) + assert metadata["training_size"] + metadata["eval_size"] == 15 assert local_storage_path.joinpath(dataset["train_key"], "_SUCCESS").exists() assert local_storage_path.joinpath(dataset["eval_key"], "_SUCCESS").exists() diff --git a/pkg/workloads/spark_job/test/unit/spark_util_test.py b/pkg/workloads/spark_job/test/unit/spark_util_test.py index 5fe7840e85..7b6b71fda0 100644 --- a/pkg/workloads/spark_job/test/unit/spark_util_test.py +++ b/pkg/workloads/spark_job/test/unit/spark_util_test.py @@ -14,6 +14,7 @@ import math import spark_util +import consts from lib.exceptions import UserException import pytest @@ -577,3 +578,13 @@ def test_run_builtin_aggregators_error(spark, ctx_obj, get_context): ctx.store_aggregate_result.assert_not_called() ctx.populate_args.assert_called_once_with({"ignoreNulls": "some_constant"}) + + +def test_infer_type(): + assert spark_util.infer_type(1) == consts.COLUMN_TYPE_INT + assert spark_util.infer_type(1.0) == consts.COLUMN_TYPE_FLOAT + assert spark_util.infer_type("cortex") == consts.COLUMN_TYPE_STRING + + assert spark_util.infer_type([1]) == consts.COLUMN_TYPE_INT_LIST + assert spark_util.infer_type([1.0]) == consts.COLUMN_TYPE_FLOAT_LIST + assert spark_util.infer_type(["cortex"]) == consts.COLUMN_TYPE_STRING_LIST diff --git a/pkg/workloads/tf_api/api.py b/pkg/workloads/tf_api/api.py index 3abb9ed08b..28eb1aab5e 100644 --- a/pkg/workloads/tf_api/api.py +++ b/pkg/workloads/tf_api/api.py @@ -95,7 +95,8 @@ def create_prediction_request(transformed_sample): prediction_request.model_spec.signature_name = signature_key for column_name, value in transformed_sample.items(): - data_type = tf_lib.CORTEX_TYPE_TO_TF_TYPE[ctx.columns[column_name]["type"]] + column_Type = ctx.get_inferred_column_type(column_name) + data_type = tf_lib.CORTEX_TYPE_TO_TF_TYPE[column_Type] shape = [1] if util.is_list(value): shape = [len(value)] diff --git a/pkg/workloads/tf_train/train_util.py b/pkg/workloads/tf_train/train_util.py index 894e5066b0..d41a3b4995 100644 --- a/pkg/workloads/tf_train/train_util.py +++ b/pkg/workloads/tf_train/train_util.py @@ -148,8 +148,8 @@ def train(model_name, model_impl, ctx, model_dir): serving_input_fn = generate_json_serving_input_fn(model_name, ctx, model_impl) exporter = tf.estimator.FinalExporter("estimator", serving_input_fn, as_text=False) - dataset_metadata = ctx.storage.get_json(model["dataset"]["metadata_key"]) train_num_steps = model["training"]["num_steps"] + dataset_metadata = ctx.get_metadata(model["dataset"]["id"], model["dataset"]["metadata_key"]) if model["training"]["num_epochs"]: train_num_steps = ( math.ceil(dataset_metadata["training_size"] / float(model["training"]["batch_size"]))