Skip to content

Commit ab76a04

Browse files
authored
Updating sst2 tutorial to replace lambda usage (#1722)
Updating sst2 tutorial to replace lambda usage
1 parent 6689502 commit ab76a04

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

examples/tutorials/sst2_classification_non_distributed.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,18 @@
9494
train_datapipe = SST2(split="train")
9595
dev_datapipe = SST2(split="dev")
9696

97+
9798
# Transform the raw dataset using non-batched API (i.e apply transformation line by line)
98-
train_datapipe = train_datapipe.map(lambda x: (text_transform(x[0]), x[1]))
99+
def apply_transform(x):
100+
return text_transform(x[0]), x[1]
101+
102+
103+
train_datapipe = train_datapipe.map(apply_transform)
99104
train_datapipe = train_datapipe.batch(batch_size)
100105
train_datapipe = train_datapipe.rows2columnar(["token_ids", "target"])
101106
train_dataloader = DataLoader(train_datapipe, batch_size=None)
102107

103-
dev_datapipe = dev_datapipe.map(lambda x: (text_transform(x[0]), x[1]))
108+
dev_datapipe = dev_datapipe.map(apply_transform)
104109
dev_datapipe = dev_datapipe.batch(batch_size)
105110
dev_datapipe = dev_datapipe.rows2columnar(["token_ids", "target"])
106111
dev_dataloader = DataLoader(dev_datapipe, batch_size=None)
@@ -111,10 +116,14 @@
111116
#
112117
# ::
113118
#
114-
# train_datapipe = train_datapipe.batch(batch_size).rows2columnar(["text", "label"])
115-
# train_datapipe = train_datapipe.map(lambda x: {"token_ids": text_transform(x["text"]), "target": label_transform(x["label"])})
116-
# dev_datapipe = dev_datapipe.batch(batch_size).rows2columnar(["text", "label"])
117-
# dev_datapipe = dev_datapipe.map(lambda x: {"token_ids": text_transform(x["text"]), "target": label_transform(x["label"])})
119+
# def batch_transform(x):
120+
# return {"token_ids": text_transform(x["text"]), "target": x["label"]}
121+
#
122+
#
123+
# train_datapipe = train_datapipe.batch(batch_size).rows2columnar(["text", "label"])
124+
# train_datapipe = train_datapipe.map(lambda x: batch_transform)
125+
# dev_datapipe = dev_datapipe.batch(batch_size).rows2columnar(["text", "label"])
126+
# dev_datapipe = dev_datapipe.map(lambda x: batch_transform)
118127
#
119128

120129
######################################################################

0 commit comments

Comments
 (0)