|
94 | 94 | train_datapipe = SST2(split="train")
|
95 | 95 | dev_datapipe = SST2(split="dev")
|
96 | 96 |
|
| 97 | + |
97 | 98 | # Transform the raw dataset using non-batched API (i.e apply transformation line by line)
|
98 |
| -train_datapipe = train_datapipe.map(lambda x: (text_transform(x[0]), x[1])) |
| 99 | +def apply_transform(x): |
| 100 | + return text_transform(x[0]), x[1] |
| 101 | + |
| 102 | + |
| 103 | +train_datapipe = train_datapipe.map(apply_transform) |
99 | 104 | train_datapipe = train_datapipe.batch(batch_size)
|
100 | 105 | train_datapipe = train_datapipe.rows2columnar(["token_ids", "target"])
|
101 | 106 | train_dataloader = DataLoader(train_datapipe, batch_size=None)
|
102 | 107 |
|
103 |
| -dev_datapipe = dev_datapipe.map(lambda x: (text_transform(x[0]), x[1])) |
| 108 | +dev_datapipe = dev_datapipe.map(apply_transform) |
104 | 109 | dev_datapipe = dev_datapipe.batch(batch_size)
|
105 | 110 | dev_datapipe = dev_datapipe.rows2columnar(["token_ids", "target"])
|
106 | 111 | dev_dataloader = DataLoader(dev_datapipe, batch_size=None)
|
|
111 | 116 | #
|
112 | 117 | # ::
|
113 | 118 | #
|
114 |
| -# train_datapipe = train_datapipe.batch(batch_size).rows2columnar(["text", "label"]) |
115 |
| -# train_datapipe = train_datapipe.map(lambda x: {"token_ids": text_transform(x["text"]), "target": label_transform(x["label"])}) |
116 |
| -# dev_datapipe = dev_datapipe.batch(batch_size).rows2columnar(["text", "label"]) |
117 |
| -# dev_datapipe = dev_datapipe.map(lambda x: {"token_ids": text_transform(x["text"]), "target": label_transform(x["label"])}) |
| 119 | +# def batch_transform(x): |
| 120 | +# return {"token_ids": text_transform(x["text"]), "target": x["label"]} |
| 121 | +# |
| 122 | +# |
| 123 | +# train_datapipe = train_datapipe.batch(batch_size).rows2columnar(["text", "label"]) |
| 124 | +# train_datapipe = train_datapipe.map(lambda x: batch_transform) |
| 125 | +# dev_datapipe = dev_datapipe.batch(batch_size).rows2columnar(["text", "label"]) |
| 126 | +# dev_datapipe = dev_datapipe.map(lambda x: batch_transform) |
118 | 127 | #
|
119 | 128 |
|
120 | 129 | ######################################################################
|
|
0 commit comments