diff --git a/cli/cmd/init.go b/cli/cmd/init.go index 9ac68e19a3..4acac77d34 100644 --- a/cli/cmd/init.go +++ b/cli/cmd/init.go @@ -91,7 +91,8 @@ func appInitFiles(appName string) map[string]string { # data: # type: csv # path: s3a://my-bucket/data.csv -# skip_header: true +# csv_config: +# header: true # schema: # - column1 # - column2 diff --git a/docs/applications/resources/environments.md b/docs/applications/resources/environments.md index 4774bfcc35..76c6dd786c 100644 --- a/docs/applications/resources/environments.md +++ b/docs/applications/resources/environments.md @@ -20,13 +20,38 @@ Transfer data at scale from data warehouses like S3 into the Cortex cluster. Onc data: type: csv # file type (required) path: s3a:/// # S3 is currently supported (required) - skip_header: # skips a single header line (default: false) drop_null: # drop any rows that contain at least 1 null value (default: false) + csv_config: # optional configuration that can be provided schema: - # raw column names listed in the CSV columns' order (required) ... ``` +#### CSV Config + +To help ingest different styles of CSV files, Cortex supports the parameters listed below. All of these parameters are optional. A description and default values for each parameter can be found in the [PySpark CSV Documentation](https://spark.apache.org/docs/2.4.0/api/python/pyspark.sql.html#pyspark.sql.DataFrameReader.csv). + +```yaml +csv_config: + sep: + encoding: + quote: + escape: + comment: + header: + ignore_leading_white_space: + ignore_trailing_white_space: + null_value: + nan_value: + positive_inf: + negative_inf: + max_columns: + max_chars_per_column: + multiline: + char_to_escape_quote_escaping: + empty_value: +``` + ### Parquet Data Config ```yaml diff --git a/examples/fraud/resources/environments.yaml b/examples/fraud/resources/environments.yaml index 9855fd4c06..bb30dd078f 100644 --- a/examples/fraud/resources/environments.yaml +++ b/examples/fraud/resources/environments.yaml @@ -3,7 +3,8 @@ data: type: csv path: s3a://cortex-examples/fraud.csv - skip_header: true + csv_config: + header: true schema: - time - v1 diff --git a/examples/mnist/resources/environments.yaml b/examples/mnist/resources/environments.yaml index 456ed158fa..660927f011 100644 --- a/examples/mnist/resources/environments.yaml +++ b/examples/mnist/resources/environments.yaml @@ -3,7 +3,8 @@ data: type: csv path: s3a://cortex-examples/mnist.csv - skip_header: true + csv_config: + header: true schema: - image - label diff --git a/pkg/api/context/serialize.go b/pkg/api/context/serialize.go index 378e880271..d041eac962 100644 --- a/pkg/api/context/serialize.go +++ b/pkg/api/context/serialize.go @@ -33,7 +33,7 @@ type RawColumnsTypeSplit struct { } type DataSplit struct { - CsvData *userconfig.CsvData `json:"csv_data"` + CSVData *userconfig.CSVData `json:"csv_data"` ParquetData *userconfig.ParquetData `json:"parquet_data"` } @@ -84,8 +84,8 @@ func (ctx ContextSerial) collectRawColumns() RawColumns { func (ctx Context) splitEnvironment() *DataSplit { var split DataSplit switch typedData := ctx.Environment.Data.(type) { - case *userconfig.CsvData: - split.CsvData = typedData + case *userconfig.CSVData: + split.CSVData = typedData case *userconfig.ParquetData: split.ParquetData = typedData } @@ -94,10 +94,10 @@ func (ctx Context) splitEnvironment() *DataSplit { } func (ctxSerial *ContextSerial) collectEnvironment() (*Environment, error) { - if ctxSerial.DataSplit.ParquetData != nil && ctxSerial.DataSplit.CsvData == nil { + if ctxSerial.DataSplit.ParquetData != nil && ctxSerial.DataSplit.CSVData == nil { ctxSerial.Environment.Data = ctxSerial.DataSplit.ParquetData - } else if ctxSerial.DataSplit.CsvData != nil && ctxSerial.DataSplit.ParquetData == nil { - ctxSerial.Environment.Data = ctxSerial.DataSplit.CsvData + } else if ctxSerial.DataSplit.CSVData != nil && ctxSerial.DataSplit.ParquetData == nil { + ctxSerial.Environment.Data = ctxSerial.DataSplit.CSVData } else { return nil, errors.Wrap(userconfig.ErrorSpecifyOnlyOne("CSV", "PARQUET"), ctxSerial.App.Name, resource.EnvironmentType.String(), userconfig.DataKey) } diff --git a/pkg/api/userconfig/environments.go b/pkg/api/userconfig/environments.go index 6bff10169b..52cc00c696 100644 --- a/pkg/api/userconfig/environments.go +++ b/pkg/api/userconfig/environments.go @@ -89,7 +89,7 @@ var dataValidation = &cr.InterfaceStructValidation{ TypeStructField: "Type", InterfaceStructTypes: map[string]*cr.InterfaceStructType{ "csv": &cr.InterfaceStructType{ - Type: (*CsvData)(nil), + Type: (*CSVData)(nil), StructFieldValidations: csvDataFieldValidations, }, "parquet": &cr.InterfaceStructType{ @@ -99,12 +99,33 @@ var dataValidation = &cr.InterfaceStructValidation{ }, } -type CsvData struct { - Type string `json:"type" yaml:"type"` - Path string `json:"path" yaml:"path"` - Schema []string `json:"schema" yaml:"schema"` - DropNull bool `json:"drop_null" yaml:"drop_null"` - SkipHeader bool `json:"skip_header" yaml:"skip_header"` +type CSVData struct { + Type string `json:"type" yaml:"type"` + Path string `json:"path" yaml:"path"` + Schema []string `json:"schema" yaml:"schema"` + DropNull bool `json:"drop_null" yaml:"drop_null"` + CSVConfig *CSVConfig `json:"csv_config" yaml:"csv_config"` +} + +// SPARK_VERSION dependent +type CSVConfig struct { + Sep *string `json:"sep" yaml:"sep"` + Encoding *string `json:"encoding" yaml:"encoding"` + Quote *string `json:"quote" yaml:"quote"` + Escape *string `json:"escape" yaml:"escape"` + Comment *string `json:"comment" yaml:"comment"` + Header *bool `json:"header" yaml:"header"` + IgnoreLeadingWhiteSpace *bool `json:"ignore_leading_white_space" yaml:"ignore_leading_white_space"` + IgnoreTrailingWhiteSpace *bool `json:"ignore_trailing_white_space" yaml:"ignore_trailing_white_space"` + NullValue *string `json:"null_value" yaml:"null_value"` + NanValue *string `json:"nan_value" yaml:"nan_value"` + PositiveInf *string `json:"positive_inf" yaml:"positive_inf"` + NegativeInf *string `json:"negative_inf" yaml:"negative_inf"` + MaxColumns *int32 `json:"max_columns" yaml:"max_columns"` + MaxCharsPerColumn *int32 `json:"max_chars_per_column" yaml:"max_chars_per_column"` + Multiline *bool `json:"multiline" yaml:"multiline"` + CharToEscapeQuoteEscaping *string `json:"char_to_escape_quote_escaping" yaml:"char_to_escape_quote_escaping"` + EmptyValue *string `json:"empty_value" yaml:"empty_value"` } var csvDataFieldValidations = []*cr.StructFieldValidation{ @@ -127,9 +148,82 @@ var csvDataFieldValidations = []*cr.StructFieldValidation{ }, }, &cr.StructFieldValidation{ - StructField: "SkipHeader", - BoolValidation: &cr.BoolValidation{ - Default: false, + StructField: "CSVConfig", + StructValidation: &cr.StructValidation{ + StructFieldValidations: []*cr.StructFieldValidation{ + &cr.StructFieldValidation{ + StructField: "Sep", + StringPtrValidation: &cr.StringPtrValidation{}, + }, + &cr.StructFieldValidation{ + StructField: "Encoding", + StringPtrValidation: &cr.StringPtrValidation{}, + }, + &cr.StructFieldValidation{ + StructField: "Quote", + StringPtrValidation: &cr.StringPtrValidation{}, + }, + &cr.StructFieldValidation{ + StructField: "Escape", + StringPtrValidation: &cr.StringPtrValidation{}, + }, + &cr.StructFieldValidation{ + StructField: "Comment", + StringPtrValidation: &cr.StringPtrValidation{}, + }, + &cr.StructFieldValidation{ + StructField: "Header", + BoolPtrValidation: &cr.BoolPtrValidation{}, + }, + &cr.StructFieldValidation{ + StructField: "IgnoreLeadingWhiteSpace", + BoolPtrValidation: &cr.BoolPtrValidation{}, + }, + &cr.StructFieldValidation{ + StructField: "IgnoreTrailingWhiteSpace", + BoolPtrValidation: &cr.BoolPtrValidation{}, + }, + &cr.StructFieldValidation{ + StructField: "NullValue", + StringPtrValidation: &cr.StringPtrValidation{}, + }, + &cr.StructFieldValidation{ + StructField: "NanValue", + StringPtrValidation: &cr.StringPtrValidation{}, + }, + &cr.StructFieldValidation{ + StructField: "PositiveInf", + StringPtrValidation: &cr.StringPtrValidation{}, + }, + &cr.StructFieldValidation{ + StructField: "NegativeInf", + StringPtrValidation: &cr.StringPtrValidation{}, + }, + &cr.StructFieldValidation{ + StructField: "MaxColumns", + Int32PtrValidation: &cr.Int32PtrValidation{ + GreaterThan: util.Int32Ptr(0), + }, + }, + &cr.StructFieldValidation{ + StructField: "MaxCharsPerColumn", + Int32PtrValidation: &cr.Int32PtrValidation{ + GreaterThanOrEqualTo: util.Int32Ptr(-1), + }, + }, + &cr.StructFieldValidation{ + StructField: "Multiline", + BoolPtrValidation: &cr.BoolPtrValidation{}, + }, + &cr.StructFieldValidation{ + StructField: "CharToEscapeQuoteEscaping", + StringPtrValidation: &cr.StringPtrValidation{}, + }, + &cr.StructFieldValidation{ + StructField: "EmptyValue", + StringPtrValidation: &cr.StringPtrValidation{}, + }, + }, }, }, } @@ -212,7 +306,7 @@ func (env *Environment) Validate() error { return nil } -func (csvData *CsvData) Validate() error { +func (csvData *CSVData) Validate() error { return nil } @@ -220,7 +314,7 @@ func (parqData *ParquetData) Validate() error { return nil } -func (csvData *CsvData) GetExternalPath() string { +func (csvData *CSVData) GetExternalPath() string { return csvData.Path } @@ -228,7 +322,7 @@ func (parqData *ParquetData) GetExternalPath() string { return parqData.Path } -func (csvData *CsvData) GetIngestedColumns() []string { +func (csvData *CSVData) GetIngestedColumns() []string { return csvData.Schema } diff --git a/pkg/operator/context/environment.go b/pkg/operator/context/environment.go index dff1e70ccd..311ce0a2a2 100644 --- a/pkg/operator/context/environment.go +++ b/pkg/operator/context/environment.go @@ -44,7 +44,7 @@ func dataID(config *userconfig.Config, datasetVersion string) string { data := config.Environment.Data switch typedData := data.(type) { - case *userconfig.CsvData: + case *userconfig.CSVData: buf.WriteString(s.Obj(typedData)) case *userconfig.ParquetData: buf.WriteString(typedData.Type) diff --git a/pkg/workloads/lib/test/util_test.py b/pkg/workloads/lib/test/util_test.py index 51ed21b133..d9a75a903f 100644 --- a/pkg/workloads/lib/test/util_test.py +++ b/pkg/workloads/lib/test/util_test.py @@ -21,6 +21,15 @@ import logging +def test_snake_to_camel(): + assert util.snake_to_camel("ONE_TWO_THREE") == "oneTwoThree" + assert util.snake_to_camel("ONE_TWO_THREE", lower=False) == "OneTwoThree" + assert util.snake_to_camel("ONE_TWO_THREE", sep="-") == "one_two_three" + assert util.snake_to_camel("ONE-TWO-THREE", sep="-") == "oneTwoThree" + assert util.snake_to_camel("ONE") == "one" + assert util.snake_to_camel("ONE", lower=False) == "One" + + def test_flatten_all_values(): obj = "v" expected = ["v"] diff --git a/pkg/workloads/lib/util.py b/pkg/workloads/lib/util.py index f24504d9b3..90977cd1db 100644 --- a/pkg/workloads/lib/util.py +++ b/pkg/workloads/lib/util.py @@ -88,6 +88,16 @@ def pluralize(num, singular, plural): return str(num) + " " + plural +def snake_to_camel(input, sep="_", lower=True): + output = "" + for idx, word in enumerate(input.lower().split(sep)): + if idx == 0 and lower: + output += word + else: + output += word[0].upper() + word[1:] + return output + + def mkdir_p(dir_path): try: os.makedirs(dir_path) diff --git a/pkg/workloads/spark_job/spark_util.py b/pkg/workloads/spark_job/spark_util.py index a844df271e..0c563578c0 100644 --- a/pkg/workloads/spark_job/spark_util.py +++ b/pkg/workloads/spark_job/spark_util.py @@ -239,10 +239,16 @@ def ingest(ctx, spark): def read_csv(ctx, spark): - csv_config = ctx.environment["data"] + data_config = ctx.environment["data"] schema = expected_schema_from_context(ctx) - header = csv_config.get("skip_header", False) - return spark.read.csv(csv_config["path"], header=header, schema=schema, mode="FAILFAST") + + csv_config = { + util.snake_to_camel(param_name): val + for param_name, val in data_config.get("csv_config", {}).items() + if val is not None + } + + return spark.read.csv(data_config["path"], schema=schema, mode="FAILFAST", **csv_config) def read_parquet(ctx, spark): diff --git a/pkg/workloads/spark_job/test/spark_util_test.py b/pkg/workloads/spark_job/test/spark_util_test.py index e1ee251111..8eaecc4e7f 100644 --- a/pkg/workloads/spark_job/test/spark_util_test.py +++ b/pkg/workloads/spark_job/test/spark_util_test.py @@ -11,19 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import pytest +import math import spark_util +from lib.exceptions import UserException +import pytest from pyspark.sql.types import * from pyspark.sql import Row +import pyspark.sql.functions as F +from mock import MagicMock, call from py4j.protocol import Py4JJavaError + pytestmark = pytest.mark.usefixtures("spark") -import pyspark.sql.functions as F -from mock import MagicMock, call -from lib.exceptions import UserException def test_compare_column_schemas(): @@ -164,6 +165,55 @@ def test_read_csv_invalid_type(spark, write_csv_file, ctx_obj, get_context): spark_util.ingest(get_context(ctx_obj), spark).collect() +def test_read_csv_valid_options(spark, write_csv_file, ctx_obj, get_context): + csv_str = "\n".join( + [ + "a_str|b_float|c_long", + " a |1|", + "|NaN|1", + '"""weird"" having a | inside the string"|-Infini|NULL', + ] + ) + path_to_file = write_csv_file(csv_str) + + ctx_obj["environment"] = { + "data": { + "type": "csv", + "path": path_to_file, + "schema": ["a_str", "b_float", "c_long"], + "csv_config": { + "header": True, + "sep": "|", + "ignore_leading_white_space": False, + "ignore_trailing_white_space": False, + "nan_value": "NaN", + "escape": '"', + "negative_inf": "-Infini", + "null_value": "NULL", + }, + } + } + + ctx_obj["raw_columns"] = { + "a_str": {"name": "a_str", "type": "STRING_COLUMN", "required": True, "id": "-"}, + "b_float": {"name": "b_float", "type": "FLOAT_COLUMN", "required": True, "id": "-"}, + "c_long": {"name": "c_long", "type": "INT_COLUMN", "required": False, "id": "-"}, + } + + actual_results = spark_util.read_csv(get_context(ctx_obj), spark).collect() + + assert len(actual_results) == 3 + assert actual_results[0] == Row(a_str=" a ", b_float=float(1), c_long=None) + assert actual_results[1].a_str == None + assert math.isnan( + actual_results[1].b_float + ) # nan != nan so a row-wise comparison can't be done + assert actual_results[1].c_long == 1 + assert actual_results[2] == Row( + a_str='"weird" having a | inside the string', b_float=float("-Inf"), c_long=None + ) + + def test_value_checker_required(): raw_column_config = {"name": "a_str", "type": "STRING_COLUMN", "required": True} results = list(spark_util.value_checker(raw_column_config))