Skip to content

Commit f17d093

Browse files
committed
add more tests
1 parent f94ee1d commit f17d093

File tree

2 files changed

+85
-40
lines changed

2 files changed

+85
-40
lines changed

test/nodes/test_csv_dataloader.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,9 @@
77
import csv
88
import os
99
import tempfile
10-
from typing import Any, Dict, List, Union
1110

1211
from parameterized import parameterized
1312
from torch.testing._internal.common_utils import TestCase
14-
from torchdata.nodes.base_node import BaseNode
1513

1614
from torchdata.nodes.csv_reader import CSVReader
1715

@@ -61,6 +59,7 @@ def test_basic_read_dict(self):
6159
self.assertEqual(len(results), len(self.test_data) - 1)
6260
self.assertEqual(results[0], {"name": "Alice", "age": "30", "city": "New York"})
6361
self.assertEqual(results[1]["city"], "London")
62+
self.assertEqual(results[-1]["city"], "Bogota")
6463
node.close()
6564

6665
def test_different_delimiters(self):
@@ -70,25 +69,27 @@ def test_different_delimiters(self):
7069

7170
self.assertEqual(len(results), len(self.test_data) - 1)
7271
self.assertEqual(results[2]["city"], "Paris")
72+
self.assertEqual(results[-1]["city"], "Bogota")
7373
node.close()
7474

7575
def test_state_management(self):
7676
path = self._create_temp_csv()
7777
node = CSVReader(path, has_header=True, return_dict=True)
78-
78+
print(f"initial state: {node.state_dict()}")
7979
for _ in range(11):
80-
next(node)
80+
_ = next(node)
81+
print(f"element = {_}, state: {node.state_dict()}")
8182

8283
state = node.state_dict()
8384

8485
node.reset(state)
85-
8686
item = next(node)
87+
8788
with self.assertRaises(StopIteration):
8889
next(node)
8990

9091
self.assertEqual(item["name"], "Lily")
91-
self.assertEqual(state[CSVReader.LINE_NUM_KEY], 11)
92+
self.assertEqual(state[CSVReader.NUM_LINES_YIELDED], 11)
9293
node.close()
9394

9495
@parameterized.expand([3, 5, 7])
@@ -98,6 +99,28 @@ def test_save_load_state(self, midpoint: int):
9899
run_test_save_load_state(self, node, midpoint)
99100
node.close()
100101

102+
def test_load_wrong_state(self):
103+
path = self._create_temp_csv(header=True)
104+
node = CSVReader(path, has_header=True)
105+
106+
state = node.state_dict()
107+
state[CSVReader.HEADER_KEY] = None
108+
with self.assertRaisesRegex(ValueError, "Check if has_header=True matches the state header=None"):
109+
node.reset(state)
110+
111+
node.close()
112+
113+
node = CSVReader(path, has_header=False)
114+
state = node.state_dict()
115+
state[CSVReader.HEADER_KEY] = ["name", "age"]
116+
with self.assertRaisesRegex(
117+
ValueError,
118+
r"Check if has_header=False matches the state header=\['name', 'age'\]",
119+
):
120+
node.reset(state)
121+
122+
node.close()
123+
101124
def test_empty_file(self):
102125
path = self._create_temp_csv()
103126
# Overwrite with empty file

torchdata/nodes/csv_reader.py

Lines changed: 56 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import csv
2+
from itertools import islice
23
from typing import Any, Dict, Iterator, List, Optional, Sequence, TextIO, Union
34

