Skip to content
12 changes: 11 additions & 1 deletion docs/applications/resources/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,21 @@ Train custom TensorFlow models at scale.
start_delay_secs: <int> # start evaluating after waiting for this many seconds (default: 120)
throttle_secs: <int> # do not re-evaluate unless the last evaluation was started at least this many seconds ago (default: 600)

compute:
compute: # Resources for training and evaluations steps (TensorFlow)
cpu: <string> # CPU request (default: Null)
mem: <string> # memory request (default: Null)
gpu: <string> # GPU request (default: Null)

dataset_compute: # Resources for constructing training dataset (Spark)
executors: <int> # number of spark executors (default: 1)
driver_cpu: <string> # CPU request for spark driver (default: 1)
driver_mem: <string> # memory request for spark driver (default: 500Mi)
driver_mem_overhead: <string> # off-heap (non-JVM) memory allocated to the driver (overrides mem_overhead_factor) (default: min[driver_mem * 0.4, 384Mi])
executor_cpu: <string> # CPU request for each spark executor (default: 1)
executor_mem: <string> # memory request for each spark executor (default: 500Mi)
executor_mem_overhead: <string> # off-heap (non-JVM) memory allocated to each executor (overrides mem_overhead_factor) (default: min[executor_mem * 0.4, 384Mi])
mem_overhead_factor: <float> # the proportion of driver_mem/executor_mem which will be additionally allocated for off-heap (non-JVM) memory (default: 0.4)

tags:
<string>: <scalar> # arbitrary key/value pairs to attach to the resource (optional)
...
Expand Down
2 changes: 1 addition & 1 deletion pkg/operator/api/userconfig/aggregates.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ var aggregateValidation = &cr.StructValidation{
},
},
inputValuesFieldValidation,
sparkComputeFieldValidation,
sparkComputeFieldValidation("Compute"),
tagsFieldValidation,
typeFieldValidation,
},
Expand Down
136 changes: 70 additions & 66 deletions pkg/operator/api/userconfig/compute.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,84 +45,88 @@ type SparkCompute struct {
MemOverheadFactor *float64 `json:"mem_overhead_factor" yaml:"mem_overhead_factor"`
}

