Skip to content

Training dataset resource bug #86

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 26, 2019
12 changes: 11 additions & 1 deletion docs/applications/resources/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,21 @@ Train custom TensorFlow models at scale.
start_delay_secs: <int> # start evaluating after waiting for this many seconds (default: 120)
throttle_secs: <int> # do not re-evaluate unless the last evaluation was started at least this many seconds ago (default: 600)

compute:
compute: # Resources for training and evaluations steps (TensorFlow)
cpu: <string> # CPU request (default: Null)
mem: <string> # memory request (default: Null)
gpu: <string> # GPU request (default: Null)

dataset_compute: # Resources for constructing training dataset (Spark)
executors: <int> # number of spark executors (default: 1)
driver_cpu: <string> # CPU request for spark driver (default: 1)
driver_mem: <string> # memory request for spark driver (default: 500Mi)
driver_mem_overhead: <string> # off-heap (non-JVM) memory allocated to the driver (overrides mem_overhead_factor) (default: min[driver_mem * 0.4, 384Mi])
executor_cpu: <string> # CPU request for each spark executor (default: 1)
executor_mem: <string> # memory request for each spark executor (default: 500Mi)
executor_mem_overhead: <string> # off-heap (non-JVM) memory allocated to each executor (overrides mem_overhead_factor) (default: min[executor_mem * 0.4, 384Mi])
mem_overhead_factor: <float> # the proportion of driver_mem/executor_mem which will be additionally allocated for off-heap (non-JVM) memory (default: 0.4)

tags:
<string>: <scalar> # arbitrary key/value pairs to attach to the resource (optional)
...
Expand Down
2 changes: 1 addition & 1 deletion pkg/operator/api/userconfig/aggregates.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ var aggregateValidation = &cr.StructValidation{
},
},
inputValuesFieldValidation,
sparkComputeFieldValidation,
sparkComputeFieldValidation("Compute"),
tagsFieldValidation,
typeFieldValidation,
},
Expand Down
136 changes: 70 additions & 66 deletions pkg/operator/api/userconfig/compute.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,84 +45,88 @@ type SparkCompute struct {
MemOverheadFactor *float64 `json:"mem_overhead_factor" yaml:"mem_overhead_factor"`
}

