Skip to content

Commit 8bbb6ac

Browse files
authored
Fix STSB and WikiTexts tests (#1737)
1 parent e548d3f commit 8bbb6ac

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

test/datasets/test_stsb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def _get_mock_dataset(root_dir):
2121

2222
seed = 1
2323
mocked_data = defaultdict(list)
24-
for file_name, name in zip(["sts-train.csv", "sts-dev.csv" "sts-test.csv"], ["train", "dev", "test"]):
24+
for file_name, name in zip(["sts-train.csv", "sts-dev.csv", "sts-test.csv"], ["train", "dev", "test"]):
2525
txt_file = os.path.join(temp_dataset_dir, file_name)
2626
with open(txt_file, "w", encoding="utf-8") as f:
2727
for i in range(5):

test/datasets/test_wikitexts.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,28 +25,30 @@ def _get_mock_dataset(root_dir, base_dir_name):
2525
file_names = ("wiki.train.tokens", "wiki.valid.tokens", "wiki.test.tokens")
2626
for file_name in file_names:
2727
csv_file = os.path.join(temp_dataset_dir, file_name)
28-
mocked_lines = mocked_data[os.path.splitext(file_name)[0]]
28+
mocked_lines = mocked_data[file_name.split(".")[1]]
2929
with open(csv_file, "w", encoding="utf-8") as f:
3030
for i in range(5):
3131
rand_string = get_random_unicode(seed)
32-
dataset_line = rand_string
33-
f.write(f"{rand_string}\n")
32+
dataset_line = f"{rand_string}\n"
33+
f.write(dataset_line)
3434

3535
# append line to correct dataset split
3636
mocked_lines.append(dataset_line)
3737
seed += 1
3838

3939
if base_dir_name == WikiText103.__name__:
4040
compressed_file = "wikitext-103-v1"
41+
arcname_folder = "wikitext-103"
4142
else:
4243
compressed_file = "wikitext-2-v1"
44+
arcname_folder = "wikitext-2"
4345

4446
compressed_dataset_path = os.path.join(base_dir, compressed_file + ".zip")
4547
# create zip file from dataset folder
4648
with zipfile.ZipFile(compressed_dataset_path, "w") as zip_file:
4749
for file_name in file_names:
4850
txt_file = os.path.join(temp_dataset_dir, file_name)
49-
zip_file.write(txt_file, arcname=compressed_file)
51+
zip_file.write(txt_file, arcname=os.path.join(arcname_folder, file_name))
5052

5153
return mocked_data
5254

0 commit comments

Comments
 (0)