var sparkComputeFieldValidation = &cr.StructFieldValidation{
StructField: "Compute",
StructValidation: &cr.StructValidation{
StructFieldValidations: []*cr.StructFieldValidation{
{
StructField: "Executors",
Int32Validation: &cr.Int32Validation{
Default: 1,
GreaterThan: pointer.Int32(0),
},
var sparkComputeStructValidation = &cr.StructValidation{
StructFieldValidations: []*cr.StructFieldValidation{
{
StructField: "Executors",
Int32Validation: &cr.Int32Validation{
Default: 1,
GreaterThan: pointer.Int32(0),
},
{
StructField: "DriverCPU",
StringValidation: &cr.StringValidation{
Default: "1",
},
Parser: QuantityParser(&QuantityValidation{
Min: k8sresource.MustParse("1"),
}),
},
{
StructField: "DriverCPU",
StringValidation: &cr.StringValidation{
Default: "1",
},
{
StructField: "ExecutorCPU",
StringValidation: &cr.StringValidation{
Default: "1",
},
Parser: QuantityParser(&QuantityValidation{
Min: k8sresource.MustParse("1"),
Int: true,
}),
Parser: QuantityParser(&QuantityValidation{
Min: k8sresource.MustParse("1"),
}),
},
{
StructField: "ExecutorCPU",
StringValidation: &cr.StringValidation{
Default: "1",
},
{
StructField: "DriverMem",
StringValidation: &cr.StringValidation{
Default: "500Mi",
},
Parser: QuantityParser(&QuantityValidation{
Min: k8sresource.MustParse("500Mi"),
}),
Parser: QuantityParser(&QuantityValidation{
Min: k8sresource.MustParse("1"),
Int: true,
}),
},
{
StructField: "DriverMem",
StringValidation: &cr.StringValidation{
Default: "500Mi",
},
{
StructField: "ExecutorMem",
StringValidation: &cr.StringValidation{
Default: "500Mi",
},
Parser: QuantityParser(&QuantityValidation{
Min: k8sresource.MustParse("500Mi"),
}),
Parser: QuantityParser(&QuantityValidation{
Min: k8sresource.MustParse("500Mi"),
}),
},
{
StructField: "ExecutorMem",
StringValidation: &cr.StringValidation{
Default: "500Mi",
},
{
StructField: "DriverMemOverhead",
StringPtrValidation: &cr.StringPtrValidation{
Default: nil, // min(DriverMem * 0.4, 384Mi)
},
Parser: QuantityParser(&QuantityValidation{
Min: k8sresource.MustParse("0"),
}),
Parser: QuantityParser(&QuantityValidation{
Min: k8sresource.MustParse("500Mi"),
}),
},
{
StructField: "DriverMemOverhead",
StringPtrValidation: &cr.StringPtrValidation{
Default: nil, // min(DriverMem * 0.4, 384Mi)
},
{
StructField: "ExecutorMemOverhead",
StringPtrValidation: &cr.StringPtrValidation{
Default: nil, // min(ExecutorMem * 0.4, 384Mi)
},
Parser: QuantityParser(&QuantityValidation{
Min: k8sresource.MustParse("0"),
}),
Parser: QuantityParser(&QuantityValidation{
Min: k8sresource.MustParse("0"),
}),
},
{
StructField: "ExecutorMemOverhead",
StringPtrValidation: &cr.StringPtrValidation{
Default: nil, // min(ExecutorMem * 0.4, 384Mi)
},
{
StructField: "MemOverheadFactor",
Float64PtrValidation: &cr.Float64PtrValidation{
Default: nil, // set to 0.4 by Spark
GreaterThanOrEqualTo: pointer.Float64(0),
LessThan: pointer.Float64(1),
},
Parser: QuantityParser(&QuantityValidation{
Min: k8sresource.MustParse("0"),
}),
},
{
StructField: "MemOverheadFactor",
Float64PtrValidation: &cr.Float64PtrValidation{
Default: nil, // set to 0.4 by Spark
GreaterThanOrEqualTo: pointer.Float64(0),
LessThan: pointer.Float64(1),
},
},
},
}

func sparkComputeFieldValidation(fieldName string) *cr.StructFieldValidation {
return &cr.StructFieldValidation{
StructField: fieldName,
StructValidation: sparkComputeStructValidation,
}
}

func (sparkCompute *SparkCompute) ID() string {
var buf bytes.Buffer
buf.WriteString(s.Int32(sparkCompute.Executors))
Expand Down
2 changes: 2 additions & 0 deletions pkg/operator/api/userconfig/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ type Model struct {
Training *ModelTraining `json:"training" yaml:"training"`
Evaluation *ModelEvaluation `json:"evaluation" yaml:"evaluation"`
Compute *TFCompute `json:"compute" yaml:"compute"`
DatasetCompute *SparkCompute `json:"dataset_compute" yaml:"dataset_compute"`
Tags Tags `json:"tags" yaml:"tags"`
}

Expand Down Expand Up @@ -127,6 +128,7 @@ var modelValidation = &cr.StructValidation{
StructValidation: modelEvaluationValidation,
},
tfComputeFieldValidation,
sparkComputeFieldValidation("DatasetCompute"),
tagsFieldValidation,
typeFieldValidation,
},
Expand Down
6 changes: 3 additions & 3 deletions pkg/operator/api/userconfig/raw_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ var rawIntColumnFieldValidations = []*cr.StructFieldValidation{
AllowNull: true,
},
},
sparkComputeFieldValidation,
sparkComputeFieldValidation("Compute"),
tagsFieldValidation,
typeFieldValidation,
}
Expand Down Expand Up @@ -145,7 +145,7 @@ var rawFloatColumnFieldValidations = []*cr.StructFieldValidation{
AllowNull: true,
},
},
sparkComputeFieldValidation,
sparkComputeFieldValidation("Compute"),
tagsFieldValidation,
typeFieldValidation,
}
Expand Down Expand Up @@ -182,7 +182,7 @@ var rawStringColumnFieldValidations = []*cr.StructFieldValidation{
AllowNull: true,
},
},
sparkComputeFieldValidation,
sparkComputeFieldValidation("Compute"),
tagsFieldValidation,
typeFieldValidation,
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/operator/api/userconfig/transformed_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ var transformedColumnValidation = &cr.StructValidation{
},
},
inputValuesFieldValidation,
sparkComputeFieldValidation,
sparkComputeFieldValidation("Compute"),
tagsFieldValidation,
typeFieldValidation,
},
Expand Down
1 change: 1 addition & 0 deletions pkg/operator/workloads/data_job.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ func dataWorkloadSpecs(ctx *context.Context) ([]*WorkloadSpec, error) {
allComputes = append(allComputes, transformedColumn.Compute)
}
}
allComputes = append(allComputes, model.DatasetCompute)
}

resourceIDSet := strset.Union(rawColumnIDs, aggregateIDs, transformedColumnIDs, trainingDatasetIDs)
Expand Down