var sparkComputeFieldValidation = &cr.StructFieldValidation{
StructField: "Compute",
StructValidation: &cr.StructValidation{
StructFieldValidations: []*cr.StructFieldValidation{
{
StructField: "Executors",
Int32Validation: &cr.Int32Validation{
Default: 1,
GreaterThan: pointer.Int32(0),
},
var sparkComputeStructValidation = &cr.StructValidation{
StructFieldValidations: []*cr.StructFieldValidation{
{
StructField: "Executors",
Int32Validation: &cr.Int32Validation{
Default: 1,
GreaterThan: pointer.Int32(0),
},
{
StructField: "DriverCPU",
StringValidation: &cr.StringValidation{
Default: "1",
},
Parser: QuantityParser(&QuantityValidation{
Min: k8sresource.MustParse("1"),
}),
},
{
StructField: "DriverCPU",
StringValidation: &cr.StringValidation{
Default: "1",
},
{
StructField: "ExecutorCPU",
StringValidation: &cr.StringValidation{
Default: "1",
},
Parser: QuantityParser(&QuantityValidation{
Min: k8sresource.MustParse("1"),
Int: true,
}),
Parser: QuantityParser(&QuantityValidation{
Min: k8sresource.MustParse("1"),
}),
},
{
StructField: "ExecutorCPU",
StringValidation: &cr.StringValidation{
Default: "1",
},
{
StructField: "DriverMem",
StringValidation: &cr.StringValidation{
Default: "500Mi",
},
Parser: QuantityParser(&QuantityValidation{
Min: k8sresource.MustParse("500Mi"),
}),
Parser: QuantityParser(&QuantityValidation{
Min: k8sresource.MustParse("1"),
Int: true,
}),
},
{
StructField: "DriverMem",
StringValidation: &cr.StringValidation{
Default: "500Mi",
},
{
StructField: "ExecutorMem",
StringValidation: &cr.StringValidation{
Default: "500Mi",
},
Parser: QuantityParser(&QuantityValidation{
Min: k8sresource.MustParse("500Mi"),
}),
Parser: QuantityParser(&QuantityValidation{
Min: k8sresource.MustParse("500Mi"),
}),
},
{
StructField: "ExecutorMem",
StringValidation: &cr.StringValidation{
Default: "500Mi",
},
{
StructField: "DriverMemOverhead",
StringPtrValidation: &cr.StringPtrValidation{
Default: nil, // min(DriverMem * 0.4, 384Mi)
},
Parser: QuantityParser(&QuantityValidation{
Min: k8sresource.MustParse("0"),
}),
Parser: QuantityParser(&QuantityValidation{
Min: k8sresource.MustParse("500Mi"),
}),
},
{
StructField: "DriverMemOverhead",
StringPtrValidation: &cr.StringPtrValidation{
Default: nil, // min(DriverMem * 0.4, 384Mi)
},
{
StructField: "ExecutorMemOverhead",
StringPtrValidation: &cr.StringPtrValidation{
Default: nil, // min(ExecutorMem * 0.4, 384Mi)
},
Parser: QuantityParser(&QuantityValidation{
Min: k8sresource.MustParse("0"),
}),
Parser: QuantityParser(&QuantityValidation{
Min: k8sresource.MustParse("0"),
}),
},
{
StructField: "ExecutorMemOverhead",
StringPtrValidation: &cr.StringPtrValidation{
Default: nil, // min(ExecutorMem * 0.4, 384Mi)
},
{
StructField: "MemOverheadFactor",
Float64PtrValidation: &cr.Float64PtrValidation{
Default: nil, // set to 0.4 by Spark
GreaterThanOrEqualTo: pointer.Float64(0),
LessThan: pointer.Float64(1),
},
Parser: QuantityParser(&QuantityValidation{
Min: k8sresource.MustParse("0"),
}),
},
{
StructField: "MemOverheadFactor",
Float64PtrValidation: &cr.Float64PtrValidation{
Default: nil, // set to 0.4 by Spark
GreaterThanOrEqualTo: pointer.Float64(0),
LessThan: pointer.Float64(1),
},
},
},
}

func sparkComputeFieldValidation(fieldName string) *cr.StructFieldValidation {
return &cr.StructFieldValidation{
StructField: fieldName,
StructValidation: sparkComputeStructValidation,
}
}

func (sparkCompute *SparkCompute) ID() string {
var buf bytes.Buffer
buf.WriteString(s.Int32(sparkCompute.Executors))
Expand Down
2 changes: 2 additions & 0 deletions pkg/operator/api/userconfig/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ type Model struct {
Training *ModelTraining `json:"training" yaml:"training"`
Evaluation *ModelEvaluation `json:"evaluation" yaml:"evaluation"`
Compute *TFCompute `json:"compute" yaml:"compute"`
DatasetCompute *SparkCompute `json:"dataset_compute" yaml:"dataset_compute"`
Tags Tags `json:"tags" yaml:"tags"`
}

Expand Down Expand Up @@ -127,6 +128,7 @@ var modelValidation = &cr.StructValidation{
StructValidation: modelEvaluationValidation,
},
tfComputeFieldValidation,
sparkComputeFieldValidation("DatasetCompute"),
tagsFieldValidation,
typeFieldValidation,
},
Expand Down
6 changes: 3 additions & 3 deletions pkg/operator/api/userconfig/raw_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ var rawIntColumnFieldValidations = []*cr.StructFieldValidation{
AllowNull: true,
},
},
sparkComputeFieldValidation,
sparkComputeFieldValidation("Compute"),
tagsFieldValidation,
typeFieldValidation,
}
Expand Down Expand Up @@ -145,7 +145,7 @@ var rawFloatColumnFieldValidations = []*cr.StructFieldValidation{
AllowNull: true,
},
},
sparkComputeFieldValidation,
sparkComputeFieldValidation("Compute"),
tagsFieldValidation,
typeFieldValidation,
}
Expand Down Expand Up @@ -182,7 +182,7 @@ var rawStringColumnFieldValidations = []*cr.StructFieldValidation{
AllowNull: true,
},
},
sparkComputeFieldValidation,
sparkComputeFieldValidation("Compute"),
tagsFieldValidation,
typeFieldValidation,
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/operator/api/userconfig/transformed_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ var transformedColumnValidation = &cr.StructValidation{
},
},
inputValuesFieldValidation,
sparkComputeFieldValidation,
sparkComputeFieldValidation("Compute"),
tagsFieldValidation,
typeFieldValidation,
},
Expand Down
1 change: 1 addition & 0 deletions pkg/operator/workloads/data_job.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ func dataWorkloadSpecs(ctx *context.Context) ([]*WorkloadSpec, error) {
allComputes = append(allComputes, transformedColumn.Compute)
}
}
allComputes = append(allComputes, model.DatasetCompute)
}

resourceIDSet := strset.Union(rawColumnIDs, aggregateIDs, transformedColumnIDs, trainingDatasetIDs)
Expand Down