45
from torchdata.nodes.base_node import BaseNode
@@ -13,7 +14,7 @@ class CSVReader(BaseNode[Union[List[str], Dict[str, str]]]):
1314
return_dict: Return rows as dictionaries (requires has_header=True)
1415
"""
1516

16-
LINE_NUM_KEY = "line_num"
17+
NUM_LINES_YIELDED = "num_lines_yielded"
1718
HEADER_KEY = "header"
1819

1920
def __init__(
@@ -22,6 +23,7 @@ def __init__(
2223
has_header: bool = False,
2324
delimiter: str = ",",
2425
return_dict: bool = False,
26+
encoding: str = "utf-8",
2527
):
2628
super().__init__()
2729
self.file_path = file_path
@@ -30,64 +32,84 @@ def __init__(
3032
self.return_dict = return_dict
3133
if return_dict and not has_header:
3234
raise ValueError("return_dict=True requires has_header=True")
35+
self.encoding = encoding
3336
self._file: Optional[TextIO] = None
3437
self._reader: Optional[Iterator[Union[List[str], Dict[str, str]]]] = None
3538
self._header: Optional[Sequence[str]] = None
36-
self._line_num: int = 0
39+
self._num_lines_yielded: int = 0
3740
self.reset() # Initialize reader
3841

3942
def reset(self, initial_state: Optional[Dict[str, Any]] = None):
40-
super().reset(initial_state)
41-
42-
if self._file and not self._file.closed:
43-
self._file.close()
43+
super().reset()
44+
self.close()
45+
46+
# Reopen the file and reset counters
47+
self._file = open(self.file_path, encoding=self.encoding)
48+
self._num_lines_yielded = 0
49+
if initial_state is not None:
50+
self._handle_initial_state(initial_state)
51+
else:
52+
self._initialize_reader()
4453

45-
self._file = open(self.file_path, newline="", encoding="utf-8")
46-
self._line_num = 0
54+
def _handle_initial_state(self, state: Dict[str, Any]):
55+
"""Restore reader state from checkpoint."""
56+
# Validate header compatibility
57+
if (not self.has_header and self.HEADER_KEY in state) or (self.has_header and state[self.HEADER_KEY] is None):
58+
raise ValueError(f"Check if has_header={self.has_header} matches the state header={state[self.HEADER_KEY]}")
4759

48-
if initial_state:
49-
self._header = initial_state.get(self.HEADER_KEY)
50-
target_line_num = initial_state[self.LINE_NUM_KEY]
60+
self._header = state.get(self.HEADER_KEY)
61+
target_line_num = state[self.NUM_LINES_YIELDED]
62+
assert self._file is not None
63+
# Create appropriate reader
64+
if self.return_dict:
5165

52-
if self.return_dict:
53-
if self._header is None:
54-
raise ValueError("return_dict=True requires has_header=True")
55-
self._reader = csv.DictReader(self._file, delimiter=self.delimiter, fieldnames=self._header)
56-
else:
57-
self._reader = csv.reader(self._file, delimiter=self.delimiter)
66+
self._reader = csv.DictReader(self._file, delimiter=self.delimiter, fieldnames=self._header)
67+
else:
68+
self._reader = csv.reader(self._file, delimiter=self.delimiter)
69+
# Skip header if needed (applies only when file has header)
70+
71+
assert isinstance(self._reader, Iterator)
72+
if self.has_header:
73+
try:
74+
next(self._reader) # Skip header line
75+
except StopIteration:
76+
pass # Empty file
77+
# Fast-forward to target line using efficient slicing
78+
consumed = sum(1 for _ in islice(self._reader, target_line_num))
79+
self._num_lines_yielded = consumed
80+
81+
def _initialize_reader(self):
82+
"""Create fresh reader without state."""
83+
assert self._file is not None
84+
if self.return_dict:
85+
self._reader = csv.DictReader(self._file, delimiter=self.delimiter)
86+
self._header = self._reader.fieldnames
87+
else:
88+
self._reader = csv.reader(self._file, delimiter=self.delimiter)
5889

59-
assert isinstance(self._reader, Iterator)
6090
if self.has_header:
61-
next(self._reader) # Skip header
62-
for _ in range(target_line_num - self._line_num):
63-
try:
64-
next(self._reader)
65-
self._line_num += 1
66-
except StopIteration:
67-
break
68-
else:
6991

70-
if self.return_dict:
71-
self._reader = csv.DictReader(self._file, delimiter=self.delimiter)
72-
self._header = self._reader.fieldnames
73-
else:
74-
self._reader = csv.reader(self._file, delimiter=self.delimiter)
75-
if self.has_header:
92+
try:
7693
self._header = next(self._reader)
94+
except StopIteration:
95+
self._header = None # Handle empty file
7796

7897
def next(self) -> Union[List[str], Dict[str, str]]:
7998
try:
8099
assert isinstance(self._reader, Iterator)
81100
row = next(self._reader)
82-
self._line_num += 1
101+
self._num_lines_yielded += 1
83102
return row
84103

85104
except StopIteration:
86105
self.close()
87106
raise
88107

89108
def get_state(self) -> Dict[str, Any]:
90-
return {self.LINE_NUM_KEY: self._line_num, self.HEADER_KEY: self._header}
109+
return {
110+
self.NUM_LINES_YIELDED: self._num_lines_yielded,
111+
self.HEADER_KEY: self._header,
112+
}
91113

92114
def close(self):
93115
if self._file and not self._file.closed:

0 commit comments

Comments
 (0)