From 214e1cc9f78155a799ed26c39f7a7f1df3ea0267 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Sun, 5 May 2019 10:38:53 -0400 Subject: [PATCH 01/48] ResourceConfigField -> ResourceField --- pkg/operator/api/context/python_packages.go | 2 +- pkg/operator/api/context/transformers.go | 5 ++-- pkg/operator/api/userconfig/aggregates.go | 2 +- pkg/operator/api/userconfig/aggregators.go | 2 +- pkg/operator/api/userconfig/apis.go | 2 +- pkg/operator/api/userconfig/constants.go | 2 +- pkg/operator/api/userconfig/embed.go | 2 +- pkg/operator/api/userconfig/environments.go | 2 +- pkg/operator/api/userconfig/models.go | 2 +- pkg/operator/api/userconfig/raw_columns.go | 6 ++--- pkg/operator/api/userconfig/resource.go | 30 ++++++++++----------- pkg/operator/api/userconfig/templates.go | 2 +- pkg/operator/api/userconfig/transformers.go | 2 +- pkg/operator/context/models.go | 2 +- pkg/operator/context/python_packages.go | 4 +-- 15 files changed, 34 insertions(+), 33 deletions(-) diff --git a/pkg/operator/api/context/python_packages.go b/pkg/operator/api/context/python_packages.go index 9818e0082b..4b5627c7c9 100644 --- a/pkg/operator/api/context/python_packages.go +++ b/pkg/operator/api/context/python_packages.go @@ -24,7 +24,7 @@ import ( type PythonPackages map[string]*PythonPackage type PythonPackage struct { - userconfig.ResourceConfigFields + userconfig.ResourceFields *ComputedResourceFields SrcKey string `json:"src_key"` PackageKey string `json:"package_key"` diff --git a/pkg/operator/api/context/transformers.go b/pkg/operator/api/context/transformers.go index 96d127fab7..65bf53522c 100644 --- a/pkg/operator/api/context/transformers.go +++ b/pkg/operator/api/context/transformers.go @@ -25,8 +25,9 @@ type Transformers map[string]*Transformer type Transformer struct { *userconfig.Transformer *ResourceFields - Namespace *string `json:"namespace"` - ImplKey string `json:"impl_key"` + Namespace *string `json:"namespace"` + ImplKey string `json:"impl_key"` + SkipValidation bool `json:"skip_validation"` } func (transformers Transformers) OneByID(id string) *Transformer { diff --git a/pkg/operator/api/userconfig/aggregates.go b/pkg/operator/api/userconfig/aggregates.go index cae31fe5c1..aa1606acae 100644 --- a/pkg/operator/api/userconfig/aggregates.go +++ b/pkg/operator/api/userconfig/aggregates.go @@ -26,7 +26,7 @@ import ( type Aggregates []*Aggregate type Aggregate struct { - ResourceConfigFields + ResourceFields Aggregator string `json:"aggregator" yaml:"aggregator"` Inputs *Inputs `json:"inputs" yaml:"inputs"` Compute *SparkCompute `json:"compute" yaml:"compute"` diff --git a/pkg/operator/api/userconfig/aggregators.go b/pkg/operator/api/userconfig/aggregators.go index 226606cdbd..3eefddf752 100644 --- a/pkg/operator/api/userconfig/aggregators.go +++ b/pkg/operator/api/userconfig/aggregators.go @@ -24,7 +24,7 @@ import ( type Aggregators []*Aggregator type Aggregator struct { - ResourceConfigFields + ResourceFields Inputs *Inputs `json:"inputs" yaml:"inputs"` OutputType interface{} `json:"output_type" yaml:"output_type"` Path string `json:"path" yaml:"path"` diff --git a/pkg/operator/api/userconfig/apis.go b/pkg/operator/api/userconfig/apis.go index d7add9a17e..854dfee201 100644 --- a/pkg/operator/api/userconfig/apis.go +++ b/pkg/operator/api/userconfig/apis.go @@ -24,7 +24,7 @@ import ( type APIs []*API type API struct { - ResourceConfigFields + ResourceFields ModelName string `json:"model_name" yaml:"model_name"` Compute *APICompute `json:"compute" yaml:"compute"` Tags Tags `json:"tags" yaml:"tags"` diff --git a/pkg/operator/api/userconfig/constants.go b/pkg/operator/api/userconfig/constants.go index b1217622a7..a84c678940 100644 --- a/pkg/operator/api/userconfig/constants.go +++ b/pkg/operator/api/userconfig/constants.go @@ -25,7 +25,7 @@ import ( type Constants []*Constant type Constant struct { - ResourceConfigFields + ResourceFields Type interface{} `json:"type" yaml:"type"` Value interface{} `json:"value" yaml:"value"` Tags Tags `json:"tags" yaml:"tags"` diff --git a/pkg/operator/api/userconfig/embed.go b/pkg/operator/api/userconfig/embed.go index 8713130b35..3778e189d0 100644 --- a/pkg/operator/api/userconfig/embed.go +++ b/pkg/operator/api/userconfig/embed.go @@ -24,7 +24,7 @@ import ( type Embeds []*Embed type Embed struct { - ResourceConfigFields + ResourceFields Template string `json:"template" yaml:"template"` Args map[string]interface{} `json:"args" yaml:"args"` } diff --git a/pkg/operator/api/userconfig/environments.go b/pkg/operator/api/userconfig/environments.go index 26c06534c2..940765f080 100644 --- a/pkg/operator/api/userconfig/environments.go +++ b/pkg/operator/api/userconfig/environments.go @@ -28,7 +28,7 @@ import ( type Environments []*Environment type Environment struct { - ResourceConfigFields + ResourceFields LogLevel *LogLevel `json:"log_level" yaml:"log_level"` Limit *Limit `json:"limit" yaml:"limit"` Data Data `json:"-" yaml:"-"` diff --git a/pkg/operator/api/userconfig/models.go b/pkg/operator/api/userconfig/models.go index 6e30a3b79d..b1f7acfea1 100644 --- a/pkg/operator/api/userconfig/models.go +++ b/pkg/operator/api/userconfig/models.go @@ -29,7 +29,7 @@ import ( type Models []*Model type Model struct { - ResourceConfigFields + ResourceFields Type ModelType `json:"type" yaml:"type"` Path string `json:"path" yaml:"path"` TargetColumn string `json:"target_column" yaml:"target_column"` diff --git a/pkg/operator/api/userconfig/raw_columns.go b/pkg/operator/api/userconfig/raw_columns.go index c321c38199..1f04b7dad7 100644 --- a/pkg/operator/api/userconfig/raw_columns.go +++ b/pkg/operator/api/userconfig/raw_columns.go @@ -53,7 +53,7 @@ var rawColumnValidation = &cr.InterfaceStructValidation{ } type RawIntColumn struct { - ResourceConfigFields + ResourceFields Type ColumnType `json:"type" yaml:"type"` Required bool `json:"required" yaml:"required"` Min *int64 `json:"min" yaml:"min"` @@ -100,7 +100,7 @@ var rawIntColumnFieldValidations = []*cr.StructFieldValidation{ } type RawFloatColumn struct { - ResourceConfigFields + ResourceFields Type ColumnType `json:"type" yaml:"type"` Required bool `json:"required" yaml:"required"` Min *float32 `json:"min" yaml:"min"` @@ -147,7 +147,7 @@ var rawFloatColumnFieldValidations = []*cr.StructFieldValidation{ } type RawStringColumn struct { - ResourceConfigFields + ResourceFields Type ColumnType `json:"type" yaml:"type"` Required bool `json:"required" yaml:"required"` Values []string `json:"values" yaml:"values"` diff --git a/pkg/operator/api/userconfig/resource.go b/pkg/operator/api/userconfig/resource.go index a996845604..45f72fd68a 100644 --- a/pkg/operator/api/userconfig/resource.go +++ b/pkg/operator/api/userconfig/resource.go @@ -34,39 +34,39 @@ type Resource interface { SetEmbed(*Embed) } -type ResourceConfigFields struct { +type ResourceFields struct { Name string `json:"name" yaml:"name"` Index int `json:"index" yaml:"-"` FilePath string `json:"file_path" yaml:"-"` Embed *Embed `json:"embed" yaml:"-"` } -func (resourceConfigFields *ResourceConfigFields) GetName() string { - return resourceConfigFields.Name +func (ResourceFields *ResourceFields) GetName() string { + return ResourceFields.Name } -func (resourceConfigFields *ResourceConfigFields) GetIndex() int { - return resourceConfigFields.Index +func (ResourceFields *ResourceFields) GetIndex() int { + return ResourceFields.Index } -func (resourceConfigFields *ResourceConfigFields) SetIndex(index int) { - resourceConfigFields.Index = index +func (ResourceFields *ResourceFields) SetIndex(index int) { + ResourceFields.Index = index } -func (resourceConfigFields *ResourceConfigFields) GetFilePath() string { - return resourceConfigFields.FilePath +func (ResourceFields *ResourceFields) GetFilePath() string { + return ResourceFields.FilePath } -func (resourceConfigFields *ResourceConfigFields) SetFilePath(filePath string) { - resourceConfigFields.FilePath = filePath +func (ResourceFields *ResourceFields) SetFilePath(filePath string) { + ResourceFields.FilePath = filePath } -func (resourceConfigFields *ResourceConfigFields) GetEmbed() *Embed { - return resourceConfigFields.Embed +func (ResourceFields *ResourceFields) GetEmbed() *Embed { + return ResourceFields.Embed } -func (resourceConfigFields *ResourceConfigFields) SetEmbed(embed *Embed) { - resourceConfigFields.Embed = embed +func (ResourceFields *ResourceFields) SetEmbed(embed *Embed) { + ResourceFields.Embed = embed } func Identify(r Resource) string { diff --git a/pkg/operator/api/userconfig/templates.go b/pkg/operator/api/userconfig/templates.go index c0149a75b1..85f728eae4 100644 --- a/pkg/operator/api/userconfig/templates.go +++ b/pkg/operator/api/userconfig/templates.go @@ -31,7 +31,7 @@ var templateVarRegex = regexp.MustCompile("\\{\\s*([a-zA-Z0-9_-]+)\\s*\\}") type Templates []*Template type Template struct { - ResourceConfigFields + ResourceFields YAML string `json:"yaml" yaml:"yaml"` } diff --git a/pkg/operator/api/userconfig/transformers.go b/pkg/operator/api/userconfig/transformers.go index cb7b09744c..6eada9effa 100644 --- a/pkg/operator/api/userconfig/transformers.go +++ b/pkg/operator/api/userconfig/transformers.go @@ -24,7 +24,7 @@ import ( type Transformers []*Transformer type Transformer struct { - ResourceConfigFields + ResourceFields Inputs *Inputs `json:"inputs" yaml:"inputs"` OutputType ColumnType `json:"output_type" yaml:"output_type"` Path string `json:"path" yaml:"path"` diff --git a/pkg/operator/context/models.go b/pkg/operator/context/models.go index 8397dd09fc..10be49b448 100644 --- a/pkg/operator/context/models.go +++ b/pkg/operator/context/models.go @@ -104,7 +104,7 @@ func getModels( ImplID: modelImplID, ImplKey: modelImplKey, Dataset: &context.TrainingDataset{ - ResourceConfigFields: userconfig.ResourceConfigFields{ + ResourceFields: userconfig.ResourceFields{ Name: trainingDatasetName, FilePath: modelConfig.FilePath, Embed: modelConfig.Embed, diff --git a/pkg/operator/context/python_packages.go b/pkg/operator/context/python_packages.go index adc0afa7e9..e87a1753fe 100644 --- a/pkg/operator/context/python_packages.go +++ b/pkg/operator/context/python_packages.go @@ -57,7 +57,7 @@ func loadPythonPackages(files map[string][]byte, datasetVersion string) (context buf.WriteString(datasetVersion) id := hash.Bytes(buf.Bytes()) pythonPackage := context.PythonPackage{ - ResourceConfigFields: userconfig.ResourceConfigFields{ + ResourceFields: userconfig.ResourceFields{ Name: consts.RequirementsTxt, }, ComputedResourceFields: &context.ComputedResourceFields{ @@ -95,7 +95,7 @@ func loadPythonPackages(files map[string][]byte, datasetVersion string) (context } id := hash.Bytes(buf.Bytes()) pythonPackage := context.PythonPackage{ - ResourceConfigFields: userconfig.ResourceConfigFields{ + ResourceFields: userconfig.ResourceFields{ Name: packageName, }, ComputedResourceFields: &context.ComputedResourceFields{ From 719ed13ae4dea6fa9e0f7044cad177b15fa04bf6 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Mon, 6 May 2019 20:16:31 +0000 Subject: [PATCH 02/48] go transformer changes aggregates go + python yes it works propagate metadata --- examples/fraud/resources/weight_column.yaml | 11 +-- .../mnist/resources/transformed_columns.yaml | 2 +- examples/mnist/resources/transformers.yaml | 6 -- examples/reviews/resources/max_length.yaml | 9 +-- .../reviews/resources/tokenized_columns.yaml | 12 +-- pkg/consts/consts.go | 33 ++++---- pkg/lib/errors/errors.go | 2 +- pkg/lib/strings/operations.go | 4 + pkg/operator/api/context/aggregators.go | 5 +- pkg/operator/api/context/context.go | 11 +++ pkg/operator/api/context/models.go | 9 +-- pkg/operator/api/userconfig/aggregates.go | 16 ++-- pkg/operator/api/userconfig/config.go | 12 ++- .../api/userconfig/transformed_columns.go | 18 +++-- pkg/operator/context/aggregates.go | 22 +++++- pkg/operator/context/aggregators.go | 50 ++++++++++-- pkg/operator/context/apis.go | 3 + pkg/operator/context/autogenerator.go | 26 ++++++- pkg/operator/context/constants.go | 1 + pkg/operator/context/context.go | 8 +- pkg/operator/context/models.go | 8 +- pkg/operator/context/python_packages.go | 2 + pkg/operator/context/raw_columns.go | 5 ++ pkg/operator/context/transformed_columns.go | 18 ++++- pkg/operator/context/transformers.go | 50 ++++++++++-- pkg/workloads/lib/context.py | 76 ++++++++++++++++++- pkg/workloads/lib/storage/s3.py | 4 +- pkg/workloads/lib/tf_lib.py | 17 +++-- pkg/workloads/spark_job/spark_job.py | 4 +- pkg/workloads/spark_job/spark_util.py | 35 ++++++++- .../spark_job/test/integration/iris_test.py | 4 +- pkg/workloads/tf_api/api.py | 5 +- pkg/workloads/tf_train/train_util.py | 5 +- 33 files changed, 368 insertions(+), 125 deletions(-) delete mode 100644 examples/mnist/resources/transformers.yaml diff --git a/examples/fraud/resources/weight_column.yaml b/examples/fraud/resources/weight_column.yaml index a6f87aa58b..34654ab4fb 100644 --- a/examples/fraud/resources/weight_column.yaml +++ b/examples/fraud/resources/weight_column.yaml @@ -5,18 +5,9 @@ columns: col: class -- kind: transformer - name: weight - inputs: - columns: - col: INT_COLUMN - args: - class_distribution: {INT: FLOAT} - output_type: FLOAT_COLUMN - - kind: transformed_column name: weight_column - transformer: weight + transformer_path: implementations/transformers/weight.py inputs: columns: col: class diff --git a/examples/mnist/resources/transformed_columns.yaml b/examples/mnist/resources/transformed_columns.yaml index f46f08aba4..4c736b0880 100644 --- a/examples/mnist/resources/transformed_columns.yaml +++ b/examples/mnist/resources/transformed_columns.yaml @@ -1,6 +1,6 @@ - kind: transformed_column name: image_pixels - transformer: decode_and_normalize + transformer_path: implementations/transformers/decode_and_normalize.py inputs: columns: image: image diff --git a/examples/mnist/resources/transformers.yaml b/examples/mnist/resources/transformers.yaml deleted file mode 100644 index 3d13b2d5ee..0000000000 --- a/examples/mnist/resources/transformers.yaml +++ /dev/null @@ -1,6 +0,0 @@ -- kind: transformer - name: decode_and_normalize - output_type: FLOAT_LIST_COLUMN - inputs: - columns: - image: STRING_COLUMN diff --git a/examples/reviews/resources/max_length.yaml b/examples/reviews/resources/max_length.yaml index d5f642e096..168190f03f 100644 --- a/examples/reviews/resources/max_length.yaml +++ b/examples/reviews/resources/max_length.yaml @@ -1,13 +1,6 @@ -- kind: aggregator - name: max_length - inputs: - columns: - col: STRING_COLUMN - output_type: INT - - kind: aggregate name: max_review_length - aggregator: max_length + aggregator_path: implementations/aggregators/max_length.py inputs: columns: col: review diff --git a/examples/reviews/resources/tokenized_columns.yaml b/examples/reviews/resources/tokenized_columns.yaml index 225bd38f60..cc2715e76f 100644 --- a/examples/reviews/resources/tokenized_columns.yaml +++ b/examples/reviews/resources/tokenized_columns.yaml @@ -1,16 +1,6 @@ -- kind: transformer - name: tokenize_string_to_int - output_type: INT_LIST_COLUMN - inputs: - columns: - col: STRING_COLUMN - args: - max_len: INT - vocab: {STRING: INT} - - kind: transformed_column name: embedding_input - transformer: tokenize_string_to_int + transformer_path: implementations/transformers/tokenize_string_to_int.py inputs: columns: col: review diff --git a/pkg/consts/consts.go b/pkg/consts/consts.go index daf5e39737..e0835c7c7d 100644 --- a/pkg/consts/consts.go +++ b/pkg/consts/consts.go @@ -36,21 +36,24 @@ var ( RequirementsTxt = "requirements.txt" PackageDir = "packages" - AppsDir = "apps" - DataDir = "data" - RawDataDir = "data_raw" - TrainingDataDir = "data_training" - AggregatorsDir = "aggregators" - AggregatesDir = "aggregates" - TransformersDir = "transformers" - ModelImplsDir = "model_implementations" - PythonPackagesDir = "python_packages" - ModelsDir = "models" - ConstantsDir = "constants" - ContextsDir = "contexts" - ResourceStatusesDir = "resource_statuses" - WorkloadSpecsDir = "workload_specs" - LogPrefixesDir = "log_prefixes" + AppsDir = "apps" + APIsDir = "apis" + DataDir = "data" + RawDataDir = "data_raw" + TrainingDataDir = "data_training" + AggregatorsDir = "aggregators" + AggregatesDir = "aggregates" + TransformersDir = "transformers" + ModelImplsDir = "model_implementations" + PythonPackagesDir = "python_packages" + ModelsDir = "models" + ConstantsDir = "constants" + ContextsDir = "contexts" + ResourceStatusesDir = "resource_statuses" + WorkloadSpecsDir = "workload_specs" + LogPrefixesDir = "log_prefixes" + RawColumnsDir = "raw_columns" + TransformedColumnsDir = "transformed_columns" TelemetryURL = "https://telemetry.cortexlabs.dev" ) diff --git a/pkg/lib/errors/errors.go b/pkg/lib/errors/errors.go index 342e07a307..c3dc2de155 100644 --- a/pkg/lib/errors/errors.go +++ b/pkg/lib/errors/errors.go @@ -151,7 +151,7 @@ func Panic(items ...interface{}) { func PrintError(err error, strs ...string) { wrappedErr := Wrap(err, strs...) fmt.Println("error:", wrappedErr.Error()) - // PrintStacktrace(wrappedErr) + PrintStacktrace(wrappedErr) } func PrintStacktrace(err error) { diff --git a/pkg/lib/strings/operations.go b/pkg/lib/strings/operations.go index 60cb18648e..11338c4f16 100644 --- a/pkg/lib/strings/operations.go +++ b/pkg/lib/strings/operations.go @@ -128,3 +128,7 @@ func StrsSentence(strs []string, lastJoinWord string) string { return strings.Join(strs[:lastIndex], ", ") + ", " + lastJoinWord + " " + strs[lastIndex] } } + +func PathToName(path string) string { + return strings.Replace(strings.Replace(path, "/", "_", -1), ".", "_", -1) +} diff --git a/pkg/operator/api/context/aggregators.go b/pkg/operator/api/context/aggregators.go index 16edf1e807..635151adb0 100644 --- a/pkg/operator/api/context/aggregators.go +++ b/pkg/operator/api/context/aggregators.go @@ -25,8 +25,9 @@ type Aggregators map[string]*Aggregator type Aggregator struct { *userconfig.Aggregator *ResourceFields - Namespace *string `json:"namespace"` - ImplKey string `json:"impl_key"` + Namespace *string `json:"namespace"` + ImplKey string `json:"impl_key"` + SkipValidation bool `json:"skip_validation"` } func (aggregators Aggregators) OneByID(id string) *Aggregator { diff --git a/pkg/operator/api/context/context.go b/pkg/operator/api/context/context.go index ab062a32ed..72dcddf632 100644 --- a/pkg/operator/api/context/context.go +++ b/pkg/operator/api/context/context.go @@ -55,6 +55,8 @@ type Resource interface { GetID() string GetIDWithTags() string GetResourceFields() *ResourceFields + GetMetadataKey() string + SetMetadataKey(string) } type ComputedResource interface { @@ -72,6 +74,7 @@ type ResourceFields struct { ID string `json:"id"` IDWithTags string `json:"id_with_tags"` ResourceType resource.Type `json:"resource_type"` + MetadataKey string `json:"metadata_key"` } type ComputedResourceFields struct { @@ -91,6 +94,14 @@ func (r *ResourceFields) GetResourceFields() *ResourceFields { return r } +func (r *ResourceFields) GetMetadataKey() string { + return r.MetadataKey +} + +func (r *ResourceFields) SetMetadataKey(metadataKey string) { + r.MetadataKey = metadataKey +} + func (r *ComputedResourceFields) GetWorkloadID() string { return r.WorkloadID } diff --git a/pkg/operator/api/context/models.go b/pkg/operator/api/context/models.go index dd45711a76..9df19a0bc2 100644 --- a/pkg/operator/api/context/models.go +++ b/pkg/operator/api/context/models.go @@ -38,12 +38,11 @@ type Model struct { } type TrainingDataset struct { - userconfig.ResourceConfigFields + userconfig.ResourceFields *ComputedResourceFields - ModelName string `json:"model_name"` - TrainKey string `json:"train_key"` - EvalKey string `json:"eval_key"` - MetadataKey string `json:"metadata_key"` + ModelName string `json:"model_name"` + TrainKey string `json:"train_key"` + EvalKey string `json:"eval_key"` } func (trainingDataset *TrainingDataset) GetResourceType() resource.Type { diff --git a/pkg/operator/api/userconfig/aggregates.go b/pkg/operator/api/userconfig/aggregates.go index aa1606acae..d83a717772 100644 --- a/pkg/operator/api/userconfig/aggregates.go +++ b/pkg/operator/api/userconfig/aggregates.go @@ -27,10 +27,11 @@ type Aggregates []*Aggregate type Aggregate struct { ResourceFields - Aggregator string `json:"aggregator" yaml:"aggregator"` - Inputs *Inputs `json:"inputs" yaml:"inputs"` - Compute *SparkCompute `json:"compute" yaml:"compute"` - Tags Tags `json:"tags" yaml:"tags"` + Aggregator *string `json:"aggregator" yaml:"aggregator"` + AggregatorPath *string `json:"aggregator_path" yaml:"aggregator_path"` + Inputs *Inputs `json:"inputs" yaml:"inputs"` + Compute *SparkCompute `json:"compute" yaml:"compute"` + Tags Tags `json:"tags" yaml:"tags"` } var aggregateValidation = &configreader.StructValidation{ @@ -44,11 +45,14 @@ var aggregateValidation = &configreader.StructValidation{ }, { StructField: "Aggregator", - StringValidation: &configreader.StringValidation{ - Required: true, + StringPtrValidation: &configreader.StringPtrValidation{ AlphaNumericDashDotUnderscore: true, }, }, + { + StructField: "AggregatorPath", + StringPtrValidation: &configreader.StringPtrValidation{}, + }, inputValuesFieldValidation, sparkComputeFieldValidation("Compute"), tagsFieldValidation, diff --git a/pkg/operator/api/userconfig/config.go b/pkg/operator/api/userconfig/config.go index f245f6436b..a31d3bd73e 100644 --- a/pkg/operator/api/userconfig/config.go +++ b/pkg/operator/api/userconfig/config.go @@ -203,16 +203,20 @@ func (config *Config) Validate(envName string) error { // Check local aggregators exist aggregatorNames := config.Aggregators.Names() for _, aggregate := range config.Aggregates { - if !strings.Contains(aggregate.Aggregator, ".") && !slices.HasString(aggregatorNames, aggregate.Aggregator) { - return errors.Wrap(ErrorUndefinedResource(aggregate.Aggregator, resource.AggregatorType), Identify(aggregate), AggregatorKey) + if aggregate.Aggregator != nil && + !strings.Contains(*aggregate.Aggregator, ".") && + !slices.HasString(aggregatorNames, *aggregate.Aggregator) { + return errors.Wrap(ErrorUndefinedResource(*aggregate.Aggregator, resource.AggregatorType), Identify(aggregate), AggregatorKey) } } // Check local transformers exist transformerNames := config.Transformers.Names() for _, transformedColumn := range config.TransformedColumns { - if !strings.Contains(transformedColumn.Transformer, ".") && !slices.HasString(transformerNames, transformedColumn.Transformer) { - return errors.Wrap(ErrorUndefinedResource(transformedColumn.Transformer, resource.TransformerType), Identify(transformedColumn), TransformerKey) + if transformedColumn.Transformer != nil && + !strings.Contains(*transformedColumn.Transformer, ".") && + !slices.HasString(transformerNames, *transformedColumn.Transformer) { + return errors.Wrap(ErrorUndefinedResource(*transformedColumn.Transformer, resource.TransformerType), Identify(transformedColumn), TransformerKey) } } diff --git a/pkg/operator/api/userconfig/transformed_columns.go b/pkg/operator/api/userconfig/transformed_columns.go index a5118bb3a1..9de4000134 100644 --- a/pkg/operator/api/userconfig/transformed_columns.go +++ b/pkg/operator/api/userconfig/transformed_columns.go @@ -26,11 +26,12 @@ import ( type TransformedColumns []*TransformedColumn type TransformedColumn struct { - ResourceConfigFields - Transformer string `json:"transformer" yaml:"transformer"` - Inputs *Inputs `json:"inputs" yaml:"inputs"` - Compute *SparkCompute `json:"compute" yaml:"compute"` - Tags Tags `json:"tags" yaml:"tags"` + ResourceFields + Transformer *string `json:"transformer" yaml:"transformer"` + TransformerPath *string `json:"transformer_path" yaml:"transformer_path"` + Inputs *Inputs `json:"inputs" yaml:"inputs"` + Compute *SparkCompute `json:"compute" yaml:"compute"` + Tags Tags `json:"tags" yaml:"tags"` } var transformedColumnValidation = &configreader.StructValidation{ @@ -44,11 +45,14 @@ var transformedColumnValidation = &configreader.StructValidation{ }, { StructField: "Transformer", - StringValidation: &configreader.StringValidation{ - Required: true, + StringPtrValidation: &configreader.StringPtrValidation{ AlphaNumericDashDotUnderscore: true, }, }, + { + StructField: "TransformerPath", + StringPtrValidation: &configreader.StringPtrValidation{}, + }, inputValuesFieldValidation, sparkComputeFieldValidation("Compute"), tagsFieldValidation, diff --git a/pkg/operator/context/aggregates.go b/pkg/operator/context/aggregates.go index f5541f9741..111d66a85a 100644 --- a/pkg/operator/context/aggregates.go +++ b/pkg/operator/context/aggregates.go @@ -44,7 +44,17 @@ func getAggregates( return nil, userconfig.ErrorDuplicateResourceName(aggregateConfig, constants[aggregateConfig.Name]) } - aggregator, err := getAggregator(aggregateConfig.Aggregator, userAggregators) + var aggName string + if aggregateConfig.Aggregator != nil { + aggName = *aggregateConfig.Aggregator + } + + if aggregateConfig.AggregatorPath != nil { + aggName = *aggregateConfig.AggregatorPath + aggregateConfig.Aggregator = &aggName + } + + aggregator, err := getAggregator(aggName, userAggregators) if err != nil { return nil, errors.Wrap(err, userconfig.Identify(aggregateConfig), userconfig.AggregatorKey) } @@ -80,11 +90,13 @@ func getAggregates( buf.WriteString(aggregateConfig.Tags.ID()) idWithTags := hash.Bytes(buf.Bytes()) - aggregateKey := filepath.Join( + aggregateRootKey := filepath.Join( root, consts.AggregatesDir, - id+".msgpack", + id, ) + aggregateKey := aggregateRootKey + ".msgpack" + aggregateMetadataKey := aggregateRootKey + "_metadata.json" aggregates[aggregateConfig.Name] = &context.Aggregate{ ComputedResourceFields: &context.ComputedResourceFields{ @@ -92,6 +104,7 @@ func getAggregates( ID: id, IDWithTags: idWithTags, ResourceType: resource.AggregateType, + MetadataKey: aggregateMetadataKey, }, }, Aggregate: aggregateConfig, @@ -109,6 +122,9 @@ func validateAggregateInputs( rawColumns context.RawColumns, aggregator *context.Aggregator, ) error { + if aggregator.SkipValidation { + return nil + } columnRuntimeTypes, err := context.GetColumnRuntimeTypes(aggregateConfig.Inputs.Columns, rawColumns) if err != nil { diff --git a/pkg/operator/context/aggregators.go b/pkg/operator/context/aggregators.go index bfe2f48c27..e219b69423 100644 --- a/pkg/operator/context/aggregators.go +++ b/pkg/operator/context/aggregators.go @@ -23,6 +23,7 @@ import ( "github.com/cortexlabs/cortex/pkg/consts" "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/hash" + s "github.com/cortexlabs/cortex/pkg/lib/strings" "github.com/cortexlabs/cortex/pkg/operator/api/context" "github.com/cortexlabs/cortex/pkg/operator/api/resource" "github.com/cortexlabs/cortex/pkg/operator/api/userconfig" @@ -30,24 +31,47 @@ import ( ) func loadUserAggregators( - aggregatorConfigs userconfig.Aggregators, + config *userconfig.Config, impls map[string][]byte, pythonPackages context.PythonPackages, ) (map[string]*context.Aggregator, error) { userAggregators := make(map[string]*context.Aggregator) - for _, aggregatorConfig := range aggregatorConfigs { + for _, aggregatorConfig := range config.Aggregators { impl, ok := impls[aggregatorConfig.Path] if !ok { return nil, errors.Wrap(ErrorImplDoesNotExist(aggregatorConfig.Path), userconfig.Identify(aggregatorConfig)) } - aggregator, err := newAggregator(*aggregatorConfig, impl, nil, pythonPackages) + aggregator, err := newAggregator(*aggregatorConfig, impl, nil, pythonPackages, false) if err != nil { return nil, err } userAggregators[aggregator.Name] = aggregator } + for _, aggregateConfig := range config.Aggregates { + if aggregateConfig.AggregatorPath == nil { + continue + } + + impl, ok := impls[*aggregateConfig.AggregatorPath] + if !ok { + return nil, errors.Wrap(ErrorImplDoesNotExist(*aggregateConfig.AggregatorPath), userconfig.Identify(aggregateConfig)) + } + + anonAggregatorConfig := &userconfig.Aggregator{ + ResourceFields: userconfig.ResourceFields{ + Name: s.PathToName(*aggregateConfig.AggregatorPath), + }, + Path: *aggregateConfig.AggregatorPath, + } + aggregator, err := newAggregator(*anonAggregatorConfig, impl, nil, pythonPackages, true) + if err != nil { + return nil, err + } + userAggregators[*aggregateConfig.AggregatorPath] = aggregator + } + return userAggregators, nil } @@ -56,6 +80,7 @@ func newAggregator( impl []byte, namespace *string, pythonPackages context.PythonPackages, + skipValidation bool, ) (*context.Aggregator, error) { implID := hash.Bytes(impl) @@ -76,10 +101,12 @@ func newAggregator( ID: id, IDWithTags: id, ResourceType: resource.AggregatorType, + MetadataKey: filepath.Join(consts.AggregatorsDir, id+"_metadata.json"), }, - Aggregator: &aggregatorConfig, - Namespace: namespace, - ImplKey: filepath.Join(consts.AggregatorsDir, implID+".py"), + Aggregator: &aggregatorConfig, + Namespace: namespace, + ImplKey: filepath.Join(consts.AggregatorsDir, implID+".py"), + SkipValidation: skipValidation, } aggregator.Aggregator.Path = "" @@ -132,7 +159,16 @@ func getAggregators( aggregators := context.Aggregators{} for _, aggregateConfig := range config.Aggregates { - aggregatorName := aggregateConfig.Aggregator + + var aggregatorName string + if aggregateConfig.Aggregator != nil { + aggregatorName = *aggregateConfig.Aggregator + } + + if aggregateConfig.AggregatorPath != nil { + aggregatorName = *aggregateConfig.AggregatorPath + } + if _, ok := aggregators[aggregatorName]; ok { continue } diff --git a/pkg/operator/context/apis.go b/pkg/operator/context/apis.go index 1ff5045575..9f0ef58bae 100644 --- a/pkg/operator/context/apis.go +++ b/pkg/operator/context/apis.go @@ -18,7 +18,9 @@ package context import ( "bytes" + "path/filepath" + "github.com/cortexlabs/cortex/pkg/consts" "github.com/cortexlabs/cortex/pkg/lib/hash" "github.com/cortexlabs/cortex/pkg/operator/api/context" "github.com/cortexlabs/cortex/pkg/operator/api/resource" @@ -48,6 +50,7 @@ func getAPIs(config *userconfig.Config, ID: id, IDWithTags: idWithTags, ResourceType: resource.APIType, + MetadataKey: filepath.Join(consts.APIsDir, id+"_metadata.json"), }, }, API: apiConfig, diff --git a/pkg/operator/context/autogenerator.go b/pkg/operator/context/autogenerator.go index e57259eab8..b3f1b1ab90 100644 --- a/pkg/operator/context/autogenerator.go +++ b/pkg/operator/context/autogenerator.go @@ -43,7 +43,16 @@ func autoGenerateConfig( } } - aggregator, err := getAggregator(aggregate.Aggregator, userAggregators) + var name string + if aggregate.Aggregator != nil { + name = *aggregate.Aggregator + } + + if aggregate.AggregatorPath != nil { + name = *aggregate.AggregatorPath + } + + aggregator, err := getAggregator(name, userAggregators) if err != nil { return errors.Wrap(err, userconfig.Identify(aggregate), userconfig.AggregatorKey) } @@ -61,7 +70,7 @@ func autoGenerateConfig( }, "/") constant := &userconfig.Constant{ - ResourceConfigFields: userconfig.ResourceConfigFields{ + ResourceFields: userconfig.ResourceFields{ Name: constantName, }, Type: argType, @@ -84,7 +93,16 @@ func autoGenerateConfig( } } - transformer, err := getTransformer(transformedColumn.Transformer, userTransformers) + var name string + if transformedColumn.Transformer != nil { + name = *transformedColumn.Transformer + } + + if transformedColumn.TransformerPath != nil { + name = s.PathToName(*transformedColumn.TransformerPath) + } + + transformer, err := getTransformer(name, userTransformers) if err != nil { return errors.Wrap(err, userconfig.Identify(transformedColumn), userconfig.TransformerKey) } @@ -102,7 +120,7 @@ func autoGenerateConfig( }, "/") constant := &userconfig.Constant{ - ResourceConfigFields: userconfig.ResourceConfigFields{ + ResourceFields: userconfig.ResourceFields{ Name: constantName, }, Type: argType, diff --git a/pkg/operator/context/constants.go b/pkg/operator/context/constants.go index bdb2edcba1..ecc1836d8b 100644 --- a/pkg/operator/context/constants.go +++ b/pkg/operator/context/constants.go @@ -59,6 +59,7 @@ func newConstant(constantConfig userconfig.Constant) (*context.Constant, error) ID: id, IDWithTags: idWithTags, ResourceType: resource.ConstantType, + MetadataKey: filepath.Join(consts.ConstantsDir, id+"_metadata.json"), }, Constant: &constantConfig, Key: filepath.Join(consts.ConstantsDir, id+".msgpack"), diff --git a/pkg/operator/context/context.go b/pkg/operator/context/context.go index 3c5babff09..f0f845e5b3 100644 --- a/pkg/operator/context/context.go +++ b/pkg/operator/context/context.go @@ -62,7 +62,7 @@ func Init() error { if err != nil { return errors.Wrap(err, userconfig.Identify(aggregatorConfig)) } - aggregator, err := newAggregator(*aggregatorConfig, impl, pointer.String("cortex"), nil) + aggregator, err := newAggregator(*aggregatorConfig, impl, pointer.String("cortex"), nil, false) if err != nil { return err } @@ -81,7 +81,7 @@ func Init() error { if err != nil { return errors.Wrap(err, userconfig.Identify(transConfig)) } - transformer, err := newTransformer(*transConfig, impl, pointer.String("cortex"), nil) + transformer, err := newTransformer(*transConfig, impl, pointer.String("cortex"), nil, false) if err != nil { return err } @@ -130,12 +130,12 @@ func New( } ctx.PythonPackages = pythonPackages - userTransformers, err := loadUserTransformers(userconf.Transformers, files, pythonPackages) + userTransformers, err := loadUserTransformers(userconf, files, pythonPackages) if err != nil { return nil, err } - userAggregators, err := loadUserAggregators(userconf.Aggregators, files, pythonPackages) + userAggregators, err := loadUserAggregators(userconf, files, pythonPackages) if err != nil { return nil, err } diff --git a/pkg/operator/context/models.go b/pkg/operator/context/models.go index 10be49b448..0b69ae5dc9 100644 --- a/pkg/operator/context/models.go +++ b/pkg/operator/context/models.go @@ -97,6 +97,7 @@ func getModels( ID: modelID, IDWithTags: modelID, ResourceType: resource.ModelType, + MetadataKey: filepath.Join(datasetRoot, "metadata.json"), }, }, Model: modelConfig, @@ -116,10 +117,9 @@ func getModels( ResourceType: resource.TrainingDatasetType, }, }, - ModelName: modelConfig.Name, - TrainKey: filepath.Join(datasetRoot, "train.tfrecord"), - EvalKey: filepath.Join(datasetRoot, "eval.tfrecord"), - MetadataKey: filepath.Join(datasetRoot, "metadata.json"), + ModelName: modelConfig.Name, + TrainKey: filepath.Join(datasetRoot, "train.tfrecord"), + EvalKey: filepath.Join(datasetRoot, "eval.tfrecord"), }, } } diff --git a/pkg/operator/context/python_packages.go b/pkg/operator/context/python_packages.go index e87a1753fe..8cf89da0a5 100644 --- a/pkg/operator/context/python_packages.go +++ b/pkg/operator/context/python_packages.go @@ -64,6 +64,7 @@ func loadPythonPackages(files map[string][]byte, datasetVersion string) (context ResourceFields: &context.ResourceFields{ ID: id, ResourceType: resource.PythonPackageType, + MetadataKey: filepath.Join(consts.PythonPackagesDir, id, "metadata.json"), }, }, SrcKey: filepath.Join(consts.PythonPackagesDir, id, "src.txt"), @@ -102,6 +103,7 @@ func loadPythonPackages(files map[string][]byte, datasetVersion string) (context ResourceFields: &context.ResourceFields{ ID: id, ResourceType: resource.PythonPackageType, + MetadataKey: filepath.Join(consts.PythonPackagesDir, id, "metadata.json"), }, }, SrcKey: filepath.Join(consts.PythonPackagesDir, id, "src.zip"), diff --git a/pkg/operator/context/raw_columns.go b/pkg/operator/context/raw_columns.go index 4ccf432aef..a3102cff45 100644 --- a/pkg/operator/context/raw_columns.go +++ b/pkg/operator/context/raw_columns.go @@ -18,7 +18,9 @@ package context import ( "bytes" + "path/filepath" + "github.com/cortexlabs/cortex/pkg/consts" "github.com/cortexlabs/cortex/pkg/lib/configreader" "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/hash" @@ -57,6 +59,7 @@ func getRawColumns( ID: id, IDWithTags: idWithTags, ResourceType: resource.RawColumnType, + MetadataKey: filepath.Join(consts.RawColumnsDir, id+"_metadata.json"), }, }, RawIntColumn: typedColumnConfig, @@ -74,6 +77,7 @@ func getRawColumns( ID: id, IDWithTags: idWithTags, ResourceType: resource.RawColumnType, + MetadataKey: filepath.Join(consts.RawColumnsDir, id+"_metadata.json"), }, }, RawFloatColumn: typedColumnConfig, @@ -89,6 +93,7 @@ func getRawColumns( ID: id, IDWithTags: idWithTags, ResourceType: resource.RawColumnType, + MetadataKey: filepath.Join(consts.RawColumnsDir, id+"_metadata.json"), }, }, RawStringColumn: typedColumnConfig, diff --git a/pkg/operator/context/transformed_columns.go b/pkg/operator/context/transformed_columns.go index c6bf707c89..a6267e0885 100644 --- a/pkg/operator/context/transformed_columns.go +++ b/pkg/operator/context/transformed_columns.go @@ -18,7 +18,9 @@ package context import ( "bytes" + "path/filepath" + "github.com/cortexlabs/cortex/pkg/consts" "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/hash" s "github.com/cortexlabs/cortex/pkg/lib/strings" @@ -39,7 +41,17 @@ func getTransformedColumns( transformedColumns := context.TransformedColumns{} for _, transformedColumnConfig := range config.TransformedColumns { - transformer, err := getTransformer(transformedColumnConfig.Transformer, userTransformers) + var transName string + if transformedColumnConfig.Transformer != nil { + transName = *transformedColumnConfig.Transformer + } + + if transformedColumnConfig.TransformerPath != nil { + transName = s.PathToName(*transformedColumnConfig.TransformerPath) + transformedColumnConfig.Transformer = &transName + } + + transformer, err := getTransformer(transName, userTransformers) if err != nil { return nil, errors.Wrap(err, userconfig.Identify(transformedColumnConfig), userconfig.TransformerKey) } @@ -80,6 +92,7 @@ func getTransformedColumns( ID: id, IDWithTags: idWithTags, ResourceType: resource.TransformedColumnType, + MetadataKey: filepath.Join(consts.TransformedColumnsDir, id+"_metadata.json"), }, }, TransformedColumn: transformedColumnConfig, @@ -97,6 +110,9 @@ func validateTransformedColumnInputs( aggregates context.Aggregates, transformer *context.Transformer, ) error { + if transformer.SkipValidation { + return nil + } columnRuntimeTypes, err := context.GetColumnRuntimeTypes(transformedColumnConfig.Inputs.Columns, rawColumns) if err != nil { diff --git a/pkg/operator/context/transformers.go b/pkg/operator/context/transformers.go index ff4bad3b2c..c513644683 100644 --- a/pkg/operator/context/transformers.go +++ b/pkg/operator/context/transformers.go @@ -23,6 +23,7 @@ import ( "github.com/cortexlabs/cortex/pkg/consts" "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/hash" + s "github.com/cortexlabs/cortex/pkg/lib/strings" "github.com/cortexlabs/cortex/pkg/operator/api/context" "github.com/cortexlabs/cortex/pkg/operator/api/resource" "github.com/cortexlabs/cortex/pkg/operator/api/userconfig" @@ -30,24 +31,46 @@ import ( ) func loadUserTransformers( - transConfigs userconfig.Transformers, + config *userconfig.Config, impls map[string][]byte, pythonPackages context.PythonPackages, ) (map[string]*context.Transformer, error) { userTransformers := make(map[string]*context.Transformer) - for _, transConfig := range transConfigs { + for _, transConfig := range config.Transformers { impl, ok := impls[transConfig.Path] if !ok { return nil, errors.Wrap(ErrorImplDoesNotExist(transConfig.Path), userconfig.Identify(transConfig)) } - transformer, err := newTransformer(*transConfig, impl, nil, pythonPackages) + transformer, err := newTransformer(*transConfig, impl, nil, pythonPackages, false) if err != nil { return nil, err } userTransformers[transformer.Name] = transformer } + for _, transColConfig := range config.TransformedColumns { + if transColConfig.TransformerPath == nil { + continue + } + + impl, ok := impls[*transColConfig.TransformerPath] + if !ok { + return nil, errors.Wrap(ErrorImplDoesNotExist(*transColConfig.TransformerPath), userconfig.Identify(transColConfig)) + } + + anonTransformerConfig := &userconfig.Transformer{ + ResourceFields: userconfig.ResourceFields{ + Name: s.PathToName(*transColConfig.TransformerPath), + }, + Path: *transColConfig.TransformerPath, + } + transformer, err := newTransformer(*anonTransformerConfig, impl, nil, pythonPackages, true) + if err != nil { + return nil, err + } + userTransformers[transformer.Name] = transformer + } return userTransformers, nil } @@ -56,6 +79,7 @@ func newTransformer( impl []byte, namespace *string, pythonPackages context.PythonPackages, + skipValidation bool, ) (*context.Transformer, error) { implID := hash.Bytes(impl) @@ -75,10 +99,12 @@ func newTransformer( ID: id, IDWithTags: id, ResourceType: resource.TransformerType, + MetadataKey: filepath.Join(consts.TransformersDir, id+"_metadata.json"), }, - Transformer: &transConfig, - Namespace: namespace, - ImplKey: filepath.Join(consts.TransformersDir, implID+".py"), + Transformer: &transConfig, + Namespace: namespace, + ImplKey: filepath.Join(consts.TransformersDir, implID+".py"), + SkipValidation: skipValidation, } transformer.Transformer.Path = "" @@ -114,7 +140,6 @@ func getTransformer( name string, userTransformers map[string]*context.Transformer, ) (*context.Transformer, error) { - if transformer, ok := builtinTransformers[name]; ok { return transformer, nil } @@ -131,10 +156,19 @@ func getTransformers( transformers := context.Transformers{} for _, transformedColumnConfig := range config.TransformedColumns { - transformerName := transformedColumnConfig.Transformer + var transformerName string + if transformedColumnConfig.Transformer != nil { + transformerName = *transformedColumnConfig.Transformer + } + + if transformedColumnConfig.TransformerPath != nil { + transformerName = s.PathToName(*transformedColumnConfig.TransformerPath) + } + if _, ok := transformers[transformerName]; ok { continue } + transformer, err := getTransformer(transformerName, userTransformers) if err != nil { return nil, errors.Wrap(err, userconfig.Identify(transformedColumnConfig), userconfig.TransformerKey) diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index 7dc2de1ab6..94b0f79b18 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -80,7 +80,6 @@ def __init__(self, **kwargs): self.models = self.ctx["models"] self.apis = self.ctx["apis"] self.training_datasets = {k: v["dataset"] for k, v in self.models.items()} - self.api_version = self.cortex_config["api_version"] if "local_storage_path" in kwargs: @@ -99,6 +98,8 @@ def __init__(self, **kwargs): ) ) + self.fetch_metadata() + self.columns = util.merge_dicts_overwrite( self.raw_columns, self.transformed_columns # self.aggregates ) @@ -138,6 +139,7 @@ def __init__(self, **kwargs): self.constants_id_map, ) + def is_raw_column(self, name): return name in self.raw_columns @@ -228,6 +230,8 @@ def get_transformer_impl(self, column_name): return None, None transformer_name = self.transformed_columns[column_name]["transformer"] + if not transformer_name: + transformer_name = self.transformed_columns[column_name]["transformer_path"] if transformer_name in self._transformer_impls: return self._transformer_impls[transformer_name] @@ -469,6 +473,76 @@ def upload_resource_status_end(self, exit_code, *resources): def resource_status_key(self, resource): return os.path.join(self.status_prefix, resource["id"], resource["workload_id"]) + def update_metadata(self, metadata, context_key, context_item=""): + if context_item == "": + self.ctx[context_key]["metadata"] = metadata + self.storage.put_json(metadata, self.ctx[context_key]["metadata_key"]) + return + + self.ctx[context_key][context_item]["metadata"] = metadata + self.storage.put_json(metadata, self.ctx[context_key][context_item]["metadata_key"]) + + + def fetch_metadata(self): + for k, v in self.python_packages.items(): + metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) + if not metadata: + metadata = {} + self.python_packages[k]["metadata"] = metadata + + for k, v in self.raw_columns.items(): + metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) + if not metadata: + metadata = {} + self.raw_columns[k]["metadata"] = metadata + + for k, v in self.transformed_columns.items(): + metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) + if not metadata: + metadata = {} + self.transformed_columns[k]["metadata"] = metadata + + for k, v in self.transformers.items(): + metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) + if not metadata: + metadata = {} + self.transformers[k]["metadata"] = metadata + + for k, v in self.aggregators.items(): + metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) + if not metadata: + metadata = {} + self.aggregators[k]["metadata"] = metadata + + for k, v in self.aggregates.items(): + metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) + if not metadata: + metadata = {} + self.aggregates[k]["metadata"] = metadata + + for k, v in self.constants.items(): + metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) + if not metadata: + metadata = {} + self.constants[k]["metadata"] = metadata + + for k, v in self.models.items(): + metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) + if not metadata: + metadata = {} + self.models[k]["metadata"] = metadata + + for k, v in self.apis.items(): + metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) + if not metadata: + metadata = {} + self.apis[k]["metadata"] = metadata + + metadata = self.storage.get_json(self.raw_dataset["metadata_key"], allow_missing=True) + if not metadata: + metadata = {} + self.raw_dataset["metadata"] = metadata + MODEL_IMPL_VALIDATION = { "required": [{"name": "create_estimator", "args": ["run_config", "model_config"]}], diff --git a/pkg/workloads/lib/storage/s3.py b/pkg/workloads/lib/storage/s3.py index 2d7f77ac9f..90348dff04 100644 --- a/pkg/workloads/lib/storage/s3.py +++ b/pkg/workloads/lib/storage/s3.py @@ -131,10 +131,10 @@ def put_json(self, obj, key): self._upload_string_to_s3(json.dumps(obj), key) def get_json(self, key, allow_missing=False): - obj = self._read_bytes_from_s3(key, allow_missing).decode("utf-8") + obj = self._read_bytes_from_s3(key, allow_missing) if obj is None: return None - return json.loads(obj) + return json.loads(obj.decode("utf-8")) def put_msgpack(self, obj, key): self._upload_string_to_s3(msgpack.dumps(obj), key) diff --git a/pkg/workloads/lib/tf_lib.py b/pkg/workloads/lib/tf_lib.py index ce21da7550..f120bd6de8 100644 --- a/pkg/workloads/lib/tf_lib.py +++ b/pkg/workloads/lib/tf_lib.py @@ -57,10 +57,13 @@ def get_column_tf_types(model_name, ctx, training=True): """Generate a dict {column name -> tf_type}""" model = ctx.models[model_name] - column_types = { - column_name: CORTEX_TYPE_TO_TF_TYPE[ctx.columns[column_name]["type"]] - for column_name in model["feature_columns"] - } + column_types = {} + for column_name in model["feature_columns"]: + columnType = ctx.columns[column_name]["type"] + if columnType == "unknown": + columnType = ctx.columns[column_name]["metadata"]["type"] + + column_types[column_name] = CORTEX_TYPE_TO_TF_TYPE[columnType] if training: target_column_name = model["target_column"] @@ -79,7 +82,11 @@ def get_feature_spec(model_name, ctx, training=True): column_types = get_column_tf_types(model_name, ctx, training) feature_spec = {} for column_name, tf_type in column_types.items(): - if ctx.columns[column_name]["type"] in consts.COLUMN_LIST_TYPES: + columnType = ctx.columns[column_name]["type"] + if columnType == "unknown": + columnType = ctx.columns[column_name]["metadata"]["type"] + + if columnType in consts.COLUMN_LIST_TYPES: feature_spec[column_name] = tf.FixedLenSequenceFeature( shape=(), dtype=tf_type, allow_missing=True ) diff --git a/pkg/workloads/spark_job/spark_job.py b/pkg/workloads/spark_job/spark_job.py index 3c7472d676..96011d276e 100644 --- a/pkg/workloads/spark_job/spark_job.py +++ b/pkg/workloads/spark_job/spark_job.py @@ -91,7 +91,7 @@ def parse_args(args): def validate_dataset(ctx, raw_df, cols_to_validate): - total_row_count = ctx.storage.get_json(ctx.raw_dataset["metadata_key"])["dataset_size"] + total_row_count = ctx.raw_dataset["metadata"]["dataset_size"] conditions_dict = spark_util.value_check_data(ctx, raw_df, cols_to_validate) if len(conditions_dict) > 0: @@ -161,7 +161,7 @@ def ingest_raw_dataset(spark, ctx, cols_to_validate, should_ingest): written_count = write_raw_dataset(ingest_df, ctx, spark) metadata = {"dataset_size": written_count} - ctx.storage.put_json(metadata, ctx.raw_dataset["metadata_key"]) + ctx.update_metadata(metadata, "raw_dataset") if written_count != full_dataset_size: logger.info( "{} rows read, {} rows dropped, {} rows ingested".format( diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index c6c75c4d38..5dab66bfa1 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -48,6 +48,19 @@ consts.COLUMN_TYPE_STRING_LIST: [ArrayType(StringType(), True)], } +PYTHON_TYPE_TO_CORTEX_TYPE = { + int: consts.COLUMN_TYPE_INT, + float: consts.COLUMN_TYPE_FLOAT, + str: consts.COLUMN_TYPE_STRING, +} + +PYTHON_TYPE_TO_CORTEX_LIST_TYPE = { + int: consts.COLUMN_TYPE_INT_LIST, + float: consts.COLUMN_TYPE_FLOAT_LIST, + str: consts.COLUMN_TYPE_STRING_LIST, +} + + def accumulate_count(df, spark): acc = df._sc.accumulator(0) @@ -95,7 +108,7 @@ def write_training_data(model_name, df, ctx, spark): ) metadata = {"training_size": train_df_acc.value, "eval_size": eval_df_acc.value} - ctx.storage.put_json(metadata, training_dataset["metadata_key"]) + ctx.update_metadata(metadata, "models", model_name) return df @@ -384,7 +397,7 @@ def run_custom_aggregator(aggregator_resource, df, ctx, spark): "function aggregate_spark", ) from e - if not util.validate_value_type(result, aggregator["output_type"]): + if not aggregator["skip_validation"] and not util.validate_value_type(result, aggregator["output_type"]): raise UserException( "aggregate " + aggregator_resource["name"], "aggregator " + aggregator["name"], @@ -460,6 +473,24 @@ def validate_transformer(column_name, df, ctx, spark): trans_impl, _ = ctx.get_transformer_impl(column_name) if hasattr(trans_impl, "transform_python"): + sample_df = df.collect() + sample = sample_df[0] + inputs = ctx.create_column_inputs_map(sample, column_name) + _, impl_args = extract_inputs(column_name, ctx) + transformedSample = trans_impl.transform_python(inputs, impl_args) + rowType = type(transformedSample) + isList = rowType == list + typeConversionDict = PYTHON_TYPE_TO_CORTEX_TYPE + if isList: + rowType = type(transformedSample[0]) + typeConversionDict = PYTHON_TYPE_TO_CORTEX_LIST_TYPE + + # for downstream operations on this job + ctx.transformed_columns[column_name]["type"] = typeConversionDict[rowType] + + # for downstream operations on other jobs + ctx.update_metadata({"type": typeConversionDict[rowType]}, "transformed_columns", column_name) + try: transform_python_collect = execute_transform_python( column_name, df, ctx, spark, validate=True diff --git a/pkg/workloads/spark_job/test/integration/iris_test.py b/pkg/workloads/spark_job/test/integration/iris_test.py index 73ff1e9ec1..481a57d3ea 100644 --- a/pkg/workloads/spark_job/test/integration/iris_test.py +++ b/pkg/workloads/spark_job/test/integration/iris_test.py @@ -117,7 +117,7 @@ def test_simple_end_to_end(spark): status["exist_code"] = "succeeded" dataset = raw_ctx["models"]["dnn"]["dataset"] - metadata_key = storage.get_json(dataset["metadata_key"]) - assert metadata_key["training_size"] + metadata_key["eval_size"] == 15 + metadata = storage.get_json(raw_ctx["models"]["dnn"]["metadata"]) + assert metadata["training_size"] + metadata["eval_size"] == 15 assert local_storage_path.joinpath(dataset["train_key"], "_SUCCESS").exists() assert local_storage_path.joinpath(dataset["eval_key"], "_SUCCESS").exists() diff --git a/pkg/workloads/tf_api/api.py b/pkg/workloads/tf_api/api.py index 3abb9ed08b..31982a6512 100644 --- a/pkg/workloads/tf_api/api.py +++ b/pkg/workloads/tf_api/api.py @@ -95,7 +95,10 @@ def create_prediction_request(transformed_sample): prediction_request.model_spec.signature_name = signature_key for column_name, value in transformed_sample.items(): - data_type = tf_lib.CORTEX_TYPE_TO_TF_TYPE[ctx.columns[column_name]["type"]] + columnType = ctx.columns[column_name]["type"] + if columnType == "unknown": + columnType = ctx.columns[column_name]["metadata"]["type"] + data_type = tf_lib.CORTEX_TYPE_TO_TF_TYPE[columnType] shape = [1] if util.is_list(value): shape = [len(value)] diff --git a/pkg/workloads/tf_train/train_util.py b/pkg/workloads/tf_train/train_util.py index 894e5066b0..d6f1fb607e 100644 --- a/pkg/workloads/tf_train/train_util.py +++ b/pkg/workloads/tf_train/train_util.py @@ -148,11 +148,10 @@ def train(model_name, model_impl, ctx, model_dir): serving_input_fn = generate_json_serving_input_fn(model_name, ctx, model_impl) exporter = tf.estimator.FinalExporter("estimator", serving_input_fn, as_text=False) - dataset_metadata = ctx.storage.get_json(model["dataset"]["metadata_key"]) train_num_steps = model["training"]["num_steps"] if model["training"]["num_epochs"]: train_num_steps = ( - math.ceil(dataset_metadata["training_size"] / float(model["training"]["batch_size"])) + math.ceil(model["metadata"]["training_size"] / float(model["training"]["batch_size"])) * model["training"]["num_epochs"] ) @@ -161,7 +160,7 @@ def train(model_name, model_impl, ctx, model_dir): eval_num_steps = model["evaluation"]["num_steps"] if model["evaluation"]["num_epochs"]: eval_num_steps = ( - math.ceil(dataset_metadata["eval_size"] / float(model["evaluation"]["batch_size"])) + math.ceil(model["metadata"]["training_size"] / float(model["evaluation"]["batch_size"])) * model["evaluation"]["num_epochs"] ) From 5c835c2ec545c2c3f98a640cf1ca8bdff07de7cb Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Tue, 7 May 2019 18:00:19 -0400 Subject: [PATCH 03/48] comment out error stack printing --- pkg/lib/errors/errors.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/lib/errors/errors.go b/pkg/lib/errors/errors.go index c3dc2de155..342e07a307 100644 --- a/pkg/lib/errors/errors.go +++ b/pkg/lib/errors/errors.go @@ -151,7 +151,7 @@ func Panic(items ...interface{}) { func PrintError(err error, strs ...string) { wrappedErr := Wrap(err, strs...) fmt.Println("error:", wrappedErr.Error()) - PrintStacktrace(wrappedErr) + // PrintStacktrace(wrappedErr) } func PrintStacktrace(err error) { From b54ffd3a8b2275f710f380fdafb67a60039da01b Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Tue, 7 May 2019 18:01:28 -0400 Subject: [PATCH 04/48] format --- pkg/workloads/lib/context.py | 24 +++++++++++------------- pkg/workloads/spark_job/spark_util.py | 9 ++++++--- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index 94b0f79b18..c55b4dca4a 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -139,7 +139,6 @@ def __init__(self, **kwargs): self.constants_id_map, ) - def is_raw_column(self, name): return name in self.raw_columns @@ -476,12 +475,11 @@ def resource_status_key(self, resource): def update_metadata(self, metadata, context_key, context_item=""): if context_item == "": self.ctx[context_key]["metadata"] = metadata - self.storage.put_json(metadata, self.ctx[context_key]["metadata_key"]) + self.storage.put_json(metadata, self.ctx[context_key]["metadata_key"]) return self.ctx[context_key][context_item]["metadata"] = metadata - self.storage.put_json(metadata, self.ctx[context_key][context_item]["metadata_key"]) - + self.storage.put_json(metadata, self.ctx[context_key][context_item]["metadata_key"]) def fetch_metadata(self): for k, v in self.python_packages.items(): @@ -494,53 +492,53 @@ def fetch_metadata(self): metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) if not metadata: metadata = {} - self.raw_columns[k]["metadata"] = metadata + self.raw_columns[k]["metadata"] = metadata for k, v in self.transformed_columns.items(): metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) if not metadata: metadata = {} - self.transformed_columns[k]["metadata"] = metadata + self.transformed_columns[k]["metadata"] = metadata for k, v in self.transformers.items(): metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) if not metadata: metadata = {} - self.transformers[k]["metadata"] = metadata + self.transformers[k]["metadata"] = metadata for k, v in self.aggregators.items(): metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) if not metadata: metadata = {} - self.aggregators[k]["metadata"] = metadata + self.aggregators[k]["metadata"] = metadata for k, v in self.aggregates.items(): metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) if not metadata: metadata = {} - self.aggregates[k]["metadata"] = metadata + self.aggregates[k]["metadata"] = metadata for k, v in self.constants.items(): metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) if not metadata: metadata = {} - self.constants[k]["metadata"] = metadata + self.constants[k]["metadata"] = metadata for k, v in self.models.items(): metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) if not metadata: metadata = {} - self.models[k]["metadata"] = metadata + self.models[k]["metadata"] = metadata for k, v in self.apis.items(): metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) if not metadata: metadata = {} - self.apis[k]["metadata"] = metadata + self.apis[k]["metadata"] = metadata metadata = self.storage.get_json(self.raw_dataset["metadata_key"], allow_missing=True) if not metadata: - metadata = {} + metadata = {} self.raw_dataset["metadata"] = metadata diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 5dab66bfa1..b3a0416a3c 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -61,7 +61,6 @@ } - def accumulate_count(df, spark): acc = df._sc.accumulator(0) first_column_schema = df.schema[0] @@ -397,7 +396,9 @@ def run_custom_aggregator(aggregator_resource, df, ctx, spark): "function aggregate_spark", ) from e - if not aggregator["skip_validation"] and not util.validate_value_type(result, aggregator["output_type"]): + if not aggregator["skip_validation"] and not util.validate_value_type( + result, aggregator["output_type"] + ): raise UserException( "aggregate " + aggregator_resource["name"], "aggregator " + aggregator["name"], @@ -489,7 +490,9 @@ def validate_transformer(column_name, df, ctx, spark): ctx.transformed_columns[column_name]["type"] = typeConversionDict[rowType] # for downstream operations on other jobs - ctx.update_metadata({"type": typeConversionDict[rowType]}, "transformed_columns", column_name) + ctx.update_metadata( + {"type": typeConversionDict[rowType]}, "transformed_columns", column_name + ) try: transform_python_collect = execute_transform_python( From 162bc8685f1a816ef7c3769a476c06d0b15b5d8f Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Tue, 7 May 2019 18:11:11 -0400 Subject: [PATCH 05/48] simplify fetch_metadata --- pkg/workloads/lib/context.py | 71 +++++++++--------------------------- 1 file changed, 18 insertions(+), 53 deletions(-) diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index c55b4dca4a..9e7e70b391 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -482,59 +482,24 @@ def update_metadata(self, metadata, context_key, context_item=""): self.storage.put_json(metadata, self.ctx[context_key][context_item]["metadata_key"]) def fetch_metadata(self): - for k, v in self.python_packages.items(): - metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) - if not metadata: - metadata = {} - self.python_packages[k]["metadata"] = metadata - - for k, v in self.raw_columns.items(): - metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) - if not metadata: - metadata = {} - self.raw_columns[k]["metadata"] = metadata - - for k, v in self.transformed_columns.items(): - metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) - if not metadata: - metadata = {} - self.transformed_columns[k]["metadata"] = metadata - - for k, v in self.transformers.items(): - metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) - if not metadata: - metadata = {} - self.transformers[k]["metadata"] = metadata - - for k, v in self.aggregators.items(): - metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) - if not metadata: - metadata = {} - self.aggregators[k]["metadata"] = metadata - - for k, v in self.aggregates.items(): - metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) - if not metadata: - metadata = {} - self.aggregates[k]["metadata"] = metadata - - for k, v in self.constants.items(): - metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) - if not metadata: - metadata = {} - self.constants[k]["metadata"] = metadata - - for k, v in self.models.items(): - metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) - if not metadata: - metadata = {} - self.models[k]["metadata"] = metadata - - for k, v in self.apis.items(): - metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) - if not metadata: - metadata = {} - self.apis[k]["metadata"] = metadata + resources = [ + "python_packages", + "raw_columns", + "transformed_columns", + "transformers", + "aggregators", + "aggregates", + "constants", + "models", + "apis", + ] + + for resource in resources: + for k, v in self.ctx[resource].items(): + metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) + if not metadata: + metadata = {} + self.ctx[resource][k]["metadata"] = metadata metadata = self.storage.get_json(self.raw_dataset["metadata_key"], allow_missing=True) if not metadata: From b8db1387e70d66d50de115ad98af34cfb2fdbd51 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Wed, 8 May 2019 15:45:38 -0400 Subject: [PATCH 06/48] clean up --- pkg/operator/api/context/context.go | 5 ----- pkg/operator/api/userconfig/config.go | 12 +++++++++-- pkg/operator/api/userconfig/errors.go | 21 ++++++++++++++++++- pkg/workloads/lib/context.py | 1 + .../spark_job/test/integration/iris_test.py | 4 ++-- 5 files changed, 33 insertions(+), 10 deletions(-) diff --git a/pkg/operator/api/context/context.go b/pkg/operator/api/context/context.go index 72dcddf632..594df54b87 100644 --- a/pkg/operator/api/context/context.go +++ b/pkg/operator/api/context/context.go @@ -56,7 +56,6 @@ type Resource interface { GetIDWithTags() string GetResourceFields() *ResourceFields GetMetadataKey() string - SetMetadataKey(string) } type ComputedResource interface { @@ -98,10 +97,6 @@ func (r *ResourceFields) GetMetadataKey() string { return r.MetadataKey } -func (r *ResourceFields) SetMetadataKey(metadataKey string) { - r.MetadataKey = metadataKey -} - func (r *ComputedResourceFields) GetWorkloadID() string { return r.WorkloadID } diff --git a/pkg/operator/api/userconfig/config.go b/pkg/operator/api/userconfig/config.go index a31d3bd73e..98daec6d4e 100644 --- a/pkg/operator/api/userconfig/config.go +++ b/pkg/operator/api/userconfig/config.go @@ -200,9 +200,13 @@ func (config *Config) Validate(envName string) error { } } - // Check local aggregators exist + // Check local aggregators exist or a path to one is defined aggregatorNames := config.Aggregators.Names() for _, aggregate := range config.Aggregates { + if aggregate.AggregatorPath == nil && aggregate.Aggregator == nil { + return ErrorMissingAggregator(aggregate) + } + if aggregate.Aggregator != nil && !strings.Contains(*aggregate.Aggregator, ".") && !slices.HasString(aggregatorNames, *aggregate.Aggregator) { @@ -210,9 +214,13 @@ func (config *Config) Validate(envName string) error { } } - // Check local transformers exist + // Check local transformers exist or a path to one is defined transformerNames := config.Transformers.Names() for _, transformedColumn := range config.TransformedColumns { + if transformedColumn.TransformerPath == nil && transformedColumn.Transformer == nil { + return ErrorMissingTransformer(transformedColumn) + } + if transformedColumn.Transformer != nil && !strings.Contains(*transformedColumn.Transformer, ".") && !slices.HasString(transformerNames, *transformedColumn.Transformer) { diff --git a/pkg/operator/api/userconfig/errors.go b/pkg/operator/api/userconfig/errors.go index 5ff2b9511f..3e9b71652f 100644 --- a/pkg/operator/api/userconfig/errors.go +++ b/pkg/operator/api/userconfig/errors.go @@ -58,6 +58,8 @@ const ( ErrK8sQuantityMustBeInt ErrRegressionTargetType ErrClassificationTargetType + ErrMissingAggregator + ErrMissingTransformer ) var errorKinds = []string{ @@ -90,9 +92,11 @@ var errorKinds = []string{ "err_k8s_quantity_must_be_int", "err_regression_target_type", "err_classification_target_type", + "err_missing_aggregator", + "err_missing_transformer", } -var _ = [1]int{}[int(ErrClassificationTargetType)-(len(errorKinds)-1)] // Ensure list length matches +var _ = [1]int{}[int(ErrMissingTransformer)-(len(errorKinds)-1)] // Ensure list length matches func (t ErrorKind) String() string { return errorKinds[t] @@ -376,9 +380,24 @@ func ErrorRegressionTargetType() error { message: "regression models can only predict float target values", } } + func ErrorClassificationTargetType() error { return Error{ Kind: ErrClassificationTargetType, message: "classification models can only predict integer target values (i.e. {0, 1, ..., num_classes-1})", } } + +func ErrorMissingAggregator(aggregate *Aggregate) error { + return Error{ + Kind: ErrMissingAggregator, + message: fmt.Sprintf("missing aggregator for aggregate \"%s\", expecting either \"aggregator\" or \"aggregator_path\"", aggregate.Name), + } +} + +func ErrorMissingTransformer(transformedColumn *TransformedColumn) error { + return Error{ + Kind: ErrMissingTransformer, + message: fmt.Sprintf("missing transformer for transformed_column \"%s\", expecting either \"transformer\" or \"transformer_path\"", transformedColumn.Name), + } +} diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index 9e7e70b391..abab9b82d5 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -80,6 +80,7 @@ def __init__(self, **kwargs): self.models = self.ctx["models"] self.apis = self.ctx["apis"] self.training_datasets = {k: v["dataset"] for k, v in self.models.items()} + self.api_version = self.cortex_config["api_version"] if "local_storage_path" in kwargs: diff --git a/pkg/workloads/spark_job/test/integration/iris_test.py b/pkg/workloads/spark_job/test/integration/iris_test.py index 481a57d3ea..00d305edc6 100644 --- a/pkg/workloads/spark_job/test/integration/iris_test.py +++ b/pkg/workloads/spark_job/test/integration/iris_test.py @@ -77,7 +77,7 @@ def test_simple_end_to_end(spark): raw_df = spark_job.ingest_raw_dataset(spark, ctx, cols_to_validate, should_ingest) assert raw_df.count() == 15 - assert storage.get_json(ctx.raw_dataset["metadata_key"])["dataset_size"] == 15 + assert ctx.raw_dataset["metadata"]["dataset_size"] == 15 for raw_column_id in cols_to_validate: path = os.path.join(raw_ctx["status_prefix"], raw_column_id, workload_id) status = storage.get_json(str(path)) @@ -117,7 +117,7 @@ def test_simple_end_to_end(spark): status["exist_code"] = "succeeded" dataset = raw_ctx["models"]["dnn"]["dataset"] - metadata = storage.get_json(raw_ctx["models"]["dnn"]["metadata"]) + metadata = raw_ctx["models"]["dnn"]["metadata"] assert metadata["training_size"] + metadata["eval_size"] == 15 assert local_storage_path.joinpath(dataset["train_key"], "_SUCCESS").exists() assert local_storage_path.joinpath(dataset["eval_key"], "_SUCCESS").exists() From a35f089a7cbff29d0985aa33e762b8f1f3ab13eb Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Wed, 8 May 2019 17:56:38 -0400 Subject: [PATCH 07/48] fix transform spark bug and training_columns --- pkg/workloads/lib/tf_lib.py | 6 +++- pkg/workloads/spark_job/spark_util.py | 46 +++++++++++++++++++++++---- 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/pkg/workloads/lib/tf_lib.py b/pkg/workloads/lib/tf_lib.py index f120bd6de8..cdbc2358c7 100644 --- a/pkg/workloads/lib/tf_lib.py +++ b/pkg/workloads/lib/tf_lib.py @@ -72,7 +72,11 @@ def get_column_tf_types(model_name, ctx, training=True): ] for column_name in model["training_columns"]: - column_types[column_name] = CORTEX_TYPE_TO_TF_TYPE[ctx.columns[column_name]["type"]] + columnType = ctx.columns[column_name]["type"] + if columnType == "unknown": + columnType = ctx.columns[column_name]["metadata"]["type"] + + column_types[column_name] = CORTEX_TYPE_TO_TF_TYPE[columnType] return column_types diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index b3a0416a3c..790469a709 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -60,6 +60,18 @@ str: consts.COLUMN_TYPE_STRING_LIST, } +SPARK_TYPE_TO_CORTEX_TYPE = { + IntegerType(): consts.COLUMN_TYPE_INT, + LongType(): consts.COLUMN_TYPE_INT, + ArrayType(IntegerType(), True): consts.COLUMN_TYPE_INT_LIST, + ArrayType(LongType(), True): consts.COLUMN_TYPE_INT_LIST, + FloatType(): consts.COLUMN_TYPE_FLOAT, + DoubleType(): consts.COLUMN_TYPE_FLOAT, + ArrayType(FloatType(), True): consts.COLUMN_TYPE_FLOAT_LIST, + ArrayType(DoubleType(), True): consts.COLUMN_TYPE_FLOAT_LIST, + StringType(): consts.COLUMN_TYPE_STRING, + ArrayType(StringType(), True): consts.COLUMN_TYPE_STRING_LIST, +} def accumulate_count(df, spark): acc = df._sc.accumulator(0) @@ -536,9 +548,13 @@ def validate_transformer(column_name, df, ctx, spark): actual_structfield = transform_spark_df.select(column_name).schema.fields[0] + transformer = ctx.transformers[transformed_column["transformer"]] + skip_validation = transformer["skip_validation"] + # check that expected output column has the correct data type if ( - actual_structfield.dataType + not skip_validation + and actual_structfield.dataType not in CORTEX_TYPE_TO_ACCEPTABLE_SPARK_TYPES[transformed_column["type"]] ): raise UserException( @@ -554,12 +570,13 @@ def validate_transformer(column_name, df, ctx, spark): ) # perform the necessary upcast/downcast for the column e.g INT -> LONG or DOUBLE -> FLOAT - transform_spark_df = transform_spark_df.withColumn( - column_name, - F.col(column_name).cast( - CORTEX_TYPE_TO_SPARK_TYPE[ctx.transformed_columns[column_name]["type"]] - ), - ) + if not skip_validation: + transform_spark_df = transform_spark_df.withColumn( + column_name, + F.col(column_name).cast( + CORTEX_TYPE_TO_SPARK_TYPE[ctx.transformed_columns[column_name]["type"]] + ), + ) # check that the function doesn't modify the schema of the other columns in the input dataframe if set(transform_spark_df.columns) - set([column_name]) != set(df.columns): @@ -613,6 +630,21 @@ def transform_column(column_name, df, ctx, spark): trans_impl, trans_impl_path = ctx.get_transformer_impl(column_name) if hasattr(trans_impl, "transform_spark"): + skip_validation = ctx.transformers[ctx.transformed_columns[column_name]["transformer"]]["skip_validation"] + if skip_validation: + df = execute_transform_spark(column_name, df, ctx, spark) + column_type = df.select(column_name).schema[0].dataType + # for downstream operations on other jobs + ctx.update_metadata( + {"type": SPARK_TYPE_TO_CORTEX_TYPE[column_type]}, "transformed_columns", column_name + ) + return df.withColumn( + column_name, + F.col(column_name).cast( + column_type + ), + ) + return execute_transform_spark(column_name, df, ctx, spark).withColumn( column_name, F.col(column_name).cast( From fbf9f56a0f2ba4bc487b6d7ba0ae787984662455 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Wed, 8 May 2019 17:56:56 -0400 Subject: [PATCH 08/48] format --- pkg/workloads/spark_job/spark_util.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 790469a709..b9ab9c810e 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -70,9 +70,10 @@ ArrayType(FloatType(), True): consts.COLUMN_TYPE_FLOAT_LIST, ArrayType(DoubleType(), True): consts.COLUMN_TYPE_FLOAT_LIST, StringType(): consts.COLUMN_TYPE_STRING, - ArrayType(StringType(), True): consts.COLUMN_TYPE_STRING_LIST, + ArrayType(StringType(), True): consts.COLUMN_TYPE_STRING_LIST, } + def accumulate_count(df, spark): acc = df._sc.accumulator(0) first_column_schema = df.schema[0] @@ -630,20 +631,17 @@ def transform_column(column_name, df, ctx, spark): trans_impl, trans_impl_path = ctx.get_transformer_impl(column_name) if hasattr(trans_impl, "transform_spark"): - skip_validation = ctx.transformers[ctx.transformed_columns[column_name]["transformer"]]["skip_validation"] + skip_validation = ctx.transformers[ctx.transformed_columns[column_name]["transformer"]][ + "skip_validation" + ] if skip_validation: df = execute_transform_spark(column_name, df, ctx, spark) - column_type = df.select(column_name).schema[0].dataType + column_type = df.select(column_name).schema[0].dataType # for downstream operations on other jobs ctx.update_metadata( {"type": SPARK_TYPE_TO_CORTEX_TYPE[column_type]}, "transformed_columns", column_name ) - return df.withColumn( - column_name, - F.col(column_name).cast( - column_type - ), - ) + return df.withColumn(column_name, F.col(column_name).cast(column_type)) return execute_transform_spark(column_name, df, ctx, spark).withColumn( column_name, From 78e00952e192752a9eb678003c97e03eed5d53d8 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Wed, 8 May 2019 18:28:55 -0400 Subject: [PATCH 09/48] fix url --- images/spark-base/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/images/spark-base/Dockerfile b/images/spark-base/Dockerfile index 9ff1cd3bf3..1e37b61201 100644 --- a/images/spark-base/Dockerfile +++ b/images/spark-base/Dockerfile @@ -30,7 +30,7 @@ RUN curl http://archive.apache.org/dist/hadoop/common/hadoop-${HADOOP_VERSION}/h rm -rf $HADOOP_HOME/share/doc # Spark -RUN curl http://www.us.apache.org/dist/spark/spark-${SPARK_VERSION}/spark-${SPARK_VERSION}-bin-without-hadoop.tgz | tar -zx && \ +RUN curl http://archive.apache.org/dist/spark/spark-${SPARK_VERSION}/spark-${SPARK_VERSION}-bin-without-hadoop.tgz | tar -zx && \ mv spark-${SPARK_VERSION}-bin-without-hadoop $SPARK_HOME # Tensorflow Spark connector From 09992cc51c2e0f1e916aa62ea76950936bd801bc Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Wed, 8 May 2019 18:29:30 -0400 Subject: [PATCH 10/48] unfix url --- images/spark-base/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/images/spark-base/Dockerfile b/images/spark-base/Dockerfile index 9ff1cd3bf3..1e37b61201 100644 --- a/images/spark-base/Dockerfile +++ b/images/spark-base/Dockerfile @@ -30,7 +30,7 @@ RUN curl http://archive.apache.org/dist/hadoop/common/hadoop-${HADOOP_VERSION}/h rm -rf $HADOOP_HOME/share/doc # Spark -RUN curl http://www.us.apache.org/dist/spark/spark-${SPARK_VERSION}/spark-${SPARK_VERSION}-bin-without-hadoop.tgz | tar -zx && \ +RUN curl http://archive.apache.org/dist/spark/spark-${SPARK_VERSION}/spark-${SPARK_VERSION}-bin-without-hadoop.tgz | tar -zx && \ mv spark-${SPARK_VERSION}-bin-without-hadoop $SPARK_HOME # Tensorflow Spark connector From 36c5fcf4e27a7ee9312082766925fdfdb4eed8d7 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Fri, 10 May 2019 17:39:01 -0400 Subject: [PATCH 11/48] address some comments --- pkg/lib/strings/operations.go | 4 --- pkg/operator/api/userconfig/aggregates.go | 6 ++-- pkg/operator/api/userconfig/config.go | 36 +++++++++++++------ pkg/operator/api/userconfig/errors.go | 20 ++++++++++- .../api/userconfig/transformed_columns.go | 6 ++-- pkg/operator/context/aggregates.go | 12 +------ pkg/operator/context/aggregators.go | 23 ++++-------- pkg/operator/context/autogenerator.go | 21 +++-------- pkg/operator/context/transformed_columns.go | 12 +------ pkg/operator/context/transformers.go | 19 +++------- 10 files changed, 69 insertions(+), 90 deletions(-) diff --git a/pkg/lib/strings/operations.go b/pkg/lib/strings/operations.go index 11338c4f16..60cb18648e 100644 --- a/pkg/lib/strings/operations.go +++ b/pkg/lib/strings/operations.go @@ -128,7 +128,3 @@ func StrsSentence(strs []string, lastJoinWord string) string { return strings.Join(strs[:lastIndex], ", ") + ", " + lastJoinWord + " " + strs[lastIndex] } } - -func PathToName(path string) string { - return strings.Replace(strings.Replace(path, "/", "_", -1), ".", "_", -1) -} diff --git a/pkg/operator/api/userconfig/aggregates.go b/pkg/operator/api/userconfig/aggregates.go index d83a717772..e79e4101cd 100644 --- a/pkg/operator/api/userconfig/aggregates.go +++ b/pkg/operator/api/userconfig/aggregates.go @@ -27,7 +27,7 @@ type Aggregates []*Aggregate type Aggregate struct { ResourceFields - Aggregator *string `json:"aggregator" yaml:"aggregator"` + Aggregator string `json:"aggregator" yaml:"aggregator"` AggregatorPath *string `json:"aggregator_path" yaml:"aggregator_path"` Inputs *Inputs `json:"inputs" yaml:"inputs"` Compute *SparkCompute `json:"compute" yaml:"compute"` @@ -45,8 +45,8 @@ var aggregateValidation = &configreader.StructValidation{ }, { StructField: "Aggregator", - StringPtrValidation: &configreader.StringPtrValidation{ - AlphaNumericDashDotUnderscore: true, + StringValidation: &configreader.StringValidation{ + AllowEmpty: true, }, }, { diff --git a/pkg/operator/api/userconfig/config.go b/pkg/operator/api/userconfig/config.go index 98daec6d4e..8f9d6f658c 100644 --- a/pkg/operator/api/userconfig/config.go +++ b/pkg/operator/api/userconfig/config.go @@ -203,28 +203,44 @@ func (config *Config) Validate(envName string) error { // Check local aggregators exist or a path to one is defined aggregatorNames := config.Aggregators.Names() for _, aggregate := range config.Aggregates { - if aggregate.AggregatorPath == nil && aggregate.Aggregator == nil { + if aggregate.AggregatorPath == nil && aggregate.Aggregator == "" { return ErrorMissingAggregator(aggregate) } - if aggregate.Aggregator != nil && - !strings.Contains(*aggregate.Aggregator, ".") && - !slices.HasString(aggregatorNames, *aggregate.Aggregator) { - return errors.Wrap(ErrorUndefinedResource(*aggregate.Aggregator, resource.AggregatorType), Identify(aggregate), AggregatorKey) + if aggregate.AggregatorPath != nil && aggregate.Aggregator != "" { + return ErrorMultipleAggregatorSpecified(aggregate) + } + + switch { + case aggregate.AggregatorPath != nil: + continue + case aggregate.Aggregator != "": + if !strings.Contains(aggregate.Aggregator, ".") && + !slices.HasString(aggregatorNames, aggregate.Aggregator) { + return errors.Wrap(ErrorUndefinedResource(aggregate.Aggregator, resource.AggregatorType), Identify(aggregate), AggregatorKey) + } } } // Check local transformers exist or a path to one is defined transformerNames := config.Transformers.Names() for _, transformedColumn := range config.TransformedColumns { - if transformedColumn.TransformerPath == nil && transformedColumn.Transformer == nil { + if transformedColumn.TransformerPath == nil && transformedColumn.Transformer == "" { return ErrorMissingTransformer(transformedColumn) } - if transformedColumn.Transformer != nil && - !strings.Contains(*transformedColumn.Transformer, ".") && - !slices.HasString(transformerNames, *transformedColumn.Transformer) { - return errors.Wrap(ErrorUndefinedResource(*transformedColumn.Transformer, resource.TransformerType), Identify(transformedColumn), TransformerKey) + if transformedColumn.TransformerPath != nil && transformedColumn.Transformer != "" { + return ErrorMultipleTransformerSpecified(transformedColumn) + } + + switch { + case transformedColumn.TransformerPath != nil: + continue + case transformedColumn.Transformer != "": + if !strings.Contains(transformedColumn.Transformer, ".") && + !slices.HasString(transformerNames, transformedColumn.Transformer) { + return errors.Wrap(ErrorUndefinedResource(transformedColumn.Transformer, resource.TransformerType), Identify(transformedColumn), TransformerKey) + } } } diff --git a/pkg/operator/api/userconfig/errors.go b/pkg/operator/api/userconfig/errors.go index 3e9b71652f..c59ddd183f 100644 --- a/pkg/operator/api/userconfig/errors.go +++ b/pkg/operator/api/userconfig/errors.go @@ -60,6 +60,8 @@ const ( ErrClassificationTargetType ErrMissingAggregator ErrMissingTransformer + ErrMultipleAggregatorSpecified + ErrMultipleTransformerSpecified ) var errorKinds = []string{ @@ -94,9 +96,11 @@ var errorKinds = []string{ "err_classification_target_type", "err_missing_aggregator", "err_missing_transformer", + "err_multiple_aggregator_specified", + "err_multiple_transformer_specified", } -var _ = [1]int{}[int(ErrMissingTransformer)-(len(errorKinds)-1)] // Ensure list length matches +var _ = [1]int{}[int(ErrMultipleTransformerSpecified)-(len(errorKinds)-1)] // Ensure list length matches func (t ErrorKind) String() string { return errorKinds[t] @@ -401,3 +405,17 @@ func ErrorMissingTransformer(transformedColumn *TransformedColumn) error { message: fmt.Sprintf("missing transformer for transformed_column \"%s\", expecting either \"transformer\" or \"transformer_path\"", transformedColumn.Name), } } + +func ErrorMultipleAggregatorSpecified(aggregate *Aggregate) error { + return Error{ + Kind: ErrMultipleAggregatorSpecified, + message: fmt.Sprintf("aggregate \"%s\" specified both \"aggregator\" and \"aggregator_path\", please specify only one", aggregate.Name), + } +} + +func ErrorMultipleTransformerSpecified(transformedColumn *TransformedColumn) error { + return Error{ + Kind: ErrMultipleTransformerSpecified, + message: fmt.Sprintf("transformed_column \"%s\" specified both \"transformer\" and \"transformer_path\", please specify only one", transformedColumn.Name), + } +} diff --git a/pkg/operator/api/userconfig/transformed_columns.go b/pkg/operator/api/userconfig/transformed_columns.go index 9de4000134..be7a3c8b70 100644 --- a/pkg/operator/api/userconfig/transformed_columns.go +++ b/pkg/operator/api/userconfig/transformed_columns.go @@ -27,7 +27,7 @@ type TransformedColumns []*TransformedColumn type TransformedColumn struct { ResourceFields - Transformer *string `json:"transformer" yaml:"transformer"` + Transformer string `json:"transformer" yaml:"transformer"` TransformerPath *string `json:"transformer_path" yaml:"transformer_path"` Inputs *Inputs `json:"inputs" yaml:"inputs"` Compute *SparkCompute `json:"compute" yaml:"compute"` @@ -45,8 +45,8 @@ var transformedColumnValidation = &configreader.StructValidation{ }, { StructField: "Transformer", - StringPtrValidation: &configreader.StringPtrValidation{ - AlphaNumericDashDotUnderscore: true, + StringValidation: &configreader.StringValidation{ + AllowEmpty: true, }, }, { diff --git a/pkg/operator/context/aggregates.go b/pkg/operator/context/aggregates.go index 111d66a85a..67c460b40c 100644 --- a/pkg/operator/context/aggregates.go +++ b/pkg/operator/context/aggregates.go @@ -44,17 +44,7 @@ func getAggregates( return nil, userconfig.ErrorDuplicateResourceName(aggregateConfig, constants[aggregateConfig.Name]) } - var aggName string - if aggregateConfig.Aggregator != nil { - aggName = *aggregateConfig.Aggregator - } - - if aggregateConfig.AggregatorPath != nil { - aggName = *aggregateConfig.AggregatorPath - aggregateConfig.Aggregator = &aggName - } - - aggregator, err := getAggregator(aggName, userAggregators) + aggregator, err := getAggregator(aggregateConfig.Aggregator, userAggregators) if err != nil { return nil, errors.Wrap(err, userconfig.Identify(aggregateConfig), userconfig.AggregatorKey) } diff --git a/pkg/operator/context/aggregators.go b/pkg/operator/context/aggregators.go index e219b69423..b1ac423508 100644 --- a/pkg/operator/context/aggregators.go +++ b/pkg/operator/context/aggregators.go @@ -23,7 +23,6 @@ import ( "github.com/cortexlabs/cortex/pkg/consts" "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/hash" - s "github.com/cortexlabs/cortex/pkg/lib/strings" "github.com/cortexlabs/cortex/pkg/operator/api/context" "github.com/cortexlabs/cortex/pkg/operator/api/resource" "github.com/cortexlabs/cortex/pkg/operator/api/userconfig" @@ -61,7 +60,7 @@ func loadUserAggregators( anonAggregatorConfig := &userconfig.Aggregator{ ResourceFields: userconfig.ResourceFields{ - Name: s.PathToName(*aggregateConfig.AggregatorPath), + Name: hash.Bytes(impl), }, Path: *aggregateConfig.AggregatorPath, } @@ -69,7 +68,9 @@ func loadUserAggregators( if err != nil { return nil, err } - userAggregators[*aggregateConfig.AggregatorPath] = aggregator + + aggregateConfig.Aggregator = aggregator.Name + userAggregators[anonAggregatorConfig.Name] = aggregator } return userAggregators, nil @@ -159,24 +160,14 @@ func getAggregators( aggregators := context.Aggregators{} for _, aggregateConfig := range config.Aggregates { - - var aggregatorName string - if aggregateConfig.Aggregator != nil { - aggregatorName = *aggregateConfig.Aggregator - } - - if aggregateConfig.AggregatorPath != nil { - aggregatorName = *aggregateConfig.AggregatorPath - } - - if _, ok := aggregators[aggregatorName]; ok { + if _, ok := aggregators[aggregateConfig.Aggregator]; ok { continue } - aggregator, err := getAggregator(aggregatorName, userAggregators) + aggregator, err := getAggregator(aggregateConfig.Aggregator, userAggregators) if err != nil { return nil, errors.Wrap(err, userconfig.Identify(aggregateConfig), userconfig.AggregatorKey) } - aggregators[aggregatorName] = aggregator + aggregators[aggregateConfig.Aggregator] = aggregator } return aggregators, nil diff --git a/pkg/operator/context/autogenerator.go b/pkg/operator/context/autogenerator.go index b3f1b1ab90..1c1904f68f 100644 --- a/pkg/operator/context/autogenerator.go +++ b/pkg/operator/context/autogenerator.go @@ -43,16 +43,11 @@ func autoGenerateConfig( } } - var name string - if aggregate.Aggregator != nil { - name = *aggregate.Aggregator - } - if aggregate.AggregatorPath != nil { - name = *aggregate.AggregatorPath + continue } - aggregator, err := getAggregator(name, userAggregators) + aggregator, err := getAggregator(aggregate.Aggregator, userAggregators) if err != nil { return errors.Wrap(err, userconfig.Identify(aggregate), userconfig.AggregatorKey) } @@ -93,16 +88,11 @@ func autoGenerateConfig( } } - var name string - if transformedColumn.Transformer != nil { - name = *transformedColumn.Transformer - } - if transformedColumn.TransformerPath != nil { - name = s.PathToName(*transformedColumn.TransformerPath) + continue } - transformer, err := getTransformer(name, userTransformers) + transformer, err := getTransformer(transformedColumn.Transformer, userTransformers) if err != nil { return errors.Wrap(err, userconfig.Identify(transformedColumn), userconfig.TransformerKey) } @@ -133,8 +123,5 @@ func autoGenerateConfig( } } - if err := config.Validate(config.Environment.Name); err != nil { - return err - } return nil } diff --git a/pkg/operator/context/transformed_columns.go b/pkg/operator/context/transformed_columns.go index a6267e0885..edbfd7878d 100644 --- a/pkg/operator/context/transformed_columns.go +++ b/pkg/operator/context/transformed_columns.go @@ -41,17 +41,7 @@ func getTransformedColumns( transformedColumns := context.TransformedColumns{} for _, transformedColumnConfig := range config.TransformedColumns { - var transName string - if transformedColumnConfig.Transformer != nil { - transName = *transformedColumnConfig.Transformer - } - - if transformedColumnConfig.TransformerPath != nil { - transName = s.PathToName(*transformedColumnConfig.TransformerPath) - transformedColumnConfig.Transformer = &transName - } - - transformer, err := getTransformer(transName, userTransformers) + transformer, err := getTransformer(transformedColumnConfig.Transformer, userTransformers) if err != nil { return nil, errors.Wrap(err, userconfig.Identify(transformedColumnConfig), userconfig.TransformerKey) } diff --git a/pkg/operator/context/transformers.go b/pkg/operator/context/transformers.go index c513644683..6061619e59 100644 --- a/pkg/operator/context/transformers.go +++ b/pkg/operator/context/transformers.go @@ -23,7 +23,6 @@ import ( "github.com/cortexlabs/cortex/pkg/consts" "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/hash" - s "github.com/cortexlabs/cortex/pkg/lib/strings" "github.com/cortexlabs/cortex/pkg/operator/api/context" "github.com/cortexlabs/cortex/pkg/operator/api/resource" "github.com/cortexlabs/cortex/pkg/operator/api/userconfig" @@ -61,7 +60,7 @@ func loadUserTransformers( anonTransformerConfig := &userconfig.Transformer{ ResourceFields: userconfig.ResourceFields{ - Name: s.PathToName(*transColConfig.TransformerPath), + Name: hash.Bytes(impl), }, Path: *transColConfig.TransformerPath, } @@ -69,6 +68,7 @@ func loadUserTransformers( if err != nil { return nil, err } + transColConfig.Transformer = transformer.Name userTransformers[transformer.Name] = transformer } return userTransformers, nil @@ -156,24 +156,15 @@ func getTransformers( transformers := context.Transformers{} for _, transformedColumnConfig := range config.TransformedColumns { - var transformerName string - if transformedColumnConfig.Transformer != nil { - transformerName = *transformedColumnConfig.Transformer - } - - if transformedColumnConfig.TransformerPath != nil { - transformerName = s.PathToName(*transformedColumnConfig.TransformerPath) - } - - if _, ok := transformers[transformerName]; ok { + if _, ok := transformers[transformedColumnConfig.Transformer]; ok { continue } - transformer, err := getTransformer(transformerName, userTransformers) + transformer, err := getTransformer(transformedColumnConfig.Transformer, userTransformers) if err != nil { return nil, errors.Wrap(err, userconfig.Identify(transformedColumnConfig), userconfig.TransformerKey) } - transformers[transformerName] = transformer + transformers[transformedColumnConfig.Transformer] = transformer } return transformers, nil From 6a3a9bdaee39b50728ba20498eec2f20062e372f Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Fri, 10 May 2019 18:12:30 -0400 Subject: [PATCH 12/48] check 5 samples --- pkg/workloads/spark_job/spark_util.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index b9ab9c810e..02863b46c8 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -494,6 +494,17 @@ def validate_transformer(column_name, df, ctx, spark): transformedSample = trans_impl.transform_python(inputs, impl_args) rowType = type(transformedSample) isList = rowType == list + + for row in sample_df: + inputs = ctx.create_column_inputs_map(row, column_name) + transformedSample = trans_impl.transform_python(inputs, impl_args) + if rowType != type(transformedSample): + raise UserRuntimeException( + "transformed column " + column_name, + "type inference failed, mixed data types in dataframe.", + ) + + typeConversionDict = PYTHON_TYPE_TO_CORTEX_TYPE if isList: rowType = type(transformedSample[0]) From c5f760d12fd0d953e6cc6888933bafbdecd194a2 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Mon, 13 May 2019 13:29:00 -0400 Subject: [PATCH 13/48] remove skip validation --- pkg/operator/api/context/aggregators.go | 5 ++--- pkg/operator/api/context/transformers.go | 5 ++--- pkg/operator/context/aggregates.go | 2 +- pkg/operator/context/aggregators.go | 12 +++++------- pkg/operator/context/context.go | 4 ++-- pkg/operator/context/transformed_columns.go | 2 +- pkg/operator/context/transformers.go | 12 +++++------- pkg/workloads/spark_job/spark_util.py | 18 ++++++------------ 8 files changed, 24 insertions(+), 36 deletions(-) diff --git a/pkg/operator/api/context/aggregators.go b/pkg/operator/api/context/aggregators.go index 635151adb0..16edf1e807 100644 --- a/pkg/operator/api/context/aggregators.go +++ b/pkg/operator/api/context/aggregators.go @@ -25,9 +25,8 @@ type Aggregators map[string]*Aggregator type Aggregator struct { *userconfig.Aggregator *ResourceFields - Namespace *string `json:"namespace"` - ImplKey string `json:"impl_key"` - SkipValidation bool `json:"skip_validation"` + Namespace *string `json:"namespace"` + ImplKey string `json:"impl_key"` } func (aggregators Aggregators) OneByID(id string) *Aggregator { diff --git a/pkg/operator/api/context/transformers.go b/pkg/operator/api/context/transformers.go index 65bf53522c..96d127fab7 100644 --- a/pkg/operator/api/context/transformers.go +++ b/pkg/operator/api/context/transformers.go @@ -25,9 +25,8 @@ type Transformers map[string]*Transformer type Transformer struct { *userconfig.Transformer *ResourceFields - Namespace *string `json:"namespace"` - ImplKey string `json:"impl_key"` - SkipValidation bool `json:"skip_validation"` + Namespace *string `json:"namespace"` + ImplKey string `json:"impl_key"` } func (transformers Transformers) OneByID(id string) *Transformer { diff --git a/pkg/operator/context/aggregates.go b/pkg/operator/context/aggregates.go index 67c460b40c..d4ea14374e 100644 --- a/pkg/operator/context/aggregates.go +++ b/pkg/operator/context/aggregates.go @@ -112,7 +112,7 @@ func validateAggregateInputs( rawColumns context.RawColumns, aggregator *context.Aggregator, ) error { - if aggregator.SkipValidation { + if aggregateConfig.AggregatorPath != nil { return nil } diff --git a/pkg/operator/context/aggregators.go b/pkg/operator/context/aggregators.go index b1ac423508..a037095407 100644 --- a/pkg/operator/context/aggregators.go +++ b/pkg/operator/context/aggregators.go @@ -41,7 +41,7 @@ func loadUserAggregators( if !ok { return nil, errors.Wrap(ErrorImplDoesNotExist(aggregatorConfig.Path), userconfig.Identify(aggregatorConfig)) } - aggregator, err := newAggregator(*aggregatorConfig, impl, nil, pythonPackages, false) + aggregator, err := newAggregator(*aggregatorConfig, impl, nil, pythonPackages) if err != nil { return nil, err } @@ -64,7 +64,7 @@ func loadUserAggregators( }, Path: *aggregateConfig.AggregatorPath, } - aggregator, err := newAggregator(*anonAggregatorConfig, impl, nil, pythonPackages, true) + aggregator, err := newAggregator(*anonAggregatorConfig, impl, nil, pythonPackages) if err != nil { return nil, err } @@ -81,7 +81,6 @@ func newAggregator( impl []byte, namespace *string, pythonPackages context.PythonPackages, - skipValidation bool, ) (*context.Aggregator, error) { implID := hash.Bytes(impl) @@ -104,10 +103,9 @@ func newAggregator( ResourceType: resource.AggregatorType, MetadataKey: filepath.Join(consts.AggregatorsDir, id+"_metadata.json"), }, - Aggregator: &aggregatorConfig, - Namespace: namespace, - ImplKey: filepath.Join(consts.AggregatorsDir, implID+".py"), - SkipValidation: skipValidation, + Aggregator: &aggregatorConfig, + Namespace: namespace, + ImplKey: filepath.Join(consts.AggregatorsDir, implID+".py"), } aggregator.Aggregator.Path = "" diff --git a/pkg/operator/context/context.go b/pkg/operator/context/context.go index f0f845e5b3..1d9d0d87ce 100644 --- a/pkg/operator/context/context.go +++ b/pkg/operator/context/context.go @@ -62,7 +62,7 @@ func Init() error { if err != nil { return errors.Wrap(err, userconfig.Identify(aggregatorConfig)) } - aggregator, err := newAggregator(*aggregatorConfig, impl, pointer.String("cortex"), nil, false) + aggregator, err := newAggregator(*aggregatorConfig, impl, pointer.String("cortex"), nil) if err != nil { return err } @@ -81,7 +81,7 @@ func Init() error { if err != nil { return errors.Wrap(err, userconfig.Identify(transConfig)) } - transformer, err := newTransformer(*transConfig, impl, pointer.String("cortex"), nil, false) + transformer, err := newTransformer(*transConfig, impl, pointer.String("cortex"), nil) if err != nil { return err } diff --git a/pkg/operator/context/transformed_columns.go b/pkg/operator/context/transformed_columns.go index edbfd7878d..bcc94cea9c 100644 --- a/pkg/operator/context/transformed_columns.go +++ b/pkg/operator/context/transformed_columns.go @@ -100,7 +100,7 @@ func validateTransformedColumnInputs( aggregates context.Aggregates, transformer *context.Transformer, ) error { - if transformer.SkipValidation { + if transformedColumnConfig.TransformerPath != nil { return nil } diff --git a/pkg/operator/context/transformers.go b/pkg/operator/context/transformers.go index 6061619e59..f3340ef92d 100644 --- a/pkg/operator/context/transformers.go +++ b/pkg/operator/context/transformers.go @@ -41,7 +41,7 @@ func loadUserTransformers( if !ok { return nil, errors.Wrap(ErrorImplDoesNotExist(transConfig.Path), userconfig.Identify(transConfig)) } - transformer, err := newTransformer(*transConfig, impl, nil, pythonPackages, false) + transformer, err := newTransformer(*transConfig, impl, nil, pythonPackages) if err != nil { return nil, err } @@ -64,7 +64,7 @@ func loadUserTransformers( }, Path: *transColConfig.TransformerPath, } - transformer, err := newTransformer(*anonTransformerConfig, impl, nil, pythonPackages, true) + transformer, err := newTransformer(*anonTransformerConfig, impl, nil, pythonPackages) if err != nil { return nil, err } @@ -79,7 +79,6 @@ func newTransformer( impl []byte, namespace *string, pythonPackages context.PythonPackages, - skipValidation bool, ) (*context.Transformer, error) { implID := hash.Bytes(impl) @@ -101,10 +100,9 @@ func newTransformer( ResourceType: resource.TransformerType, MetadataKey: filepath.Join(consts.TransformersDir, id+"_metadata.json"), }, - Transformer: &transConfig, - Namespace: namespace, - ImplKey: filepath.Join(consts.TransformersDir, implID+".py"), - SkipValidation: skipValidation, + Transformer: &transConfig, + Namespace: namespace, + ImplKey: filepath.Join(consts.TransformersDir, implID+".py"), } transformer.Transformer.Path = "" diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 02863b46c8..53ad1a568c 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -398,7 +398,7 @@ def run_custom_aggregator(aggregator_resource, df, ctx, spark): aggregator_column_input = input_schema["columns"] args_schema = input_schema["args"] args = {} - if input_schema.get("args", None) is not None and len(input_schema["args"]) > 0: + if input_schema.get("args", None) is not None and len(args_schema) > 0: args = ctx.populate_args(input_schema["args"]) try: result = aggregator_impl.aggregate_spark(df, aggregator_column_input, args) @@ -409,7 +409,7 @@ def run_custom_aggregator(aggregator_resource, df, ctx, spark): "function aggregate_spark", ) from e - if not aggregator["skip_validation"] and not util.validate_value_type( + if aggregator["output_type"] and not util.validate_value_type( result, aggregator["output_type"] ): raise UserException( @@ -560,12 +560,9 @@ def validate_transformer(column_name, df, ctx, spark): actual_structfield = transform_spark_df.select(column_name).schema.fields[0] - transformer = ctx.transformers[transformed_column["transformer"]] - skip_validation = transformer["skip_validation"] - # check that expected output column has the correct data type if ( - not skip_validation + not transformed_column["transformer_path"] and actual_structfield.dataType not in CORTEX_TYPE_TO_ACCEPTABLE_SPARK_TYPES[transformed_column["type"]] ): @@ -582,7 +579,7 @@ def validate_transformer(column_name, df, ctx, spark): ) # perform the necessary upcast/downcast for the column e.g INT -> LONG or DOUBLE -> FLOAT - if not skip_validation: + if not transformed_column["transformer_path"]: transform_spark_df = transform_spark_df.withColumn( column_name, F.col(column_name).cast( @@ -640,12 +637,9 @@ def transform_column(column_name, df, ctx, spark): return df transformed_column = ctx.transformed_columns[column_name] - trans_impl, trans_impl_path = ctx.get_transformer_impl(column_name) + trans_impl, _ = ctx.get_transformer_impl(column_name) if hasattr(trans_impl, "transform_spark"): - skip_validation = ctx.transformers[ctx.transformed_columns[column_name]["transformer"]][ - "skip_validation" - ] - if skip_validation: + if transformed_column["transformer_path"]: df = execute_transform_spark(column_name, df, ctx, spark) column_type = df.select(column_name).schema[0].dataType # for downstream operations on other jobs From 0bd6b73264ddbf65476a4452195de9166da8d8e8 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Mon, 13 May 2019 17:27:07 -0400 Subject: [PATCH 14/48] remove outdated check --- pkg/workloads/lib/context.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index abab9b82d5..9fa7c0917a 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -230,9 +230,6 @@ def get_transformer_impl(self, column_name): return None, None transformer_name = self.transformed_columns[column_name]["transformer"] - if not transformer_name: - transformer_name = self.transformed_columns[column_name]["transformer_path"] - if transformer_name in self._transformer_impls: return self._transformer_impls[transformer_name] From 081b9b3515a7e44071fc22db7df50d12869766c6 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Mon, 13 May 2019 18:01:14 -0400 Subject: [PATCH 15/48] add AlphaNumericDashDotUnderscoreEmpty --- pkg/lib/configreader/string.go | 25 ++++++++++++------- pkg/operator/api/userconfig/aggregates.go | 3 ++- .../api/userconfig/transformed_columns.go | 3 ++- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/pkg/lib/configreader/string.go b/pkg/lib/configreader/string.go index 45e990715f..e0ecec4189 100644 --- a/pkg/lib/configreader/string.go +++ b/pkg/lib/configreader/string.go @@ -28,15 +28,16 @@ import ( ) type StringValidation struct { - Required bool - Default string - AllowEmpty bool - AllowedValues []string - Prefix string - AlphaNumericDashDotUnderscore bool - AlphaNumericDashUnderscore bool - DNS1035 bool - Validator func(string) (string, error) + Required bool + Default string + AllowEmpty bool + AllowedValues []string + Prefix string + AlphaNumericDashDotUnderscoreOrEmpty bool + AlphaNumericDashDotUnderscore bool + AlphaNumericDashUnderscore bool + DNS1035 bool + Validator func(string) (string, error) } func EnvVar(envVarName string) string { @@ -190,6 +191,12 @@ func ValidateStringVal(val string, v *StringValidation) error { } } + if v.AlphaNumericDashDotUnderscoreOrEmpty { + if !regex.CheckAlphaNumericDashDotUnderscore(val) && val != "" { + return ErrorAlphaNumericDashDotUnderscore(val) + } + } + if v.DNS1035 { if err := urls.CheckDNS1035(val); err != nil { return err diff --git a/pkg/operator/api/userconfig/aggregates.go b/pkg/operator/api/userconfig/aggregates.go index e79e4101cd..613ad89b3a 100644 --- a/pkg/operator/api/userconfig/aggregates.go +++ b/pkg/operator/api/userconfig/aggregates.go @@ -46,7 +46,8 @@ var aggregateValidation = &configreader.StructValidation{ { StructField: "Aggregator", StringValidation: &configreader.StringValidation{ - AllowEmpty: true, + AllowEmpty: true, + AlphaNumericDashDotUnderscoreOrEmpty: true, }, }, { diff --git a/pkg/operator/api/userconfig/transformed_columns.go b/pkg/operator/api/userconfig/transformed_columns.go index be7a3c8b70..476b71f368 100644 --- a/pkg/operator/api/userconfig/transformed_columns.go +++ b/pkg/operator/api/userconfig/transformed_columns.go @@ -46,7 +46,8 @@ var transformedColumnValidation = &configreader.StructValidation{ { StructField: "Transformer", StringValidation: &configreader.StringValidation{ - AllowEmpty: true, + AllowEmpty: true, + AlphaNumericDashDotUnderscoreOrEmpty: true, }, }, { From 7fedb6f37bbd3789daa21d0f9440204c42e7b41f Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Mon, 13 May 2019 18:43:51 -0400 Subject: [PATCH 16/48] address some comments - lowercase resourceFields and replace error --- pkg/operator/api/userconfig/config.go | 29 ++++++++------------- pkg/operator/api/userconfig/errors.go | 34 +++++++++++-------------- pkg/operator/api/userconfig/resource.go | 28 ++++++++++---------- 3 files changed, 40 insertions(+), 51 deletions(-) diff --git a/pkg/operator/api/userconfig/config.go b/pkg/operator/api/userconfig/config.go index 8f9d6f658c..2d651b3be4 100644 --- a/pkg/operator/api/userconfig/config.go +++ b/pkg/operator/api/userconfig/config.go @@ -204,21 +204,17 @@ func (config *Config) Validate(envName string) error { aggregatorNames := config.Aggregators.Names() for _, aggregate := range config.Aggregates { if aggregate.AggregatorPath == nil && aggregate.Aggregator == "" { - return ErrorMissingAggregator(aggregate) + return errors.Wrap(ErrorSpecifyOnlyOneMissing("aggregator", "aggregator_path"), Identify(aggregate)) } if aggregate.AggregatorPath != nil && aggregate.Aggregator != "" { return ErrorMultipleAggregatorSpecified(aggregate) } - switch { - case aggregate.AggregatorPath != nil: - continue - case aggregate.Aggregator != "": - if !strings.Contains(aggregate.Aggregator, ".") && - !slices.HasString(aggregatorNames, aggregate.Aggregator) { - return errors.Wrap(ErrorUndefinedResource(aggregate.Aggregator, resource.AggregatorType), Identify(aggregate), AggregatorKey) - } + if aggregate.Aggregator != "" && + !strings.Contains(aggregate.Aggregator, ".") && + !slices.HasString(aggregatorNames, aggregate.Aggregator) { + return errors.Wrap(ErrorUndefinedResource(aggregate.Aggregator, resource.AggregatorType), Identify(aggregate), AggregatorKey) } } @@ -226,21 +222,18 @@ func (config *Config) Validate(envName string) error { transformerNames := config.Transformers.Names() for _, transformedColumn := range config.TransformedColumns { if transformedColumn.TransformerPath == nil && transformedColumn.Transformer == "" { - return ErrorMissingTransformer(transformedColumn) + return errors.Wrap(ErrorSpecifyOnlyOneMissing("transformer", "transformer_path"), Identify(transformedColumn)) } if transformedColumn.TransformerPath != nil && transformedColumn.Transformer != "" { return ErrorMultipleTransformerSpecified(transformedColumn) } - switch { - case transformedColumn.TransformerPath != nil: - continue - case transformedColumn.Transformer != "": - if !strings.Contains(transformedColumn.Transformer, ".") && - !slices.HasString(transformerNames, transformedColumn.Transformer) { - return errors.Wrap(ErrorUndefinedResource(transformedColumn.Transformer, resource.TransformerType), Identify(transformedColumn), TransformerKey) - } + if transformedColumn.Transformer != "" && + !strings.Contains(transformedColumn.Transformer, ".") && + !slices.HasString(transformerNames, transformedColumn.Transformer) { + return errors.Wrap(ErrorUndefinedResource(transformedColumn.Transformer, resource.TransformerType), Identify(transformedColumn), TransformerKey) + } } diff --git a/pkg/operator/api/userconfig/errors.go b/pkg/operator/api/userconfig/errors.go index c59ddd183f..90a3735b27 100644 --- a/pkg/operator/api/userconfig/errors.go +++ b/pkg/operator/api/userconfig/errors.go @@ -58,10 +58,9 @@ const ( ErrK8sQuantityMustBeInt ErrRegressionTargetType ErrClassificationTargetType - ErrMissingAggregator - ErrMissingTransformer ErrMultipleAggregatorSpecified ErrMultipleTransformerSpecified + ErrSpecifyOnlyOneMissing ) var errorKinds = []string{ @@ -94,13 +93,12 @@ var errorKinds = []string{ "err_k8s_quantity_must_be_int", "err_regression_target_type", "err_classification_target_type", - "err_missing_aggregator", - "err_missing_transformer", "err_multiple_aggregator_specified", "err_multiple_transformer_specified", + "err_specify_only_one_missing", } -var _ = [1]int{}[int(ErrMultipleTransformerSpecified)-(len(errorKinds)-1)] // Ensure list length matches +var _ = [1]int{}[int(ErrSpecifyOnlyOneMissing)-(len(errorKinds)-1)] // Ensure list length matches func (t ErrorKind) String() string { return errorKinds[t] @@ -392,20 +390,6 @@ func ErrorClassificationTargetType() error { } } -func ErrorMissingAggregator(aggregate *Aggregate) error { - return Error{ - Kind: ErrMissingAggregator, - message: fmt.Sprintf("missing aggregator for aggregate \"%s\", expecting either \"aggregator\" or \"aggregator_path\"", aggregate.Name), - } -} - -func ErrorMissingTransformer(transformedColumn *TransformedColumn) error { - return Error{ - Kind: ErrMissingTransformer, - message: fmt.Sprintf("missing transformer for transformed_column \"%s\", expecting either \"transformer\" or \"transformer_path\"", transformedColumn.Name), - } -} - func ErrorMultipleAggregatorSpecified(aggregate *Aggregate) error { return Error{ Kind: ErrMultipleAggregatorSpecified, @@ -419,3 +403,15 @@ func ErrorMultipleTransformerSpecified(transformedColumn *TransformedColumn) err message: fmt.Sprintf("transformed_column \"%s\" specified both \"transformer\" and \"transformer_path\", please specify only one", transformedColumn.Name), } } + +func ErrorSpecifyOnlyOneMissing(vals ...string) error { + message := fmt.Sprintf("please specify one of %s", s.UserStrsOr(vals)) + if len(vals) == 2 { + message = fmt.Sprintf("please specify either %s or %s", s.UserStr(vals[0]), s.UserStr(vals[1])) + } + + return Error{ + Kind: ErrSpecifyOnlyOneMissing, + message: message, + } +} diff --git a/pkg/operator/api/userconfig/resource.go b/pkg/operator/api/userconfig/resource.go index 45f72fd68a..da91815692 100644 --- a/pkg/operator/api/userconfig/resource.go +++ b/pkg/operator/api/userconfig/resource.go @@ -41,32 +41,32 @@ type ResourceFields struct { Embed *Embed `json:"embed" yaml:"-"` } -func (ResourceFields *ResourceFields) GetName() string { - return ResourceFields.Name +func (resourceFields *ResourceFields) GetName() string { + return resourceFields.Name } -func (ResourceFields *ResourceFields) GetIndex() int { - return ResourceFields.Index +func (resourceFields *ResourceFields) GetIndex() int { + return resourceFields.Index } -func (ResourceFields *ResourceFields) SetIndex(index int) { - ResourceFields.Index = index +func (resourceFields *ResourceFields) SetIndex(index int) { + resourceFields.Index = index } -func (ResourceFields *ResourceFields) GetFilePath() string { - return ResourceFields.FilePath +func (resourceFields *ResourceFields) GetFilePath() string { + return resourceFields.FilePath } -func (ResourceFields *ResourceFields) SetFilePath(filePath string) { - ResourceFields.FilePath = filePath +func (resourceFields *ResourceFields) SetFilePath(filePath string) { + resourceFields.FilePath = filePath } -func (ResourceFields *ResourceFields) GetEmbed() *Embed { - return ResourceFields.Embed +func (resourceFields *ResourceFields) GetEmbed() *Embed { + return resourceFields.Embed } -func (ResourceFields *ResourceFields) SetEmbed(embed *Embed) { - ResourceFields.Embed = embed +func (resourceFields *ResourceFields) SetEmbed(embed *Embed) { + resourceFields.Embed = embed } func Identify(r Resource) string { From 944568882f52368e61fd75d8d761491e82308f56 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Mon, 13 May 2019 18:53:41 -0400 Subject: [PATCH 17/48] cache anon specs --- pkg/operator/context/aggregators.go | 7 ++++++- pkg/operator/context/transformers.go | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/pkg/operator/context/aggregators.go b/pkg/operator/context/aggregators.go index a037095407..0129bb658f 100644 --- a/pkg/operator/context/aggregators.go +++ b/pkg/operator/context/aggregators.go @@ -58,9 +58,14 @@ func loadUserAggregators( return nil, errors.Wrap(ErrorImplDoesNotExist(*aggregateConfig.AggregatorPath), userconfig.Identify(aggregateConfig)) } + implHash := hash.Bytes(impl) + if _, ok := userAggregators[implHash]; ok { + continue + } + anonAggregatorConfig := &userconfig.Aggregator{ ResourceFields: userconfig.ResourceFields{ - Name: hash.Bytes(impl), + Name: implHash, }, Path: *aggregateConfig.AggregatorPath, } diff --git a/pkg/operator/context/transformers.go b/pkg/operator/context/transformers.go index f3340ef92d..a335ddcd52 100644 --- a/pkg/operator/context/transformers.go +++ b/pkg/operator/context/transformers.go @@ -58,9 +58,14 @@ func loadUserTransformers( return nil, errors.Wrap(ErrorImplDoesNotExist(*transColConfig.TransformerPath), userconfig.Identify(transColConfig)) } + implHash := hash.Bytes(impl) + if _, ok := userTransformers[implHash]; ok { + continue + } + anonTransformerConfig := &userconfig.Transformer{ ResourceFields: userconfig.ResourceFields{ - Name: hash.Bytes(impl), + Name: implHash, }, Path: *transColConfig.TransformerPath, } From 14e45ff349cc4f51deea03c6736f7bc8a99659e3 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Mon, 13 May 2019 19:44:28 -0400 Subject: [PATCH 18/48] fix autogen --- pkg/operator/context/autogenerator.go | 30 ++++++++++++++------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/pkg/operator/context/autogenerator.go b/pkg/operator/context/autogenerator.go index 1c1904f68f..a28e0d7e6a 100644 --- a/pkg/operator/context/autogenerator.go +++ b/pkg/operator/context/autogenerator.go @@ -43,17 +43,18 @@ func autoGenerateConfig( } } - if aggregate.AggregatorPath != nil { - continue - } - aggregator, err := getAggregator(aggregate.Aggregator, userAggregators) if err != nil { return errors.Wrap(err, userconfig.Identify(aggregate), userconfig.AggregatorKey) } - argType, ok := aggregator.Inputs.Args[argName] - if !ok { - return errors.Wrap(configreader.ErrorUnsupportedKey(argName), userconfig.Identify(aggregate), userconfig.InputsKey, userconfig.ArgsKey) + + var argType interface{} + if aggregator.Inputs != nil { + var ok bool + argType, ok = aggregator.Inputs.Args[argName] + if !ok { + return errors.Wrap(configreader.ErrorUnsupportedKey(argName), userconfig.Identify(aggregate), userconfig.InputsKey, userconfig.ArgsKey) + } } constantName := strings.Join([]string{ @@ -88,17 +89,18 @@ func autoGenerateConfig( } } - if transformedColumn.TransformerPath != nil { - continue - } - transformer, err := getTransformer(transformedColumn.Transformer, userTransformers) if err != nil { return errors.Wrap(err, userconfig.Identify(transformedColumn), userconfig.TransformerKey) } - argType, ok := transformer.Inputs.Args[argName] - if !ok { - return errors.Wrap(configreader.ErrorUnsupportedKey(argName), userconfig.Identify(transformedColumn), userconfig.InputsKey, userconfig.ArgsKey) + + var argType interface{} + if transformer.Inputs != nil { + var ok bool + argType, ok = transformer.Inputs.Args[argName] + if !ok { + return errors.Wrap(configreader.ErrorUnsupportedKey(argName), userconfig.Identify(transformedColumn), userconfig.InputsKey, userconfig.ArgsKey) + } } constantName := strings.Join([]string{ From 03a22bc2a82a626993b4433d231612dca16bdcf0 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Mon, 13 May 2019 20:43:21 -0400 Subject: [PATCH 19/48] progress --- pkg/operator/context/models.go | 4 ++-- pkg/workloads/lib/context.py | 18 ++++++++++++++++++ pkg/workloads/tf_train/train_util.py | 5 +++-- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/pkg/operator/context/models.go b/pkg/operator/context/models.go index 0b69ae5dc9..a1c900e50a 100644 --- a/pkg/operator/context/models.go +++ b/pkg/operator/context/models.go @@ -85,7 +85,6 @@ func getModels( datasetIDWithTags := hash.Bytes(buf.Bytes()) datasetRoot := filepath.Join(root, consts.TrainingDataDir, datasetID) - trainingDatasetName := strings.Join([]string{ modelConfig.Name, resource.TrainingDatasetType.String(), @@ -97,7 +96,7 @@ func getModels( ID: modelID, IDWithTags: modelID, ResourceType: resource.ModelType, - MetadataKey: filepath.Join(datasetRoot, "metadata.json"), + MetadataKey: filepath.Join(root, consts.ModelsDir, modelID+"_metadata.json"), }, }, Model: modelConfig, @@ -115,6 +114,7 @@ func getModels( ID: datasetID, IDWithTags: datasetIDWithTags, ResourceType: resource.TrainingDatasetType, + MetadataKey: filepath.Join(datasetRoot, "metadata.json"), }, }, ModelName: modelConfig.Name, diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index 9fa7c0917a..522f8f34fb 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -479,6 +479,17 @@ def update_metadata(self, metadata, context_key, context_item=""): self.ctx[context_key][context_item]["metadata"] = metadata self.storage.put_json(metadata, self.ctx[context_key][context_item]["metadata_key"]) + def get_metadata(self, context_key, context_item, use_cache=True): + if use_cache and self.ctx[context_key][context_item]["metadata"]: + return self.ctx[context_key][context_item]["metadata"] + + metadata_uri = self.ctx[context_key][context_item]["metadata_key"] + metadata = self.storage.get_json(metadata_uri, allow_missing=True) + self.ctx[context_key][context_item]["metadata"] = metadata + return metadata + + + def fetch_metadata(self): resources = [ "python_packages", @@ -499,6 +510,13 @@ def fetch_metadata(self): metadata = {} self.ctx[resource][k]["metadata"] = metadata + # fetch dataset metadata for models + for k, v in self.ctx["models"].items(): + metadata = self.storage.get_json(v["dataset"]["metadata_key"], allow_missing=True) + if not metadata: + metadata = {} + self.ctx["models"][k]["dataset"]["metadata"] = metadata + metadata = self.storage.get_json(self.raw_dataset["metadata_key"], allow_missing=True) if not metadata: metadata = {} diff --git a/pkg/workloads/tf_train/train_util.py b/pkg/workloads/tf_train/train_util.py index d6f1fb607e..3a5a6e1c60 100644 --- a/pkg/workloads/tf_train/train_util.py +++ b/pkg/workloads/tf_train/train_util.py @@ -149,9 +149,10 @@ def train(model_name, model_impl, ctx, model_dir): exporter = tf.estimator.FinalExporter("estimator", serving_input_fn, as_text=False) train_num_steps = model["training"]["num_steps"] + dataset_metadata = model["dataset"]["metadata"] if model["training"]["num_epochs"]: train_num_steps = ( - math.ceil(model["metadata"]["training_size"] / float(model["training"]["batch_size"])) + math.ceil(dataset_metadata["training_size"] / float(model["training"]["batch_size"])) * model["training"]["num_epochs"] ) @@ -160,7 +161,7 @@ def train(model_name, model_impl, ctx, model_dir): eval_num_steps = model["evaluation"]["num_steps"] if model["evaluation"]["num_epochs"]: eval_num_steps = ( - math.ceil(model["metadata"]["training_size"] / float(model["evaluation"]["batch_size"])) + math.ceil(model["metadata"]["eval_size"] / float(model["evaluation"]["batch_size"])) * model["evaluation"]["num_epochs"] ) From c4a04bbcaa11468ac0deb206da3db719e41c9028 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Tue, 14 May 2019 20:28:02 -0400 Subject: [PATCH 20/48] lazy load metadata --- pkg/workloads/lib/context.py | 30 +++++++++++++++++++++------- pkg/workloads/lib/tf_lib.py | 6 +++--- pkg/workloads/spark_job/spark_job.py | 2 +- pkg/workloads/tf_api/api.py | 2 +- pkg/workloads/tf_train/train_util.py | 4 ++-- 5 files changed, 30 insertions(+), 14 deletions(-) diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index 522f8f34fb..70eab4f70a 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -99,12 +99,12 @@ def __init__(self, **kwargs): ) ) - self.fetch_metadata() - self.columns = util.merge_dicts_overwrite( self.raw_columns, self.transformed_columns # self.aggregates ) + self.ctx["columns"] = self.columns + self.values = util.merge_dicts_overwrite(self.aggregates, self.constants) self.raw_column_names = list(self.raw_columns.keys()) @@ -479,8 +479,26 @@ def update_metadata(self, metadata, context_key, context_item=""): self.ctx[context_key][context_item]["metadata"] = metadata self.storage.put_json(metadata, self.ctx[context_key][context_item]["metadata_key"]) - def get_metadata(self, context_key, context_item, use_cache=True): - if use_cache and self.ctx[context_key][context_item]["metadata"]: + def get_metadata(self, context_key, context_item="", use_cache=True): + if context_key == "raw_dataset": + if use_cache and self.raw_dataset.get("metadata", None): + return self.raw_dataset["metadata"] + + metadata = self.storage.get_json(self.raw_dataset["metadata_key"], allow_missing=True) + self.raw_dataset["metadata"] = metadata + return metadata + + + if context_key == "training_dataset": + if use_cache and self.ctx["models"][context_item]["dataset"].get("metadata", None): + return self.ctx["models"][context_item]["dataset"]["metadata"] + + metadata_uri = self.ctx["models"][context_item]["dataset"]["metadata_key"] + metadata = self.storage.get_json(metadata_uri, allow_missing=True) + self.ctx["models"][context_item]["dataset"]["metadata"] = metadata + return metadata + + if use_cache and self.ctx[context_key][context_item].get("metadata", None): return self.ctx[context_key][context_item]["metadata"] metadata_uri = self.ctx[context_key][context_item]["metadata_key"] @@ -488,9 +506,7 @@ def get_metadata(self, context_key, context_item, use_cache=True): self.ctx[context_key][context_item]["metadata"] = metadata return metadata - - - def fetch_metadata(self): + def fetch_all_metadata(self): resources = [ "python_packages", "raw_columns", diff --git a/pkg/workloads/lib/tf_lib.py b/pkg/workloads/lib/tf_lib.py index cdbc2358c7..850915b7e1 100644 --- a/pkg/workloads/lib/tf_lib.py +++ b/pkg/workloads/lib/tf_lib.py @@ -61,7 +61,7 @@ def get_column_tf_types(model_name, ctx, training=True): for column_name in model["feature_columns"]: columnType = ctx.columns[column_name]["type"] if columnType == "unknown": - columnType = ctx.columns[column_name]["metadata"]["type"] + columnType = ctx.get_metadata("columns", column_name)["type"] column_types[column_name] = CORTEX_TYPE_TO_TF_TYPE[columnType] @@ -74,7 +74,7 @@ def get_column_tf_types(model_name, ctx, training=True): for column_name in model["training_columns"]: columnType = ctx.columns[column_name]["type"] if columnType == "unknown": - columnType = ctx.columns[column_name]["metadata"]["type"] + columnType = ctx.get_metadata("columns", column_name)["type"] column_types[column_name] = CORTEX_TYPE_TO_TF_TYPE[columnType] @@ -88,7 +88,7 @@ def get_feature_spec(model_name, ctx, training=True): for column_name, tf_type in column_types.items(): columnType = ctx.columns[column_name]["type"] if columnType == "unknown": - columnType = ctx.columns[column_name]["metadata"]["type"] + columnType = ctx.get_metadata("columns", column_name)["type"] if columnType in consts.COLUMN_LIST_TYPES: feature_spec[column_name] = tf.FixedLenSequenceFeature( diff --git a/pkg/workloads/spark_job/spark_job.py b/pkg/workloads/spark_job/spark_job.py index 96011d276e..986b3cf4f2 100644 --- a/pkg/workloads/spark_job/spark_job.py +++ b/pkg/workloads/spark_job/spark_job.py @@ -91,7 +91,7 @@ def parse_args(args): def validate_dataset(ctx, raw_df, cols_to_validate): - total_row_count = ctx.raw_dataset["metadata"]["dataset_size"] + total_row_count = ctx.get_metadata("raw_dataset")["dataset_size"] conditions_dict = spark_util.value_check_data(ctx, raw_df, cols_to_validate) if len(conditions_dict) > 0: diff --git a/pkg/workloads/tf_api/api.py b/pkg/workloads/tf_api/api.py index 31982a6512..73ad6c0a0d 100644 --- a/pkg/workloads/tf_api/api.py +++ b/pkg/workloads/tf_api/api.py @@ -97,7 +97,7 @@ def create_prediction_request(transformed_sample): for column_name, value in transformed_sample.items(): columnType = ctx.columns[column_name]["type"] if columnType == "unknown": - columnType = ctx.columns[column_name]["metadata"]["type"] + columnType = ctx.get_metadata("columns", column_name)["type"] data_type = tf_lib.CORTEX_TYPE_TO_TF_TYPE[columnType] shape = [1] if util.is_list(value): diff --git a/pkg/workloads/tf_train/train_util.py b/pkg/workloads/tf_train/train_util.py index 3a5a6e1c60..14c7f0c07e 100644 --- a/pkg/workloads/tf_train/train_util.py +++ b/pkg/workloads/tf_train/train_util.py @@ -149,7 +149,7 @@ def train(model_name, model_impl, ctx, model_dir): exporter = tf.estimator.FinalExporter("estimator", serving_input_fn, as_text=False) train_num_steps = model["training"]["num_steps"] - dataset_metadata = model["dataset"]["metadata"] + dataset_metadata = ctx.get_metadata("training_dataset", model_name) if model["training"]["num_epochs"]: train_num_steps = ( math.ceil(dataset_metadata["training_size"] / float(model["training"]["batch_size"])) @@ -161,7 +161,7 @@ def train(model_name, model_impl, ctx, model_dir): eval_num_steps = model["evaluation"]["num_steps"] if model["evaluation"]["num_epochs"]: eval_num_steps = ( - math.ceil(model["metadata"]["eval_size"] / float(model["evaluation"]["batch_size"])) + math.ceil(dataset_metadata["eval_size"] / float(model["evaluation"]["batch_size"])) * model["evaluation"]["num_epochs"] ) From cd38e52c31311e99c1cf4dddb9f72de9ba0694cb Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Wed, 15 May 2019 17:18:45 -0400 Subject: [PATCH 21/48] clean up spark --- pkg/workloads/lib/context.py | 43 ++++----------- pkg/workloads/lib/tf_lib.py | 15 ++---- pkg/workloads/spark_job/spark_job.py | 3 +- pkg/workloads/spark_job/spark_util.py | 76 +++++++++++++-------------- pkg/workloads/tf_api/api.py | 4 +- 5 files changed, 52 insertions(+), 89 deletions(-) diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index 70eab4f70a..a726292392 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -471,9 +471,9 @@ def resource_status_key(self, resource): return os.path.join(self.status_prefix, resource["id"], resource["workload_id"]) def update_metadata(self, metadata, context_key, context_item=""): - if context_item == "": - self.ctx[context_key]["metadata"] = metadata - self.storage.put_json(metadata, self.ctx[context_key]["metadata_key"]) + if context_key == "raw_dataset": + self.raw_dataset["metadata"] = metadata + self.storage.put_json(metadata, self.raw_dataset["metadata_key"]) return self.ctx[context_key][context_item]["metadata"] = metadata @@ -506,37 +506,14 @@ def get_metadata(self, context_key, context_item="", use_cache=True): self.ctx[context_key][context_item]["metadata"] = metadata return metadata - def fetch_all_metadata(self): - resources = [ - "python_packages", - "raw_columns", - "transformed_columns", - "transformers", - "aggregators", - "aggregates", - "constants", - "models", - "apis", - ] + def get_inferred_column_type(self, column_name): + columnType = self.columns[column_name].get("type", None) + if not columnType or columnType == "unknown": + columnType = self.get_metadata("columns", column_name)["type"] + self.columns[column_name]["type"] = columnType + + return columnType - for resource in resources: - for k, v in self.ctx[resource].items(): - metadata = self.storage.get_json(v["metadata_key"], allow_missing=True) - if not metadata: - metadata = {} - self.ctx[resource][k]["metadata"] = metadata - - # fetch dataset metadata for models - for k, v in self.ctx["models"].items(): - metadata = self.storage.get_json(v["dataset"]["metadata_key"], allow_missing=True) - if not metadata: - metadata = {} - self.ctx["models"][k]["dataset"]["metadata"] = metadata - - metadata = self.storage.get_json(self.raw_dataset["metadata_key"], allow_missing=True) - if not metadata: - metadata = {} - self.raw_dataset["metadata"] = metadata MODEL_IMPL_VALIDATION = { diff --git a/pkg/workloads/lib/tf_lib.py b/pkg/workloads/lib/tf_lib.py index 850915b7e1..6c8d14662d 100644 --- a/pkg/workloads/lib/tf_lib.py +++ b/pkg/workloads/lib/tf_lib.py @@ -59,10 +59,7 @@ def get_column_tf_types(model_name, ctx, training=True): column_types = {} for column_name in model["feature_columns"]: - columnType = ctx.columns[column_name]["type"] - if columnType == "unknown": - columnType = ctx.get_metadata("columns", column_name)["type"] - + columnType = ctx.get_inferred_column_type(column_name) column_types[column_name] = CORTEX_TYPE_TO_TF_TYPE[columnType] if training: @@ -72,10 +69,7 @@ def get_column_tf_types(model_name, ctx, training=True): ] for column_name in model["training_columns"]: - columnType = ctx.columns[column_name]["type"] - if columnType == "unknown": - columnType = ctx.get_metadata("columns", column_name)["type"] - + columnType = ctx.get_inferred_column_type(column_name) column_types[column_name] = CORTEX_TYPE_TO_TF_TYPE[columnType] return column_types @@ -86,10 +80,7 @@ def get_feature_spec(model_name, ctx, training=True): column_types = get_column_tf_types(model_name, ctx, training) feature_spec = {} for column_name, tf_type in column_types.items(): - columnType = ctx.columns[column_name]["type"] - if columnType == "unknown": - columnType = ctx.get_metadata("columns", column_name)["type"] - + columnType = ctx.get_inferred_column_type(column_name) if columnType in consts.COLUMN_LIST_TYPES: feature_spec[column_name] = tf.FixedLenSequenceFeature( shape=(), dtype=tf_type, allow_missing=True diff --git a/pkg/workloads/spark_job/spark_job.py b/pkg/workloads/spark_job/spark_job.py index 986b3cf4f2..017412376e 100644 --- a/pkg/workloads/spark_job/spark_job.py +++ b/pkg/workloads/spark_job/spark_job.py @@ -160,8 +160,7 @@ def ingest_raw_dataset(spark, ctx, cols_to_validate, should_ingest): ingest_df = limit_dataset(full_dataset_size, ingest_df, ctx.environment["limit"]) written_count = write_raw_dataset(ingest_df, ctx, spark) - metadata = {"dataset_size": written_count} - ctx.update_metadata(metadata, "raw_dataset") + ctx.update_metadata({"dataset_size": written_count}, "raw_dataset") if written_count != full_dataset_size: logger.info( "{} rows read, {} rows dropped, {} rows ingested".format( diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 53ad1a568c..d231f5bb8c 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -462,65 +462,63 @@ def _transform(*values): if validate: transformed_column = ctx.transformed_columns[column_name] - + columnType = ctx.get_inferred_column_type(column_name) def _transform_and_validate(*values): result = _transform(*values) - if not util.validate_column_type(result, transformed_column["type"]): + if not util.validate_column_type(result, columnType): raise UserException( "transformed column " + column_name, "tranformation " + transformed_column["transformer"], - "type of {} is not {}".format(result, transformed_column["type"]), + "type of {} is not {}".format(result, columnType), ) return result transform_python_func = _transform_and_validate - column_data_type_str = ctx.transformed_columns[column_name]["type"] + column_data_type_str = ctx.get_inferred_column_type(column_name) transform_udf = F.udf(transform_python_func, CORTEX_TYPE_TO_SPARK_TYPE[column_data_type_str]) return df.withColumn(column_name, transform_udf(*required_columns_sorted)) -def validate_transformer(column_name, df, ctx, spark): +def validate_transformer(column_name, test_df, ctx, spark): transformed_column = ctx.transformed_columns[column_name] trans_impl, _ = ctx.get_transformer_impl(column_name) if hasattr(trans_impl, "transform_python"): - sample_df = df.collect() - sample = sample_df[0] - inputs = ctx.create_column_inputs_map(sample, column_name) - _, impl_args = extract_inputs(column_name, ctx) - transformedSample = trans_impl.transform_python(inputs, impl_args) - rowType = type(transformedSample) - isList = rowType == list - - for row in sample_df: - inputs = ctx.create_column_inputs_map(row, column_name) + if transformed_column["transformer_path"]: + sample_df = test_df.collect() + sample = sample_df[0] + inputs = ctx.create_column_inputs_map(sample, column_name) + _, impl_args = extract_inputs(column_name, ctx) transformedSample = trans_impl.transform_python(inputs, impl_args) - if rowType != type(transformedSample): - raise UserRuntimeException( - "transformed column " + column_name, - "type inference failed, mixed data types in dataframe.", - ) - + rowType = type(transformedSample) + isList = (rowType == list) + + for row in sample_df: + inputs = ctx.create_column_inputs_map(row, column_name) + transformedSample = trans_impl.transform_python(inputs, impl_args) + if rowType != type(transformedSample): + raise UserRuntimeException( + "transformed column " + column_name, + "type inference failed, mixed data types in dataframe.", + ) - typeConversionDict = PYTHON_TYPE_TO_CORTEX_TYPE - if isList: - rowType = type(transformedSample[0]) - typeConversionDict = PYTHON_TYPE_TO_CORTEX_LIST_TYPE - # for downstream operations on this job - ctx.transformed_columns[column_name]["type"] = typeConversionDict[rowType] + typeConversionDict = PYTHON_TYPE_TO_CORTEX_TYPE + if isList: + rowType = type(transformedSample[0]) + typeConversionDict = PYTHON_TYPE_TO_CORTEX_LIST_TYPE - # for downstream operations on other jobs - ctx.update_metadata( - {"type": typeConversionDict[rowType]}, "transformed_columns", column_name - ) + # for downstream operations on other jobs + ctx.update_metadata( + {"type": typeConversionDict[rowType]}, "transformed_columns", column_name + ) try: transform_python_collect = execute_transform_python( - column_name, df, ctx, spark, validate=True + column_name, test_df, ctx, spark, validate=True ).collect() except Exception as e: raise UserRuntimeException( @@ -531,7 +529,7 @@ def validate_transformer(column_name, df, ctx, spark): if hasattr(trans_impl, "transform_spark"): try: - transform_spark_df = execute_transform_spark(column_name, df, ctx, spark) + transform_spark_df = execute_transform_spark(column_name, test_df, ctx, spark) # check that the return object is a dataframe if type(transform_spark_df) is not DataFrame: @@ -564,14 +562,14 @@ def validate_transformer(column_name, df, ctx, spark): if ( not transformed_column["transformer_path"] and actual_structfield.dataType - not in CORTEX_TYPE_TO_ACCEPTABLE_SPARK_TYPES[transformed_column["type"]] + not in CORTEX_TYPE_TO_ACCEPTABLE_SPARK_TYPES[ctx.get_inferred_column_type(column_name)] ): raise UserException( "incorrect column type, expected {}, found {}.".format( " or ".join( str(t) for t in CORTEX_TYPE_TO_ACCEPTABLE_SPARK_TYPES[ - transformed_column["type"] + ctx.get_inferred_column_type(column_name) ] ), actual_structfield.dataType, @@ -583,15 +581,15 @@ def validate_transformer(column_name, df, ctx, spark): transform_spark_df = transform_spark_df.withColumn( column_name, F.col(column_name).cast( - CORTEX_TYPE_TO_SPARK_TYPE[ctx.transformed_columns[column_name]["type"]] + CORTEX_TYPE_TO_SPARK_TYPE[ctx.get_inferred_column_type(column_name)] ), ) # check that the function doesn't modify the schema of the other columns in the input dataframe - if set(transform_spark_df.columns) - set([column_name]) != set(df.columns): + if set(transform_spark_df.columns) - set([column_name]) != set(test_df.columns): logger.error("expected schema:") - log_df_schema(df, logger.error) + log_df_schema(test_df, logger.error) logger.error("found schema (with {} dropped):".format(column_name)) log_df_schema(transform_spark_df.drop(column_name), logger.error) @@ -651,7 +649,7 @@ def transform_column(column_name, df, ctx, spark): return execute_transform_spark(column_name, df, ctx, spark).withColumn( column_name, F.col(column_name).cast( - CORTEX_TYPE_TO_SPARK_TYPE[ctx.transformed_columns[column_name]["type"]] + CORTEX_TYPE_TO_SPARK_TYPE[ctx.get_inferred_column_type(column_name)] ), ) elif hasattr(trans_impl, "transform_python"): diff --git a/pkg/workloads/tf_api/api.py b/pkg/workloads/tf_api/api.py index 73ad6c0a0d..93718511eb 100644 --- a/pkg/workloads/tf_api/api.py +++ b/pkg/workloads/tf_api/api.py @@ -95,9 +95,7 @@ def create_prediction_request(transformed_sample): prediction_request.model_spec.signature_name = signature_key for column_name, value in transformed_sample.items(): - columnType = ctx.columns[column_name]["type"] - if columnType == "unknown": - columnType = ctx.get_metadata("columns", column_name)["type"] + columnType = ctx.get_inferred_column_type(column_name) data_type = tf_lib.CORTEX_TYPE_TO_TF_TYPE[columnType] shape = [1] if util.is_list(value): From 909a0018a5a0c5e813e354c04b545e04150f9fcb Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Wed, 15 May 2019 17:55:28 -0400 Subject: [PATCH 22/48] format and lint --- pkg/workloads/lib/context.py | 2 -- pkg/workloads/spark_job/spark_util.py | 25 +++++++++++-------------- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index a726292392..794e03a718 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -488,7 +488,6 @@ def get_metadata(self, context_key, context_item="", use_cache=True): self.raw_dataset["metadata"] = metadata return metadata - if context_key == "training_dataset": if use_cache and self.ctx["models"][context_item]["dataset"].get("metadata", None): return self.ctx["models"][context_item]["dataset"]["metadata"] @@ -515,7 +514,6 @@ def get_inferred_column_type(self, column_name): return columnType - MODEL_IMPL_VALIDATION = { "required": [{"name": "create_estimator", "args": ["run_config", "model_config"]}], "optional": [{"name": "transform_tensorflow", "args": ["features", "labels", "model_config"]}], diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index d231f5bb8c..56a7c067ac 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -463,6 +463,7 @@ def _transform(*values): if validate: transformed_column = ctx.transformed_columns[column_name] columnType = ctx.get_inferred_column_type(column_name) + def _transform_and_validate(*values): result = _transform(*values) if not util.validate_column_type(result, columnType): @@ -494,17 +495,16 @@ def validate_transformer(column_name, test_df, ctx, spark): _, impl_args = extract_inputs(column_name, ctx) transformedSample = trans_impl.transform_python(inputs, impl_args) rowType = type(transformedSample) - isList = (rowType == list) + isList = rowType == list for row in sample_df: inputs = ctx.create_column_inputs_map(row, column_name) transformedSample = trans_impl.transform_python(inputs, impl_args) if rowType != type(transformedSample): raise UserRuntimeException( - "transformed column " + column_name, - "type inference failed, mixed data types in dataframe.", - ) - + "transformed column " + column_name, + "type inference failed, mixed data types in dataframe.", + ) typeConversionDict = PYTHON_TYPE_TO_CORTEX_TYPE if isList: @@ -562,7 +562,9 @@ def validate_transformer(column_name, test_df, ctx, spark): if ( not transformed_column["transformer_path"] and actual_structfield.dataType - not in CORTEX_TYPE_TO_ACCEPTABLE_SPARK_TYPES[ctx.get_inferred_column_type(column_name)] + not in CORTEX_TYPE_TO_ACCEPTABLE_SPARK_TYPES[ + ctx.get_inferred_column_type(column_name) + ] ): raise UserException( "incorrect column type, expected {}, found {}.".format( @@ -637,21 +639,16 @@ def transform_column(column_name, df, ctx, spark): trans_impl, _ = ctx.get_transformer_impl(column_name) if hasattr(trans_impl, "transform_spark"): + column_type = CORTEX_TYPE_TO_SPARK_TYPE[ctx.get_inferred_column_type(column_name)] + df = execute_transform_spark(column_name, df, ctx, spark) if transformed_column["transformer_path"]: - df = execute_transform_spark(column_name, df, ctx, spark) column_type = df.select(column_name).schema[0].dataType # for downstream operations on other jobs ctx.update_metadata( {"type": SPARK_TYPE_TO_CORTEX_TYPE[column_type]}, "transformed_columns", column_name ) - return df.withColumn(column_name, F.col(column_name).cast(column_type)) - return execute_transform_spark(column_name, df, ctx, spark).withColumn( - column_name, - F.col(column_name).cast( - CORTEX_TYPE_TO_SPARK_TYPE[ctx.get_inferred_column_type(column_name)] - ), - ) + return df.withColumn(column_name, F.col(column_name).cast(column_type)) elif hasattr(trans_impl, "transform_python"): return execute_transform_python(column_name, df, ctx, spark) else: From 8c93f4c1e9d92f127adb1f5c367254d8ceb0dfc3 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Wed, 15 May 2019 20:35:06 -0400 Subject: [PATCH 23/48] don't skip cast --- pkg/workloads/spark_job/spark_util.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 56a7c067ac..b84031d8c4 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -579,13 +579,12 @@ def validate_transformer(column_name, test_df, ctx, spark): ) # perform the necessary upcast/downcast for the column e.g INT -> LONG or DOUBLE -> FLOAT - if not transformed_column["transformer_path"]: - transform_spark_df = transform_spark_df.withColumn( - column_name, - F.col(column_name).cast( - CORTEX_TYPE_TO_SPARK_TYPE[ctx.get_inferred_column_type(column_name)] - ), - ) + transform_spark_df = transform_spark_df.withColumn( + column_name, + F.col(column_name).cast( + CORTEX_TYPE_TO_SPARK_TYPE[ctx.get_inferred_column_type(column_name)] + ), + ) # check that the function doesn't modify the schema of the other columns in the input dataframe if set(transform_spark_df.columns) - set([column_name]) != set(test_df.columns): From 106fd584d5e0b4e2545abcde07668106070321a3 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Wed, 15 May 2019 23:12:21 -0400 Subject: [PATCH 24/48] default to None --- pkg/workloads/spark_job/spark_util.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index b84031d8c4..c17a118011 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -488,7 +488,7 @@ def validate_transformer(column_name, test_df, ctx, spark): trans_impl, _ = ctx.get_transformer_impl(column_name) if hasattr(trans_impl, "transform_python"): - if transformed_column["transformer_path"]: + if transformed_column.get("transformer_path", None): sample_df = test_df.collect() sample = sample_df[0] inputs = ctx.create_column_inputs_map(sample, column_name) @@ -560,7 +560,7 @@ def validate_transformer(column_name, test_df, ctx, spark): # check that expected output column has the correct data type if ( - not transformed_column["transformer_path"] + not transformed_column.get("transformer_path", None) and actual_structfield.dataType not in CORTEX_TYPE_TO_ACCEPTABLE_SPARK_TYPES[ ctx.get_inferred_column_type(column_name) @@ -640,7 +640,7 @@ def transform_column(column_name, df, ctx, spark): if hasattr(trans_impl, "transform_spark"): column_type = CORTEX_TYPE_TO_SPARK_TYPE[ctx.get_inferred_column_type(column_name)] df = execute_transform_spark(column_name, df, ctx, spark) - if transformed_column["transformer_path"]: + if transformed_column.get("transformer_path", None): column_type = df.select(column_name).schema[0].dataType # for downstream operations on other jobs ctx.update_metadata( From 28fcdc84fabece27e9e9dd6d74989cd3c83b869f Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Thu, 16 May 2019 11:52:00 -0400 Subject: [PATCH 25/48] fix context --- examples/fraud/resources/weight_column.yaml | 11 ++++++++++- pkg/workloads/lib/context.py | 11 ++++++++--- pkg/workloads/spark_job/spark_job.py | 4 ++-- pkg/workloads/spark_job/spark_util.py | 3 +-- pkg/workloads/tf_train/train_util.py | 2 +- 5 files changed, 22 insertions(+), 9 deletions(-) diff --git a/examples/fraud/resources/weight_column.yaml b/examples/fraud/resources/weight_column.yaml index 34654ab4fb..a6f87aa58b 100644 --- a/examples/fraud/resources/weight_column.yaml +++ b/examples/fraud/resources/weight_column.yaml @@ -5,9 +5,18 @@ columns: col: class +- kind: transformer + name: weight + inputs: + columns: + col: INT_COLUMN + args: + class_distribution: {INT: FLOAT} + output_type: FLOAT_COLUMN + - kind: transformed_column name: weight_column - transformer_path: implementations/transformers/weight.py + transformer: weight inputs: columns: col: class diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index 794e03a718..970a3e001d 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -471,16 +471,21 @@ def resource_status_key(self, resource): return os.path.join(self.status_prefix, resource["id"], resource["workload_id"]) def update_metadata(self, metadata, context_key, context_item=""): - if context_key == "raw_dataset": + if context_key == "raw_datasets": self.raw_dataset["metadata"] = metadata self.storage.put_json(metadata, self.raw_dataset["metadata_key"]) return + if context_key == "training_datasets": + self.ctx["models"][context_item]["dataset"]["metadata"] = metadata + self.storage.put_json(metadata, self.ctx["models"][context_item]["dataset"]["metadata_key"]) + return + self.ctx[context_key][context_item]["metadata"] = metadata self.storage.put_json(metadata, self.ctx[context_key][context_item]["metadata_key"]) def get_metadata(self, context_key, context_item="", use_cache=True): - if context_key == "raw_dataset": + if context_key == "raw_datasets": if use_cache and self.raw_dataset.get("metadata", None): return self.raw_dataset["metadata"] @@ -488,7 +493,7 @@ def get_metadata(self, context_key, context_item="", use_cache=True): self.raw_dataset["metadata"] = metadata return metadata - if context_key == "training_dataset": + if context_key == "training_datasets": if use_cache and self.ctx["models"][context_item]["dataset"].get("metadata", None): return self.ctx["models"][context_item]["dataset"]["metadata"] diff --git a/pkg/workloads/spark_job/spark_job.py b/pkg/workloads/spark_job/spark_job.py index 017412376e..309cea8b6d 100644 --- a/pkg/workloads/spark_job/spark_job.py +++ b/pkg/workloads/spark_job/spark_job.py @@ -91,7 +91,7 @@ def parse_args(args): def validate_dataset(ctx, raw_df, cols_to_validate): - total_row_count = ctx.get_metadata("raw_dataset")["dataset_size"] + total_row_count = ctx.get_metadata("raw_datasets")["dataset_size"] conditions_dict = spark_util.value_check_data(ctx, raw_df, cols_to_validate) if len(conditions_dict) > 0: @@ -160,7 +160,7 @@ def ingest_raw_dataset(spark, ctx, cols_to_validate, should_ingest): ingest_df = limit_dataset(full_dataset_size, ingest_df, ctx.environment["limit"]) written_count = write_raw_dataset(ingest_df, ctx, spark) - ctx.update_metadata({"dataset_size": written_count}, "raw_dataset") + ctx.update_metadata({"dataset_size": written_count}, "raw_datasets") if written_count != full_dataset_size: logger.info( "{} rows read, {} rows dropped, {} rows ingested".format( diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index c17a118011..b966a0af4e 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -119,8 +119,7 @@ def write_training_data(model_name, df, ctx, spark): ctx.storage.hadoop_path(training_dataset["eval_key"]) ) - metadata = {"training_size": train_df_acc.value, "eval_size": eval_df_acc.value} - ctx.update_metadata(metadata, "models", model_name) + ctx.update_metadata({"training_size": train_df_acc.value, "eval_size": eval_df_acc.value}, "training_datasets", model_name) return df diff --git a/pkg/workloads/tf_train/train_util.py b/pkg/workloads/tf_train/train_util.py index 14c7f0c07e..b537fdff7d 100644 --- a/pkg/workloads/tf_train/train_util.py +++ b/pkg/workloads/tf_train/train_util.py @@ -149,7 +149,7 @@ def train(model_name, model_impl, ctx, model_dir): exporter = tf.estimator.FinalExporter("estimator", serving_input_fn, as_text=False) train_num_steps = model["training"]["num_steps"] - dataset_metadata = ctx.get_metadata("training_dataset", model_name) + dataset_metadata = ctx.get_metadata("training_datasets", model_name) if model["training"]["num_epochs"]: train_num_steps = ( math.ceil(dataset_metadata["training_size"] / float(model["training"]["batch_size"])) From c5504ac30e2d44ddec2560bc0db94265d8495091 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Thu, 16 May 2019 11:54:16 -0400 Subject: [PATCH 26/48] format --- pkg/workloads/lib/context.py | 4 +++- pkg/workloads/spark_job/spark_util.py | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index 970a3e001d..cebd203c95 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -478,7 +478,9 @@ def update_metadata(self, metadata, context_key, context_item=""): if context_key == "training_datasets": self.ctx["models"][context_item]["dataset"]["metadata"] = metadata - self.storage.put_json(metadata, self.ctx["models"][context_item]["dataset"]["metadata_key"]) + self.storage.put_json( + metadata, self.ctx["models"][context_item]["dataset"]["metadata_key"] + ) return self.ctx[context_key][context_item]["metadata"] = metadata diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index b966a0af4e..3fdf702381 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -119,7 +119,11 @@ def write_training_data(model_name, df, ctx, spark): ctx.storage.hadoop_path(training_dataset["eval_key"]) ) - ctx.update_metadata({"training_size": train_df_acc.value, "eval_size": eval_df_acc.value}, "training_datasets", model_name) + ctx.update_metadata( + {"training_size": train_df_acc.value, "eval_size": eval_df_acc.value}, + "training_datasets", + model_name, + ) return df From 5271c3f2c6d28b704bbba1b79b3f77b007b3cd33 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Thu, 16 May 2019 12:36:34 -0400 Subject: [PATCH 27/48] fix tests --- pkg/workloads/lib/context.py | 4 ++-- pkg/workloads/spark_job/spark_job.py | 4 ++-- pkg/workloads/spark_job/test/integration/iris_test.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index cebd203c95..6f918df44f 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -471,7 +471,7 @@ def resource_status_key(self, resource): return os.path.join(self.status_prefix, resource["id"], resource["workload_id"]) def update_metadata(self, metadata, context_key, context_item=""): - if context_key == "raw_datasets": + if context_key == "raw_dataset": self.raw_dataset["metadata"] = metadata self.storage.put_json(metadata, self.raw_dataset["metadata_key"]) return @@ -487,7 +487,7 @@ def update_metadata(self, metadata, context_key, context_item=""): self.storage.put_json(metadata, self.ctx[context_key][context_item]["metadata_key"]) def get_metadata(self, context_key, context_item="", use_cache=True): - if context_key == "raw_datasets": + if context_key == "raw_dataset": if use_cache and self.raw_dataset.get("metadata", None): return self.raw_dataset["metadata"] diff --git a/pkg/workloads/spark_job/spark_job.py b/pkg/workloads/spark_job/spark_job.py index 309cea8b6d..017412376e 100644 --- a/pkg/workloads/spark_job/spark_job.py +++ b/pkg/workloads/spark_job/spark_job.py @@ -91,7 +91,7 @@ def parse_args(args): def validate_dataset(ctx, raw_df, cols_to_validate): - total_row_count = ctx.get_metadata("raw_datasets")["dataset_size"] + total_row_count = ctx.get_metadata("raw_dataset")["dataset_size"] conditions_dict = spark_util.value_check_data(ctx, raw_df, cols_to_validate) if len(conditions_dict) > 0: @@ -160,7 +160,7 @@ def ingest_raw_dataset(spark, ctx, cols_to_validate, should_ingest): ingest_df = limit_dataset(full_dataset_size, ingest_df, ctx.environment["limit"]) written_count = write_raw_dataset(ingest_df, ctx, spark) - ctx.update_metadata({"dataset_size": written_count}, "raw_datasets") + ctx.update_metadata({"dataset_size": written_count}, "raw_dataset") if written_count != full_dataset_size: logger.info( "{} rows read, {} rows dropped, {} rows ingested".format( diff --git a/pkg/workloads/spark_job/test/integration/iris_test.py b/pkg/workloads/spark_job/test/integration/iris_test.py index 00d305edc6..0f013d37fa 100644 --- a/pkg/workloads/spark_job/test/integration/iris_test.py +++ b/pkg/workloads/spark_job/test/integration/iris_test.py @@ -77,7 +77,7 @@ def test_simple_end_to_end(spark): raw_df = spark_job.ingest_raw_dataset(spark, ctx, cols_to_validate, should_ingest) assert raw_df.count() == 15 - assert ctx.raw_dataset["metadata"]["dataset_size"] == 15 + assert ctx.get_metadata("raw_dataset")["dataset_size"] == 15 for raw_column_id in cols_to_validate: path = os.path.join(raw_ctx["status_prefix"], raw_column_id, workload_id) status = storage.get_json(str(path)) @@ -117,7 +117,7 @@ def test_simple_end_to_end(spark): status["exist_code"] = "succeeded" dataset = raw_ctx["models"]["dnn"]["dataset"] - metadata = raw_ctx["models"]["dnn"]["metadata"] + metadata = raw_ctx.get_metadata("models", "dnn") assert metadata["training_size"] + metadata["eval_size"] == 15 assert local_storage_path.joinpath(dataset["train_key"], "_SUCCESS").exists() assert local_storage_path.joinpath(dataset["eval_key"], "_SUCCESS").exists() From 19a69bb1dfe653a00e0a4db33db010d297106f07 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Thu, 16 May 2019 12:52:28 -0400 Subject: [PATCH 28/48] fix test --- pkg/workloads/spark_job/test/integration/iris_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/workloads/spark_job/test/integration/iris_test.py b/pkg/workloads/spark_job/test/integration/iris_test.py index 0f013d37fa..24e2e2a77d 100644 --- a/pkg/workloads/spark_job/test/integration/iris_test.py +++ b/pkg/workloads/spark_job/test/integration/iris_test.py @@ -117,7 +117,7 @@ def test_simple_end_to_end(spark): status["exist_code"] = "succeeded" dataset = raw_ctx["models"]["dnn"]["dataset"] - metadata = raw_ctx.get_metadata("models", "dnn") + metadata = ctx.get_metadata("training_datasets", "dnn") assert metadata["training_size"] + metadata["eval_size"] == 15 assert local_storage_path.joinpath(dataset["train_key"], "_SUCCESS").exists() assert local_storage_path.joinpath(dataset["eval_key"], "_SUCCESS").exists() From ad9d8b3ac8dfee68bda5dbe50f3c947c7e03b4c0 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Thu, 16 May 2019 18:34:24 -0400 Subject: [PATCH 29/48] address some python comments --- pkg/workloads/lib/context.py | 4 +-- pkg/workloads/spark_job/spark_util.py | 38 ++++++++++++++++++++------- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index 6f918df44f..3eb3a61861 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -513,8 +513,8 @@ def get_metadata(self, context_key, context_item="", use_cache=True): return metadata def get_inferred_column_type(self, column_name): - columnType = self.columns[column_name].get("type", None) - if not columnType or columnType == "unknown": + columnType = self.columns[column_name].get("type", "unknown") + if columnType == "unknown": columnType = self.get_metadata("columns", column_name)["type"] self.columns[column_name]["type"] = columnType diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 3fdf702381..ae0389c9ec 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -485,6 +485,16 @@ def _transform_and_validate(*values): return df.withColumn(column_name, transform_udf(*required_columns_sorted)) +def infer_type(obj): + objType = type(obj) + typeConversionDict = PYTHON_TYPE_TO_CORTEX_TYPE + isList = objType == list + if isList: + objType = type(obj[0]) + typeConversionDict = PYTHON_TYPE_TO_CORTEX_LIST_TYPE + + return typeConversionDict[objType] + def validate_transformer(column_name, test_df, ctx, spark): transformed_column = ctx.transformed_columns[column_name] @@ -496,27 +506,35 @@ def validate_transformer(column_name, test_df, ctx, spark): sample = sample_df[0] inputs = ctx.create_column_inputs_map(sample, column_name) _, impl_args = extract_inputs(column_name, ctx) - transformedSample = trans_impl.transform_python(inputs, impl_args) - rowType = type(transformedSample) - isList = rowType == list + initial_transformed_sample = trans_impl.transform_python(inputs, impl_args) + expectedType = type(initial_transformed_sample) + isList = expectedType == list for row in sample_df: inputs = ctx.create_column_inputs_map(row, column_name) - transformedSample = trans_impl.transform_python(inputs, impl_args) - if rowType != type(transformedSample): + transformed_sample = trans_impl.transform_python(inputs, impl_args) + if expectedType != type(transformed_sample): raise UserRuntimeException( "transformed column " + column_name, "type inference failed, mixed data types in dataframe.", + "expected type of \"" + transformed_sample + "\" to be " + expectedType, ) - typeConversionDict = PYTHON_TYPE_TO_CORTEX_TYPE - if isList: - rowType = type(transformedSample[0]) - typeConversionDict = PYTHON_TYPE_TO_CORTEX_LIST_TYPE + if isList: + expectedListType = type(initial_transformed_sample[0]) + if expectedListType != type(transformed_sample[0]): + raise UserRuntimeException( + "transformed column " + column_name, + "type inference failed, mixed data types in list column.", + "expected type of \"" + transformed_sample[0] + "\" to be " + expectedListType, + ) + + + inferredCxType = infer_type(initial_transformed_sample) # for downstream operations on other jobs ctx.update_metadata( - {"type": typeConversionDict[rowType]}, "transformed_columns", column_name + {"type": inferredCxType}, "transformed_columns", column_name ) try: From b6d43375cab635123633a1911f0d0a3af08d28cc Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Thu, 16 May 2019 18:36:21 -0400 Subject: [PATCH 30/48] format --- pkg/workloads/spark_job/spark_util.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index ae0389c9ec..8cc9438a76 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -495,6 +495,7 @@ def infer_type(obj): return typeConversionDict[objType] + def validate_transformer(column_name, test_df, ctx, spark): transformed_column = ctx.transformed_columns[column_name] @@ -517,7 +518,7 @@ def validate_transformer(column_name, test_df, ctx, spark): raise UserRuntimeException( "transformed column " + column_name, "type inference failed, mixed data types in dataframe.", - "expected type of \"" + transformed_sample + "\" to be " + expectedType, + 'expected type of "' + transformed_sample + '" to be ' + expectedType, ) if isList: @@ -526,16 +527,16 @@ def validate_transformer(column_name, test_df, ctx, spark): raise UserRuntimeException( "transformed column " + column_name, "type inference failed, mixed data types in list column.", - "expected type of \"" + transformed_sample[0] + "\" to be " + expectedListType, + 'expected type of "' + + transformed_sample[0] + + '" to be ' + + expectedListType, ) - inferredCxType = infer_type(initial_transformed_sample) # for downstream operations on other jobs - ctx.update_metadata( - {"type": inferredCxType}, "transformed_columns", column_name - ) + ctx.update_metadata({"type": inferredCxType}, "transformed_columns", column_name) try: transform_python_collect = execute_transform_python( From 181da9a6cb73674f6621ae00afb45e2c53df1fc2 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Fri, 17 May 2019 11:25:26 -0400 Subject: [PATCH 31/48] add test, remove camel case --- pkg/workloads/spark_job/spark_util.py | 12 ++++++------ pkg/workloads/spark_job/test/unit/spark_util_test.py | 10 ++++++++++ 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 8cc9438a76..d450c79eee 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -508,17 +508,17 @@ def validate_transformer(column_name, test_df, ctx, spark): inputs = ctx.create_column_inputs_map(sample, column_name) _, impl_args = extract_inputs(column_name, ctx) initial_transformed_sample = trans_impl.transform_python(inputs, impl_args) - expectedType = type(initial_transformed_sample) - isList = expectedType == list + expected_type = type(initial_transformed_sample) + isList = expected_type == list for row in sample_df: inputs = ctx.create_column_inputs_map(row, column_name) transformed_sample = trans_impl.transform_python(inputs, impl_args) - if expectedType != type(transformed_sample): + if expected_type != type(transformed_sample): raise UserRuntimeException( "transformed column " + column_name, "type inference failed, mixed data types in dataframe.", - 'expected type of "' + transformed_sample + '" to be ' + expectedType, + 'expected type of "' + transformed_sample + '" to be ' + expected_type, ) if isList: @@ -533,10 +533,10 @@ def validate_transformer(column_name, test_df, ctx, spark): + expectedListType, ) - inferredCxType = infer_type(initial_transformed_sample) + inferred_cx_type = infer_type(initial_transformed_sample) # for downstream operations on other jobs - ctx.update_metadata({"type": inferredCxType}, "transformed_columns", column_name) + ctx.update_metadata({"type": inferred_cx_type}, "transformed_columns", column_name) try: transform_python_collect = execute_transform_python( diff --git a/pkg/workloads/spark_job/test/unit/spark_util_test.py b/pkg/workloads/spark_job/test/unit/spark_util_test.py index 5fe7840e85..b505e1885c 100644 --- a/pkg/workloads/spark_job/test/unit/spark_util_test.py +++ b/pkg/workloads/spark_job/test/unit/spark_util_test.py @@ -14,6 +14,7 @@ import math import spark_util +import consts from lib.exceptions import UserException import pytest @@ -577,3 +578,12 @@ def test_run_builtin_aggregators_error(spark, ctx_obj, get_context): ctx.store_aggregate_result.assert_not_called() ctx.populate_args.assert_called_once_with({"ignoreNulls": "some_constant"}) + +def test_infer_type(): + assert(spark_util.infer_type(1) == consts.COLUMN_TYPE_INT) + assert(spark_util.infer_type(1.0) == consts.COLUMN_TYPE_FLOAT) + assert(spark_util.infer_type("cortex") == consts.COLUMN_TYPE_STRING) + + assert(spark_util.infer_type([1]) == consts.COLUMN_TYPE_INT_LIST) + assert(spark_util.infer_type([1.0]) == consts.COLUMN_TYPE_FLOAT_LIST) + assert(spark_util.infer_type(["cortex"]) == consts.COLUMN_TYPE_STRING_LIST) From 0f7c3630af6f49e80a238c1e5a623f38e88880fc Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Fri, 17 May 2019 12:18:35 -0400 Subject: [PATCH 32/48] clean up type checking logic, formatting --- pkg/workloads/spark_job/spark_util.py | 21 +++---------------- .../spark_job/test/unit/spark_util_test.py | 13 ++++++------ 2 files changed, 10 insertions(+), 24 deletions(-) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index d450c79eee..c5e7dd2575 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -508,35 +508,20 @@ def validate_transformer(column_name, test_df, ctx, spark): inputs = ctx.create_column_inputs_map(sample, column_name) _, impl_args = extract_inputs(column_name, ctx) initial_transformed_sample = trans_impl.transform_python(inputs, impl_args) - expected_type = type(initial_transformed_sample) - isList = expected_type == list + expected_type = infer_type(initial_transformed_sample) for row in sample_df: inputs = ctx.create_column_inputs_map(row, column_name) transformed_sample = trans_impl.transform_python(inputs, impl_args) - if expected_type != type(transformed_sample): + if expected_type != infer_type(transformed_sample): raise UserRuntimeException( "transformed column " + column_name, "type inference failed, mixed data types in dataframe.", 'expected type of "' + transformed_sample + '" to be ' + expected_type, ) - if isList: - expectedListType = type(initial_transformed_sample[0]) - if expectedListType != type(transformed_sample[0]): - raise UserRuntimeException( - "transformed column " + column_name, - "type inference failed, mixed data types in list column.", - 'expected type of "' - + transformed_sample[0] - + '" to be ' - + expectedListType, - ) - - inferred_cx_type = infer_type(initial_transformed_sample) - # for downstream operations on other jobs - ctx.update_metadata({"type": inferred_cx_type}, "transformed_columns", column_name) + ctx.update_metadata({"type": expected_type}, "transformed_columns", column_name) try: transform_python_collect = execute_transform_python( diff --git a/pkg/workloads/spark_job/test/unit/spark_util_test.py b/pkg/workloads/spark_job/test/unit/spark_util_test.py index b505e1885c..7b6b71fda0 100644 --- a/pkg/workloads/spark_job/test/unit/spark_util_test.py +++ b/pkg/workloads/spark_job/test/unit/spark_util_test.py @@ -579,11 +579,12 @@ def test_run_builtin_aggregators_error(spark, ctx_obj, get_context): ctx.store_aggregate_result.assert_not_called() ctx.populate_args.assert_called_once_with({"ignoreNulls": "some_constant"}) + def test_infer_type(): - assert(spark_util.infer_type(1) == consts.COLUMN_TYPE_INT) - assert(spark_util.infer_type(1.0) == consts.COLUMN_TYPE_FLOAT) - assert(spark_util.infer_type("cortex") == consts.COLUMN_TYPE_STRING) + assert spark_util.infer_type(1) == consts.COLUMN_TYPE_INT + assert spark_util.infer_type(1.0) == consts.COLUMN_TYPE_FLOAT + assert spark_util.infer_type("cortex") == consts.COLUMN_TYPE_STRING - assert(spark_util.infer_type([1]) == consts.COLUMN_TYPE_INT_LIST) - assert(spark_util.infer_type([1.0]) == consts.COLUMN_TYPE_FLOAT_LIST) - assert(spark_util.infer_type(["cortex"]) == consts.COLUMN_TYPE_STRING_LIST) + assert spark_util.infer_type([1]) == consts.COLUMN_TYPE_INT_LIST + assert spark_util.infer_type([1.0]) == consts.COLUMN_TYPE_FLOAT_LIST + assert spark_util.infer_type(["cortex"]) == consts.COLUMN_TYPE_STRING_LIST From 076c020d7f57f7e956eede9977bf76d62d7ffc9c Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Fri, 17 May 2019 12:54:20 -0400 Subject: [PATCH 33/48] remove more camel case --- pkg/workloads/lib/tf_lib.py | 12 ++++++------ pkg/workloads/spark_job/spark_util.py | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/pkg/workloads/lib/tf_lib.py b/pkg/workloads/lib/tf_lib.py index 6c8d14662d..9123e69160 100644 --- a/pkg/workloads/lib/tf_lib.py +++ b/pkg/workloads/lib/tf_lib.py @@ -59,8 +59,8 @@ def get_column_tf_types(model_name, ctx, training=True): column_types = {} for column_name in model["feature_columns"]: - columnType = ctx.get_inferred_column_type(column_name) - column_types[column_name] = CORTEX_TYPE_TO_TF_TYPE[columnType] + column_type = ctx.get_inferred_column_type(column_name) + column_types[column_name] = CORTEX_TYPE_TO_TF_TYPE[column_type] if training: target_column_name = model["target_column"] @@ -69,8 +69,8 @@ def get_column_tf_types(model_name, ctx, training=True): ] for column_name in model["training_columns"]: - columnType = ctx.get_inferred_column_type(column_name) - column_types[column_name] = CORTEX_TYPE_TO_TF_TYPE[columnType] + column_type = ctx.get_inferred_column_type(column_name) + column_types[column_name] = CORTEX_TYPE_TO_TF_TYPE[column_type] return column_types @@ -80,8 +80,8 @@ def get_feature_spec(model_name, ctx, training=True): column_types = get_column_tf_types(model_name, ctx, training) feature_spec = {} for column_name, tf_type in column_types.items(): - columnType = ctx.get_inferred_column_type(column_name) - if columnType in consts.COLUMN_LIST_TYPES: + column_type = ctx.get_inferred_column_type(column_name) + if column_type in consts.COLUMN_LIST_TYPES: feature_spec[column_name] = tf.FixedLenSequenceFeature( shape=(), dtype=tf_type, allow_missing=True ) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index c5e7dd2575..8921014ff3 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -486,14 +486,14 @@ def _transform_and_validate(*values): def infer_type(obj): - objType = type(obj) - typeConversionDict = PYTHON_TYPE_TO_CORTEX_TYPE - isList = objType == list + obj_type = type(obj) + type_conversion_dict = PYTHON_TYPE_TO_CORTEX_TYPE + isList = obj_type == list if isList: - objType = type(obj[0]) - typeConversionDict = PYTHON_TYPE_TO_CORTEX_LIST_TYPE + obj_type = type(obj[0]) + type_conversion_dict = PYTHON_TYPE_TO_CORTEX_LIST_TYPE - return typeConversionDict[objType] + return type_conversion_dict[obj_type] def validate_transformer(column_name, test_df, ctx, spark): From 486c0262c3c1f716ed854141cde834a26fe96470 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Fri, 17 May 2019 12:57:47 -0400 Subject: [PATCH 34/48] fix more camel case --- pkg/workloads/spark_job/spark_util.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 8921014ff3..1b01fd3f28 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -465,15 +465,15 @@ def _transform(*values): if validate: transformed_column = ctx.transformed_columns[column_name] - columnType = ctx.get_inferred_column_type(column_name) + column_type = ctx.get_inferred_column_type(column_name) def _transform_and_validate(*values): result = _transform(*values) - if not util.validate_column_type(result, columnType): + if not util.validate_column_type(result, column_type): raise UserException( "transformed column " + column_name, "tranformation " + transformed_column["transformer"], - "type of {} is not {}".format(result, columnType), + "type of {} is not {}".format(result, column_type), ) return result From 79657c42cc43c85dcc3b61daaf2f3b2a60947b42 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Fri, 17 May 2019 14:57:12 -0400 Subject: [PATCH 35/48] address some comments --- pkg/operator/api/userconfig/config.go | 5 ++--- pkg/operator/api/userconfig/errors.go | 18 ------------------ pkg/workloads/lib/context.py | 10 +++++----- pkg/workloads/tf_api/api.py | 4 ++-- 4 files changed, 9 insertions(+), 28 deletions(-) diff --git a/pkg/operator/api/userconfig/config.go b/pkg/operator/api/userconfig/config.go index 2d651b3be4..99a3086b17 100644 --- a/pkg/operator/api/userconfig/config.go +++ b/pkg/operator/api/userconfig/config.go @@ -208,7 +208,7 @@ func (config *Config) Validate(envName string) error { } if aggregate.AggregatorPath != nil && aggregate.Aggregator != "" { - return ErrorMultipleAggregatorSpecified(aggregate) + return errors.Wrap(ErrorSpecifyOnlyOne("aggregator", "aggregator_path"), Identify(aggregate)) } if aggregate.Aggregator != "" && @@ -226,14 +226,13 @@ func (config *Config) Validate(envName string) error { } if transformedColumn.TransformerPath != nil && transformedColumn.Transformer != "" { - return ErrorMultipleTransformerSpecified(transformedColumn) + return errors.Wrap(ErrorSpecifyOnlyOne("transformer", "transformer_path"), Identify(transformedColumn)) } if transformedColumn.Transformer != "" && !strings.Contains(transformedColumn.Transformer, ".") && !slices.HasString(transformerNames, transformedColumn.Transformer) { return errors.Wrap(ErrorUndefinedResource(transformedColumn.Transformer, resource.TransformerType), Identify(transformedColumn), TransformerKey) - } } diff --git a/pkg/operator/api/userconfig/errors.go b/pkg/operator/api/userconfig/errors.go index 90a3735b27..142f311733 100644 --- a/pkg/operator/api/userconfig/errors.go +++ b/pkg/operator/api/userconfig/errors.go @@ -58,8 +58,6 @@ const ( ErrK8sQuantityMustBeInt ErrRegressionTargetType ErrClassificationTargetType - ErrMultipleAggregatorSpecified - ErrMultipleTransformerSpecified ErrSpecifyOnlyOneMissing ) @@ -93,8 +91,6 @@ var errorKinds = []string{ "err_k8s_quantity_must_be_int", "err_regression_target_type", "err_classification_target_type", - "err_multiple_aggregator_specified", - "err_multiple_transformer_specified", "err_specify_only_one_missing", } @@ -390,20 +386,6 @@ func ErrorClassificationTargetType() error { } } -func ErrorMultipleAggregatorSpecified(aggregate *Aggregate) error { - return Error{ - Kind: ErrMultipleAggregatorSpecified, - message: fmt.Sprintf("aggregate \"%s\" specified both \"aggregator\" and \"aggregator_path\", please specify only one", aggregate.Name), - } -} - -func ErrorMultipleTransformerSpecified(transformedColumn *TransformedColumn) error { - return Error{ - Kind: ErrMultipleTransformerSpecified, - message: fmt.Sprintf("transformed_column \"%s\" specified both \"transformer\" and \"transformer_path\", please specify only one", transformedColumn.Name), - } -} - func ErrorSpecifyOnlyOneMissing(vals ...string) error { message := fmt.Sprintf("please specify one of %s", s.UserStrsOr(vals)) if len(vals) == 2 { diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index 3eb3a61861..b541653422 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -513,12 +513,12 @@ def get_metadata(self, context_key, context_item="", use_cache=True): return metadata def get_inferred_column_type(self, column_name): - columnType = self.columns[column_name].get("type", "unknown") - if columnType == "unknown": - columnType = self.get_metadata("columns", column_name)["type"] - self.columns[column_name]["type"] = columnType + column_type = self.columns[column_name].get("type", "unknown") + if column_type == "unknown": + column_type = self.get_metadata("columns", column_name)["type"] + self.columns[column_name]["type"] = column_type - return columnType + return column_type MODEL_IMPL_VALIDATION = { diff --git a/pkg/workloads/tf_api/api.py b/pkg/workloads/tf_api/api.py index 93718511eb..28eb1aab5e 100644 --- a/pkg/workloads/tf_api/api.py +++ b/pkg/workloads/tf_api/api.py @@ -95,8 +95,8 @@ def create_prediction_request(transformed_sample): prediction_request.model_spec.signature_name = signature_key for column_name, value in transformed_sample.items(): - columnType = ctx.get_inferred_column_type(column_name) - data_type = tf_lib.CORTEX_TYPE_TO_TF_TYPE[columnType] + column_Type = ctx.get_inferred_column_type(column_name) + data_type = tf_lib.CORTEX_TYPE_TO_TF_TYPE[column_Type] shape = [1] if util.is_list(value): shape = [len(value)] From 471e98fcca9ef8398393c37048107d57701f7df7 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Sat, 18 May 2019 17:24:22 -0400 Subject: [PATCH 36/48] refactor metadata --- pkg/workloads/lib/context.py | 53 ++++--------------- pkg/workloads/spark_job/spark_job.py | 10 +++- pkg/workloads/spark_job/spark_util.py | 14 +++-- .../spark_job/test/integration/iris_test.py | 7 ++- pkg/workloads/tf_train/train_util.py | 2 +- 5 files changed, 35 insertions(+), 51 deletions(-) diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index b541653422..5884fa11cd 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -80,7 +80,7 @@ def __init__(self, **kwargs): self.models = self.ctx["models"] self.apis = self.ctx["apis"] self.training_datasets = {k: v["dataset"] for k, v in self.models.items()} - + self._metadatas = {} self.api_version = self.cortex_config["api_version"] if "local_storage_path" in kwargs: @@ -103,8 +103,6 @@ def __init__(self, **kwargs): self.raw_columns, self.transformed_columns # self.aggregates ) - self.ctx["columns"] = self.columns - self.values = util.merge_dicts_overwrite(self.aggregates, self.constants) self.raw_column_names = list(self.raw_columns.keys()) @@ -470,52 +468,23 @@ def upload_resource_status_end(self, exit_code, *resources): def resource_status_key(self, resource): return os.path.join(self.status_prefix, resource["id"], resource["workload_id"]) - def update_metadata(self, metadata, context_key, context_item=""): - if context_key == "raw_dataset": - self.raw_dataset["metadata"] = metadata - self.storage.put_json(metadata, self.raw_dataset["metadata_key"]) - return - - if context_key == "training_datasets": - self.ctx["models"][context_item]["dataset"]["metadata"] = metadata - self.storage.put_json( - metadata, self.ctx["models"][context_item]["dataset"]["metadata_key"] - ) - return - - self.ctx[context_key][context_item]["metadata"] = metadata - self.storage.put_json(metadata, self.ctx[context_key][context_item]["metadata_key"]) - - def get_metadata(self, context_key, context_item="", use_cache=True): - if context_key == "raw_dataset": - if use_cache and self.raw_dataset.get("metadata", None): - return self.raw_dataset["metadata"] - - metadata = self.storage.get_json(self.raw_dataset["metadata_key"], allow_missing=True) - self.raw_dataset["metadata"] = metadata - return metadata - - if context_key == "training_datasets": - if use_cache and self.ctx["models"][context_item]["dataset"].get("metadata", None): - return self.ctx["models"][context_item]["dataset"]["metadata"] - - metadata_uri = self.ctx["models"][context_item]["dataset"]["metadata_key"] - metadata = self.storage.get_json(metadata_uri, allow_missing=True) - self.ctx["models"][context_item]["dataset"]["metadata"] = metadata - return metadata + def update_metadata(self, resource_id, metadata_key, metadata): + self._metadatas[resource_id] = metadata + self.storage.put_json(metadata, metadata_key) - if use_cache and self.ctx[context_key][context_item].get("metadata", None): - return self.ctx[context_key][context_item]["metadata"] + def get_metadata(self, resource_id, metadata_key, use_cache=True): + if use_cache and self._metadatas.get(resource_id, None): + return self._metadatas[resource_id] - metadata_uri = self.ctx[context_key][context_item]["metadata_key"] - metadata = self.storage.get_json(metadata_uri, allow_missing=True) - self.ctx[context_key][context_item]["metadata"] = metadata + metadata = self.storage.get_json(metadata_key, allow_missing=True) + self._metadatas[resource_id] = metadata return metadata def get_inferred_column_type(self, column_name): + column = self.columns[column_name] column_type = self.columns[column_name].get("type", "unknown") if column_type == "unknown": - column_type = self.get_metadata("columns", column_name)["type"] + column_type = self.get_metadata(column["id"], column["metadata_key"])["type"] self.columns[column_name]["type"] = column_type return column_type diff --git a/pkg/workloads/spark_job/spark_job.py b/pkg/workloads/spark_job/spark_job.py index 017412376e..2815bc29cb 100644 --- a/pkg/workloads/spark_job/spark_job.py +++ b/pkg/workloads/spark_job/spark_job.py @@ -91,7 +91,9 @@ def parse_args(args): def validate_dataset(ctx, raw_df, cols_to_validate): - total_row_count = ctx.get_metadata("raw_dataset")["dataset_size"] + total_row_count = ctx.get_metadata(ctx.raw_dataset["id"], ctx.raw_dataset["metadata_key"])[ + "dataset_size" + ] conditions_dict = spark_util.value_check_data(ctx, raw_df, cols_to_validate) if len(conditions_dict) > 0: @@ -160,7 +162,11 @@ def ingest_raw_dataset(spark, ctx, cols_to_validate, should_ingest): ingest_df = limit_dataset(full_dataset_size, ingest_df, ctx.environment["limit"]) written_count = write_raw_dataset(ingest_df, ctx, spark) - ctx.update_metadata({"dataset_size": written_count}, "raw_dataset") + ctx.update_metadata( + ctx.raw_dataset["id"], + ctx.raw_dataset["metadata_key"], + {"dataset_size": written_count}, + ) if written_count != full_dataset_size: logger.info( "{} rows read, {} rows dropped, {} rows ingested".format( diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 1b01fd3f28..40590a0dea 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -120,9 +120,9 @@ def write_training_data(model_name, df, ctx, spark): ) ctx.update_metadata( + training_dataset["id"], + training_dataset["metadata_key"], {"training_size": train_df_acc.value, "eval_size": eval_df_acc.value}, - "training_datasets", - model_name, ) return df @@ -521,7 +521,11 @@ def validate_transformer(column_name, test_df, ctx, spark): ) # for downstream operations on other jobs - ctx.update_metadata({"type": expected_type}, "transformed_columns", column_name) + ctx.update_metadata( + transformed_column["id"], + transformed_column["metadata_key"], + {"type": expected_type}, + ) try: transform_python_collect = execute_transform_python( @@ -651,7 +655,9 @@ def transform_column(column_name, df, ctx, spark): column_type = df.select(column_name).schema[0].dataType # for downstream operations on other jobs ctx.update_metadata( - {"type": SPARK_TYPE_TO_CORTEX_TYPE[column_type]}, "transformed_columns", column_name + transformed_column["id"], + transformed_column["metadata_key"], + {"type": SPARK_TYPE_TO_CORTEX_TYPE[column_type]}, ) return df.withColumn(column_name, F.col(column_name).cast(column_type)) diff --git a/pkg/workloads/spark_job/test/integration/iris_test.py b/pkg/workloads/spark_job/test/integration/iris_test.py index 24e2e2a77d..4ff67f1b81 100644 --- a/pkg/workloads/spark_job/test/integration/iris_test.py +++ b/pkg/workloads/spark_job/test/integration/iris_test.py @@ -77,7 +77,10 @@ def test_simple_end_to_end(spark): raw_df = spark_job.ingest_raw_dataset(spark, ctx, cols_to_validate, should_ingest) assert raw_df.count() == 15 - assert ctx.get_metadata("raw_dataset")["dataset_size"] == 15 + assert ( + ctx.get_metadata(ctx.raw_dataset["id"], ctx.raw_dataset["metadata_key"])["dataset_size"] + == 15 + ) for raw_column_id in cols_to_validate: path = os.path.join(raw_ctx["status_prefix"], raw_column_id, workload_id) status = storage.get_json(str(path)) @@ -117,7 +120,7 @@ def test_simple_end_to_end(spark): status["exist_code"] = "succeeded" dataset = raw_ctx["models"]["dnn"]["dataset"] - metadata = ctx.get_metadata("training_datasets", "dnn") + metadata = ctx.get_metadata(dataset["id"], dataset["metadata_key"]) assert metadata["training_size"] + metadata["eval_size"] == 15 assert local_storage_path.joinpath(dataset["train_key"], "_SUCCESS").exists() assert local_storage_path.joinpath(dataset["eval_key"], "_SUCCESS").exists() diff --git a/pkg/workloads/tf_train/train_util.py b/pkg/workloads/tf_train/train_util.py index b537fdff7d..d41a3b4995 100644 --- a/pkg/workloads/tf_train/train_util.py +++ b/pkg/workloads/tf_train/train_util.py @@ -149,7 +149,7 @@ def train(model_name, model_impl, ctx, model_dir): exporter = tf.estimator.FinalExporter("estimator", serving_input_fn, as_text=False) train_num_steps = model["training"]["num_steps"] - dataset_metadata = ctx.get_metadata("training_datasets", model_name) + dataset_metadata = ctx.get_metadata(model["dataset"]["id"], model["dataset"]["metadata_key"]) if model["training"]["num_epochs"]: train_num_steps = ( math.ceil(dataset_metadata["training_size"] / float(model["training"]["batch_size"])) From 2d5602c631e1b4e183278ae47c145ca78a1a892e Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Sat, 18 May 2019 17:26:31 -0400 Subject: [PATCH 37/48] use raw_dataset key --- pkg/workloads/spark_job/spark_job.py | 4 ++-- pkg/workloads/spark_job/test/integration/iris_test.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/workloads/spark_job/spark_job.py b/pkg/workloads/spark_job/spark_job.py index 2815bc29cb..29055182f4 100644 --- a/pkg/workloads/spark_job/spark_job.py +++ b/pkg/workloads/spark_job/spark_job.py @@ -91,7 +91,7 @@ def parse_args(args): def validate_dataset(ctx, raw_df, cols_to_validate): - total_row_count = ctx.get_metadata(ctx.raw_dataset["id"], ctx.raw_dataset["metadata_key"])[ + total_row_count = ctx.get_metadata(ctx.raw_dataset["key"], ctx.raw_dataset["metadata_key"])[ "dataset_size" ] conditions_dict = spark_util.value_check_data(ctx, raw_df, cols_to_validate) @@ -163,7 +163,7 @@ def ingest_raw_dataset(spark, ctx, cols_to_validate, should_ingest): written_count = write_raw_dataset(ingest_df, ctx, spark) ctx.update_metadata( - ctx.raw_dataset["id"], + ctx.raw_dataset["key"], ctx.raw_dataset["metadata_key"], {"dataset_size": written_count}, ) diff --git a/pkg/workloads/spark_job/test/integration/iris_test.py b/pkg/workloads/spark_job/test/integration/iris_test.py index 4ff67f1b81..35eca5e371 100644 --- a/pkg/workloads/spark_job/test/integration/iris_test.py +++ b/pkg/workloads/spark_job/test/integration/iris_test.py @@ -78,7 +78,7 @@ def test_simple_end_to_end(spark): assert raw_df.count() == 15 assert ( - ctx.get_metadata(ctx.raw_dataset["id"], ctx.raw_dataset["metadata_key"])["dataset_size"] + ctx.get_metadata(ctx.raw_dataset["key"], ctx.raw_dataset["metadata_key"])["dataset_size"] == 15 ) for raw_column_id in cols_to_validate: From 8c7fa9b3f3a8941b54d2fad309d6a2efe5687932 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Tue, 21 May 2019 11:49:55 -0400 Subject: [PATCH 38/48] address some comments --- pkg/workloads/lib/context.py | 6 +++--- pkg/workloads/spark_job/spark_job.py | 2 +- pkg/workloads/spark_job/spark_util.py | 28 +++++++++++++-------------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index 5884fa11cd..2c8e3820fd 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -80,7 +80,6 @@ def __init__(self, **kwargs): self.models = self.ctx["models"] self.apis = self.ctx["apis"] self.training_datasets = {k: v["dataset"] for k, v in self.models.items()} - self._metadatas = {} self.api_version = self.cortex_config["api_version"] if "local_storage_path" in kwargs: @@ -113,6 +112,7 @@ def __init__(self, **kwargs): self._transformer_impls = {} self._aggregator_impls = {} self._model_impls = {} + self._metadatas = {} # This affects Tensorflow S3 access os.environ["AWS_REGION"] = self.cortex_config.get("region", "") @@ -468,12 +468,12 @@ def upload_resource_status_end(self, exit_code, *resources): def resource_status_key(self, resource): return os.path.join(self.status_prefix, resource["id"], resource["workload_id"]) - def update_metadata(self, resource_id, metadata_key, metadata): + def write_metadata(self, resource_id, metadata_key, metadata): self._metadatas[resource_id] = metadata self.storage.put_json(metadata, metadata_key) def get_metadata(self, resource_id, metadata_key, use_cache=True): - if use_cache and self._metadatas.get(resource_id, None): + if use_cache and resource_id in self._metadatas: return self._metadatas[resource_id] metadata = self.storage.get_json(metadata_key, allow_missing=True) diff --git a/pkg/workloads/spark_job/spark_job.py b/pkg/workloads/spark_job/spark_job.py index 29055182f4..53247f090d 100644 --- a/pkg/workloads/spark_job/spark_job.py +++ b/pkg/workloads/spark_job/spark_job.py @@ -162,7 +162,7 @@ def ingest_raw_dataset(spark, ctx, cols_to_validate, should_ingest): ingest_df = limit_dataset(full_dataset_size, ingest_df, ctx.environment["limit"]) written_count = write_raw_dataset(ingest_df, ctx, spark) - ctx.update_metadata( + ctx.write_metadata( ctx.raw_dataset["key"], ctx.raw_dataset["metadata_key"], {"dataset_size": written_count}, diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 40590a0dea..f19aa1191e 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -119,7 +119,7 @@ def write_training_data(model_name, df, ctx, spark): ctx.storage.hadoop_path(training_dataset["eval_key"]) ) - ctx.update_metadata( + ctx.write_metadata( training_dataset["id"], training_dataset["metadata_key"], {"training_size": train_df_acc.value, "eval_size": eval_df_acc.value}, @@ -487,22 +487,21 @@ def _transform_and_validate(*values): def infer_type(obj): obj_type = type(obj) - type_conversion_dict = PYTHON_TYPE_TO_CORTEX_TYPE - isList = obj_type == list - if isList: + + if obj_type == list: obj_type = type(obj[0]) - type_conversion_dict = PYTHON_TYPE_TO_CORTEX_LIST_TYPE + return PYTHON_TYPE_TO_CORTEX_LIST_TYPE[obj_type] - return type_conversion_dict[obj_type] + return PYTHON_TYPE_TO_CORTEX_TYPE[obj_type] def validate_transformer(column_name, test_df, ctx, spark): transformed_column = ctx.transformed_columns[column_name] - + transformer = ctx.transformers[transformed_column["transformer"]] trans_impl, _ = ctx.get_transformer_impl(column_name) if hasattr(trans_impl, "transform_python"): - if transformed_column.get("transformer_path", None): + if transformer["output_type"] == "unknown": sample_df = test_df.collect() sample = sample_df[0] inputs = ctx.create_column_inputs_map(sample, column_name) @@ -521,7 +520,7 @@ def validate_transformer(column_name, test_df, ctx, spark): ) # for downstream operations on other jobs - ctx.update_metadata( + ctx.write_metadata( transformed_column["id"], transformed_column["metadata_key"], {"type": expected_type}, @@ -571,8 +570,7 @@ def validate_transformer(column_name, test_df, ctx, spark): # check that expected output column has the correct data type if ( - not transformed_column.get("transformer_path", None) - and actual_structfield.dataType + actual_structfield.dataType not in CORTEX_TYPE_TO_ACCEPTABLE_SPARK_TYPES[ ctx.get_inferred_column_type(column_name) ] @@ -645,16 +643,18 @@ def transform_column(column_name, df, ctx, spark): return df if column_name in df.columns: return df - transformed_column = ctx.transformed_columns[column_name] + transformed_column = ctx.transformed_columns[column_name] + transformer = ctx.transformers[transformed_column["transformer"]] trans_impl, _ = ctx.get_transformer_impl(column_name) + if hasattr(trans_impl, "transform_spark"): column_type = CORTEX_TYPE_TO_SPARK_TYPE[ctx.get_inferred_column_type(column_name)] df = execute_transform_spark(column_name, df, ctx, spark) - if transformed_column.get("transformer_path", None): + if transformer["output_type"] == "unknown": column_type = df.select(column_name).schema[0].dataType # for downstream operations on other jobs - ctx.update_metadata( + ctx.write_metadata( transformed_column["id"], transformed_column["metadata_key"], {"type": SPARK_TYPE_TO_CORTEX_TYPE[column_type]}, From e493684e055e33734c2aba589f41205387f01552 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Tue, 21 May 2019 12:15:26 -0400 Subject: [PATCH 39/48] move type inference to validate_trannsformer --- pkg/workloads/spark_job/spark_util.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index f19aa1191e..b12da8cca1 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -587,6 +587,16 @@ def validate_transformer(column_name, test_df, ctx, spark): ) ) + + if transformer["output_type"] == "unknown": + column_type = transform_spark_df.select(column_name).schema[0].dataType + # for downstream operations on other jobs + ctx.write_metadata( + transformed_column["id"], + transformed_column["metadata_key"], + {"type": SPARK_TYPE_TO_CORTEX_TYPE[column_type]}, + ) + # perform the necessary upcast/downcast for the column e.g INT -> LONG or DOUBLE -> FLOAT transform_spark_df = transform_spark_df.withColumn( column_name, @@ -651,15 +661,6 @@ def transform_column(column_name, df, ctx, spark): if hasattr(trans_impl, "transform_spark"): column_type = CORTEX_TYPE_TO_SPARK_TYPE[ctx.get_inferred_column_type(column_name)] df = execute_transform_spark(column_name, df, ctx, spark) - if transformer["output_type"] == "unknown": - column_type = df.select(column_name).schema[0].dataType - # for downstream operations on other jobs - ctx.write_metadata( - transformed_column["id"], - transformed_column["metadata_key"], - {"type": SPARK_TYPE_TO_CORTEX_TYPE[column_type]}, - ) - return df.withColumn(column_name, F.col(column_name).cast(column_type)) elif hasattr(trans_impl, "transform_python"): return execute_transform_python(column_name, df, ctx, spark) From c1551c0331f3f5f8a0b56907c44dde53e05304b0 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Tue, 21 May 2019 12:33:56 -0400 Subject: [PATCH 40/48] validate inferred types from transform spark and python --- pkg/workloads/spark_job/spark_util.py | 50 +++++++++++++++++---------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index b12da8cca1..8bfad48c99 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -500,6 +500,9 @@ def validate_transformer(column_name, test_df, ctx, spark): transformer = ctx.transformers[transformed_column["transformer"]] trans_impl, _ = ctx.get_transformer_impl(column_name) + inferred_python_type = None + inferred_spark_type = None + if hasattr(trans_impl, "transform_python"): if transformer["output_type"] == "unknown": sample_df = test_df.collect() @@ -507,25 +510,21 @@ def validate_transformer(column_name, test_df, ctx, spark): inputs = ctx.create_column_inputs_map(sample, column_name) _, impl_args = extract_inputs(column_name, ctx) initial_transformed_sample = trans_impl.transform_python(inputs, impl_args) - expected_type = infer_type(initial_transformed_sample) + inferred_python_type = infer_type(initial_transformed_sample) for row in sample_df: inputs = ctx.create_column_inputs_map(row, column_name) transformed_sample = trans_impl.transform_python(inputs, impl_args) - if expected_type != infer_type(transformed_sample): + if inferred_python_type != infer_type(transformed_sample): raise UserRuntimeException( "transformed column " + column_name, "type inference failed, mixed data types in dataframe.", - 'expected type of "' + transformed_sample + '" to be ' + expected_type, + 'expected type of "' + + transformed_sample + + '" to be ' + + inferred_python_type, ) - # for downstream operations on other jobs - ctx.write_metadata( - transformed_column["id"], - transformed_column["metadata_key"], - {"type": expected_type}, - ) - try: transform_python_collect = execute_transform_python( column_name, test_df, ctx, spark, validate=True @@ -587,15 +586,8 @@ def validate_transformer(column_name, test_df, ctx, spark): ) ) - if transformer["output_type"] == "unknown": - column_type = transform_spark_df.select(column_name).schema[0].dataType - # for downstream operations on other jobs - ctx.write_metadata( - transformed_column["id"], - transformed_column["metadata_key"], - {"type": SPARK_TYPE_TO_CORTEX_TYPE[column_type]}, - ) + inferred_spark_type = transform_spark_df.select(column_name).schema[0].dataType # perform the necessary upcast/downcast for the column e.g INT -> LONG or DOUBLE -> FLOAT transform_spark_df = transform_spark_df.withColumn( @@ -647,6 +639,28 @@ def validate_transformer(column_name, test_df, ctx, spark): "{} != {}".format(ts_row, tp_row), ) + if transformer["output_type"] == "unknown": + if ( + inferred_spark_type + and inferred_python_type + and inferred_spark_type != inferred_python_type + ): + raise UserRuntimeException( + "transformed column " + column_name, + "type inference failed, transform_spark and transform_python had differing types.", + "transform_python: " + inferred_python_type, + "transform_spark: " + inferred_spark_type, + ) + + inferred_type = inferred_python_type + if inferred_type == None: + inferred_type = inferred_spark_type + + # for downstream operations on other jobs + ctx.write_metadata( + transformed_column["id"], transformed_column["metadata_key"], {"type": inferred_type} + ) + def transform_column(column_name, df, ctx, spark): if not ctx.is_transformed_column(column_name): From 8abc9dcf4dc39e6d6822785a1fea179027c8bfc2 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Tue, 21 May 2019 14:03:16 -0400 Subject: [PATCH 41/48] pass type downstream --- pkg/workloads/spark_job/spark_util.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 8bfad48c99..982450b30c 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -447,7 +447,7 @@ def execute_transform_spark(column_name, df, ctx, spark): raise UserRuntimeException("function transform_spark") from e -def execute_transform_python(column_name, df, ctx, spark, validate=False): +def execute_transform_python(column_name, df, ctx, spark, validate=False, inferred_type=None): trans_impl, trans_impl_path = ctx.get_transformer_impl(column_name) columns_input_config, impl_args = extract_inputs(column_name, ctx) @@ -465,7 +465,9 @@ def _transform(*values): if validate: transformed_column = ctx.transformed_columns[column_name] - column_type = ctx.get_inferred_column_type(column_name) + column_type = inferred_type + if not column_type: + column_type = ctx.get_inferred_column_type(column_name) def _transform_and_validate(*values): result = _transform(*values) @@ -480,7 +482,10 @@ def _transform_and_validate(*values): transform_python_func = _transform_and_validate - column_data_type_str = ctx.get_inferred_column_type(column_name) + column_data_type_str = inferred_type + if not column_data_type_str: + column_data_type_str = ctx.get_inferred_column_type(column_name) + transform_udf = F.udf(transform_python_func, CORTEX_TYPE_TO_SPARK_TYPE[column_data_type_str]) return df.withColumn(column_name, transform_udf(*required_columns_sorted)) @@ -527,7 +532,7 @@ def validate_transformer(column_name, test_df, ctx, spark): try: transform_python_collect = execute_transform_python( - column_name, test_df, ctx, spark, validate=True + column_name, test_df, ctx, spark, validate=True, inferred_type=inferred_python_type ).collect() except Exception as e: raise UserRuntimeException( @@ -656,7 +661,6 @@ def validate_transformer(column_name, test_df, ctx, spark): if inferred_type == None: inferred_type = inferred_spark_type - # for downstream operations on other jobs ctx.write_metadata( transformed_column["id"], transformed_column["metadata_key"], {"type": inferred_type} ) From b36218e995703b761191285ac05987957678f586 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Tue, 21 May 2019 14:13:50 -0400 Subject: [PATCH 42/48] add comment about transform_python --- pkg/workloads/lib/context.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index 2c8e3820fd..7d0a2edcd8 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -503,6 +503,8 @@ def get_inferred_column_type(self, column_name): "optional": [ {"name": "transform_spark", "args": ["data", "columns", "args", "transformed_column_name"]}, {"name": "reverse_transform_python", "args": ["transformed_value", "args"]}, + # it is possible to not define transform_python() + # if you are only using the transformation for training columns, and not for inference {"name": "transform_python", "args": ["sample", "args"]}, ] } From 5a7bdd35f19cec329407cad5220f44a03ca32e18 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Tue, 21 May 2019 15:29:57 -0400 Subject: [PATCH 43/48] remove unused --- pkg/workloads/spark_job/spark_util.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 982450b30c..299dc27082 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -673,7 +673,6 @@ def transform_column(column_name, df, ctx, spark): return df transformed_column = ctx.transformed_columns[column_name] - transformer = ctx.transformers[transformed_column["transformer"]] trans_impl, _ = ctx.get_transformer_impl(column_name) if hasattr(trans_impl, "transform_spark"): From 2aa5a99ceaa8e994fa1313f07bbb918beeecb6e5 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Tue, 21 May 2019 16:20:32 -0400 Subject: [PATCH 44/48] address comments --- pkg/workloads/lib/context.py | 3 +++ pkg/workloads/spark_job/spark_util.py | 27 +++++++++++---------------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index 7d0a2edcd8..d8286c1c63 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -469,6 +469,9 @@ def resource_status_key(self, resource): return os.path.join(self.status_prefix, resource["id"], resource["workload_id"]) def write_metadata(self, resource_id, metadata_key, metadata): + if resource_id in self._metadatas and self._metadatas[resource_id] == metadata: + return + self._metadatas[resource_id] = metadata self.storage.put_json(metadata, metadata_key) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 299dc27082..6cc52f7026 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -447,7 +447,7 @@ def execute_transform_spark(column_name, df, ctx, spark): raise UserRuntimeException("function transform_spark") from e -def execute_transform_python(column_name, df, ctx, spark, validate=False, inferred_type=None): +def execute_transform_python(column_name, df, ctx, spark, validate=False): trans_impl, trans_impl_path = ctx.get_transformer_impl(column_name) columns_input_config, impl_args = extract_inputs(column_name, ctx) @@ -465,9 +465,7 @@ def _transform(*values): if validate: transformed_column = ctx.transformed_columns[column_name] - column_type = inferred_type - if not column_type: - column_type = ctx.get_inferred_column_type(column_name) + column_type = ctx.get_inferred_column_type(column_name) def _transform_and_validate(*values): result = _transform(*values) @@ -482,10 +480,7 @@ def _transform_and_validate(*values): transform_python_func = _transform_and_validate - column_data_type_str = inferred_type - if not column_data_type_str: - column_data_type_str = ctx.get_inferred_column_type(column_name) - + column_data_type_str = ctx.get_inferred_column_type(column_name) transform_udf = F.udf(transform_python_func, CORTEX_TYPE_TO_SPARK_TYPE[column_data_type_str]) return df.withColumn(column_name, transform_udf(*required_columns_sorted)) @@ -530,9 +525,13 @@ def validate_transformer(column_name, test_df, ctx, spark): + inferred_python_type, ) + ctx.write_metadata( + transformed_column["id"], transformed_column["metadata_key"], {"type": inferred_python_type} + ) + try: transform_python_collect = execute_transform_python( - column_name, test_df, ctx, spark, validate=True, inferred_type=inferred_python_type + column_name, test_df, ctx, spark, validate=True ).collect() except Exception as e: raise UserRuntimeException( @@ -593,6 +592,9 @@ def validate_transformer(column_name, test_df, ctx, spark): if transformer["output_type"] == "unknown": inferred_spark_type = transform_spark_df.select(column_name).schema[0].dataType + ctx.write_metadata( + transformed_column["id"], transformed_column["metadata_key"], {"type": inferred_spark_type} + ) # perform the necessary upcast/downcast for the column e.g INT -> LONG or DOUBLE -> FLOAT transform_spark_df = transform_spark_df.withColumn( @@ -657,13 +659,6 @@ def validate_transformer(column_name, test_df, ctx, spark): "transform_spark: " + inferred_spark_type, ) - inferred_type = inferred_python_type - if inferred_type == None: - inferred_type = inferred_spark_type - - ctx.write_metadata( - transformed_column["id"], transformed_column["metadata_key"], {"type": inferred_type} - ) def transform_column(column_name, df, ctx, spark): From 95d98e12b54a62347dfa0f1a3a29502aa7417f2a Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Tue, 21 May 2019 16:55:01 -0400 Subject: [PATCH 45/48] format --- pkg/workloads/spark_job/spark_util.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 6cc52f7026..88e094d33a 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -526,8 +526,10 @@ def validate_transformer(column_name, test_df, ctx, spark): ) ctx.write_metadata( - transformed_column["id"], transformed_column["metadata_key"], {"type": inferred_python_type} - ) + transformed_column["id"], + transformed_column["metadata_key"], + {"type": inferred_python_type}, + ) try: transform_python_collect = execute_transform_python( @@ -593,7 +595,9 @@ def validate_transformer(column_name, test_df, ctx, spark): if transformer["output_type"] == "unknown": inferred_spark_type = transform_spark_df.select(column_name).schema[0].dataType ctx.write_metadata( - transformed_column["id"], transformed_column["metadata_key"], {"type": inferred_spark_type} + transformed_column["id"], + transformed_column["metadata_key"], + {"type": inferred_spark_type}, ) # perform the necessary upcast/downcast for the column e.g INT -> LONG or DOUBLE -> FLOAT @@ -660,7 +664,6 @@ def validate_transformer(column_name, test_df, ctx, spark): ) - def transform_column(column_name, df, ctx, spark): if not ctx.is_transformed_column(column_name): return df From 694b08d1fe561ca8fb1ce78364dbb54f132cbb8a Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Tue, 21 May 2019 17:23:36 -0400 Subject: [PATCH 46/48] wrap more code in try --- pkg/workloads/spark_job/spark_util.py | 64 +++++++++++++-------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 88e094d33a..1eb012724d 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -503,43 +503,43 @@ def validate_transformer(column_name, test_df, ctx, spark): inferred_python_type = None inferred_spark_type = None - if hasattr(trans_impl, "transform_python"): - if transformer["output_type"] == "unknown": - sample_df = test_df.collect() - sample = sample_df[0] - inputs = ctx.create_column_inputs_map(sample, column_name) - _, impl_args = extract_inputs(column_name, ctx) - initial_transformed_sample = trans_impl.transform_python(inputs, impl_args) - inferred_python_type = infer_type(initial_transformed_sample) - - for row in sample_df: - inputs = ctx.create_column_inputs_map(row, column_name) - transformed_sample = trans_impl.transform_python(inputs, impl_args) - if inferred_python_type != infer_type(transformed_sample): - raise UserRuntimeException( - "transformed column " + column_name, - "type inference failed, mixed data types in dataframe.", - 'expected type of "' - + transformed_sample - + '" to be ' - + inferred_python_type, - ) + try: + if hasattr(trans_impl, "transform_python"): + if transformer["output_type"] == "unknown": + sample_df = test_df.collect() + sample = sample_df[0] + inputs = ctx.create_column_inputs_map(sample, column_name) + _, impl_args = extract_inputs(column_name, ctx) + initial_transformed_sample = trans_impl.transform_python(inputs, impl_args) + inferred_python_type = infer_type(initial_transformed_sample) + + for row in sample_df: + inputs = ctx.create_column_inputs_map(row, column_name) + transformed_sample = trans_impl.transform_python(inputs, impl_args) + if inferred_python_type != infer_type(transformed_sample): + raise UserRuntimeException( + "transformed column " + column_name, + "type inference failed, mixed data types in dataframe.", + 'expected type of "' + + transformed_sample + + '" to be ' + + inferred_python_type, + ) - ctx.write_metadata( - transformed_column["id"], - transformed_column["metadata_key"], - {"type": inferred_python_type}, - ) + ctx.write_metadata( + transformed_column["id"], + transformed_column["metadata_key"], + {"type": inferred_python_type}, + ) - try: transform_python_collect = execute_transform_python( column_name, test_df, ctx, spark, validate=True ).collect() - except Exception as e: - raise UserRuntimeException( - "transformed column " + column_name, - transformed_column["transformer"] + ".transform_python", - ) from e + except Exception as e: + raise UserRuntimeException( + "transformed column " + column_name, + transformed_column["transformer"] + ".transform_python", + ) from e if hasattr(trans_impl, "transform_spark"): From 5b5f0b6b9c8de19a05d2288689dee7727fb224c6 Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Tue, 21 May 2019 17:43:40 -0400 Subject: [PATCH 47/48] move type check before value check --- pkg/workloads/spark_job/spark_util.py | 53 ++++++++++++--------------- 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index 1eb012724d..c9b62477f3 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -627,35 +627,8 @@ def validate_transformer(column_name, test_df, ctx, spark): ) raise - if hasattr(trans_impl, "transform_spark") and hasattr(trans_impl, "transform_python"): - name_type_map = [(s.name, s.dataType) for s in transform_spark_df.schema] - transform_spark_collect = transform_spark_df.collect() - - for tp_row, ts_row in zip(transform_python_collect, transform_spark_collect): - tp_dict = tp_row.asDict() - ts_dict = ts_row.asDict() - - for name, dataType in name_type_map: - if tp_dict[name] == ts_dict[name]: - continue - elif dataType == FloatType() and util.isclose( - tp_dict[name], ts_dict[name], FLOAT_PRECISION - ): - continue - raise UserException( - column_name, - "{0}.transform_spark and {0}.transform_python had differing values".format( - transformed_column["transformer"] - ), - "{} != {}".format(ts_row, tp_row), - ) - - if transformer["output_type"] == "unknown": - if ( - inferred_spark_type - and inferred_python_type - and inferred_spark_type != inferred_python_type - ): + if hasattr(trans_impl, "transform_spark") and hasattr(trans_impl, "transform_python"): + if transformer["output_type"] == "unknown" and inferred_spark_type != inferred_python_type: raise UserRuntimeException( "transformed column " + column_name, "type inference failed, transform_spark and transform_python had differing types.", @@ -663,6 +636,28 @@ def validate_transformer(column_name, test_df, ctx, spark): "transform_spark: " + inferred_spark_type, ) + name_type_map = [(s.name, s.dataType) for s in transform_spark_df.schema] + transform_spark_collect = transform_spark_df.collect() + + for tp_row, ts_row in zip(transform_python_collect, transform_spark_collect): + tp_dict = tp_row.asDict() + ts_dict = ts_row.asDict() + + for name, dataType in name_type_map: + if tp_dict[name] == ts_dict[name]: + continue + elif dataType == FloatType() and util.isclose( + tp_dict[name], ts_dict[name], FLOAT_PRECISION + ): + continue + raise UserException( + column_name, + "{0}.transform_spark and {0}.transform_python had differing values".format( + transformed_column["transformer"] + ), + "{} != {}".format(ts_row, tp_row), + ) + def transform_column(column_name, df, ctx, spark): if not ctx.is_transformed_column(column_name): From b2fd9d7d118f7067599e063e97f2c849819c460a Mon Sep 17 00:00:00 2001 From: Ivan Zhang Date: Tue, 21 May 2019 18:01:46 -0400 Subject: [PATCH 48/48] address comments --- pkg/workloads/spark_job/spark_util.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index c9b62477f3..558eb8b99a 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -503,8 +503,8 @@ def validate_transformer(column_name, test_df, ctx, spark): inferred_python_type = None inferred_spark_type = None - try: - if hasattr(trans_impl, "transform_python"): + if hasattr(trans_impl, "transform_python"): + try: if transformer["output_type"] == "unknown": sample_df = test_df.collect() sample = sample_df[0] @@ -535,14 +535,13 @@ def validate_transformer(column_name, test_df, ctx, spark): transform_python_collect = execute_transform_python( column_name, test_df, ctx, spark, validate=True ).collect() - except Exception as e: - raise UserRuntimeException( - "transformed column " + column_name, - transformed_column["transformer"] + ".transform_python", - ) from e + except Exception as e: + raise UserRuntimeException( + "transformed column " + column_name, + transformed_column["transformer"] + ".transform_python", + ) from e if hasattr(trans_impl, "transform_spark"): - try: transform_spark_df = execute_transform_spark(column_name, test_df, ctx, spark)