Skip to content

Commit 6ae9f50

Browse files
committed
Updates the PR to use attribute instead of Env Variable
-Originally AVRO_PARSER_NUM_MINIBATCH was set as an environmental variable. Because tensorflow-io rarely uses env vars to fine tune kernal ops this was changed to an attribute. See comment here: #1283 (comment)
1 parent 4da7742 commit 6ae9f50

File tree

3 files changed

+18
-4
lines changed

3 files changed

+18
-4
lines changed

tensorflow_io/core/kernels/avro/parse_avro_kernels.cc

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ Status ParseAvro(const AvroParserConfig& config,
190190

191191
// This parameter affects performance in a big and data-dependent way.
192192
const size_t kMiniBatchSizeBytes = 50000;
193+
size_t avro_num_minibatches_;
193194

194195
// Calculate number of minibatches.
195196
// In main regime make each minibatch around kMiniBatchSizeBytes bytes.
@@ -206,10 +207,9 @@ Status ParseAvro(const AvroParserConfig& config,
206207
minibatch_bytes = 0;
207208
}
208209
}
209-
if (const char* n_minibatches =
210-
std::getenv("AVRO_PARSER_NUM_MINIBATCHES")) {
211-
VLOG(5) << "Overriding num_minibatches with " << n_minibatches;
212-
result = std::stoi(n_minibatches);
210+
if (avro_num_minibatches_) {
211+
VLOG(5) << "Overriding num_minibatches with " << avro_num_minibatches_;
212+
result = avro_num_minibatches_;
213213
}
214214
// This is to ensure users can control the num minibatches all the way down
215215
// to size of 1(no parallelism).
@@ -406,6 +406,8 @@ class ParseAvroOp : public OpKernel {
406406
OP_REQUIRES_OK(ctx, ctx->GetAttr("sparse_types", &sparse_types_));
407407
OP_REQUIRES_OK(ctx, ctx->GetAttr("dense_types", &dense_types_));
408408
OP_REQUIRES_OK(ctx, ctx->GetAttr("dense_shapes", &dense_shapes_));
409+
OP_REQUIRES_OK(
410+
ctx, ctx->GetAttr("avro_num_minibatches", &avro_num_minibatches_));
409411

410412
OP_REQUIRES_OK(ctx, ctx->GetAttr("sparse_keys", &sparse_keys_));
411413
OP_REQUIRES_OK(ctx, ctx->GetAttr("dense_keys", &dense_keys_));
@@ -419,6 +421,11 @@ class ParseAvroOp : public OpKernel {
419421
dense_shapes_[d].dims() > 1 && dense_shapes_[d].dim_size(0) == -1;
420422
}
421423

424+
// Check that avro_num_minibatches is positive
425+
OP_REQUIRES(ctx, avro_num_minibatches_ >= 0,
426+
errors::InvalidArgument("Need avro_num_minibatches >= 0, got ",
427+
avro_num_minibatches_));
428+
422429
string reader_schema_str;
423430
OP_REQUIRES_OK(ctx, ctx->GetAttr("reader_schema", &reader_schema_str));
424431

@@ -513,6 +520,7 @@ class ParseAvroOp : public OpKernel {
513520
avro::ValidSchema reader_schema_;
514521
size_t num_dense_;
515522
size_t num_sparse_;
523+
int64 avro_num_minibatches_;
516524

517525
private:
518526
std::vector<std::pair<string, DataType>> CreateKeysAndTypes() {

tensorflow_io/core/ops/avro_ops.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ REGISTER_OP("IO>ParseAvro")
8383
.Output("sparse_values: sparse_types")
8484
.Output("sparse_shapes: num_sparse * int64")
8585
.Output("dense_values: dense_types")
86+
.Attr("avro_num_minibatches: int >= 0")
8687
.Attr("num_sparse: int >= 0")
8788
.Attr("reader_schema: string")
8889
.Attr("sparse_keys: list(string) >= 0")
@@ -94,6 +95,7 @@ REGISTER_OP("IO>ParseAvro")
9495
.SetShapeFn([](shape_inference::InferenceContext* c) {
9596
size_t num_dense;
9697
size_t num_sparse;
98+
int64 avro_num_minibatches;
9799
int64 num_sparse_from_user;
98100
std::vector<DataType> sparse_types;
99101
std::vector<DataType> dense_types;
@@ -106,6 +108,8 @@ REGISTER_OP("IO>ParseAvro")
106108
TF_RETURN_IF_ERROR(c->GetAttr("sparse_types", &sparse_types));
107109
TF_RETURN_IF_ERROR(c->GetAttr("dense_types", &dense_types));
108110
TF_RETURN_IF_ERROR(c->GetAttr("dense_shapes", &dense_shapes));
111+
TF_RETURN_IF_ERROR(
112+
c->GetAttr("avro_num_minibatches", &avro_num_minibatches));
109113

110114
TF_RETURN_IF_ERROR(c->GetAttr("sparse_keys", &sparse_keys));
111115
TF_RETURN_IF_ERROR(c->GetAttr("sparse_ranks", &sparse_ranks));

tensorflow_io/core/python/experimental/parse_avro_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def _parse_avro(
130130
dense_defaults=None,
131131
dense_shapes=None,
132132
name=None,
133+
avro_num_minibatches=0,
133134
):
134135
"""Parses Avro records.
135136
@@ -196,6 +197,7 @@ def _parse_avro(
196197
dense_keys=dense_keys,
197198
dense_shapes=dense_shapes,
198199
name=name,
200+
avro_num_minibatches=avro_num_minibatches,
199201
)
200202

201203
(sparse_indices, sparse_values, sparse_shapes, dense_values) = outputs

0 commit comments

Comments
 (0)