1
1
import csv
2
+ from itertools import islice
2
3
from typing import Any , Dict , Iterator , List , Optional , Sequence , TextIO , Union
3
4
4
5
from torchdata .nodes .base_node import BaseNode
@@ -13,7 +14,7 @@ class CSVReader(BaseNode[Union[List[str], Dict[str, str]]]):
13
14
return_dict: Return rows as dictionaries (requires has_header=True)
14
15
"""
15
16
16
- LINE_NUM_KEY = "line_num "
17
+ NUM_LINES_YIELDED = "num_lines_yielded "
17
18
HEADER_KEY = "header"
18
19
19
20
def __init__ (
@@ -22,6 +23,7 @@ def __init__(
22
23
has_header : bool = False ,
23
24
delimiter : str = "," ,
24
25
return_dict : bool = False ,
26
+ encoding : str = "utf-8" ,
25
27
):
26
28
super ().__init__ ()
27
29
self .file_path = file_path
@@ -30,64 +32,84 @@ def __init__(
30
32
self .return_dict = return_dict
31
33
if return_dict and not has_header :
32
34
raise ValueError ("return_dict=True requires has_header=True" )
35
+ self .encoding = encoding
33
36
self ._file : Optional [TextIO ] = None
34
37
self ._reader : Optional [Iterator [Union [List [str ], Dict [str , str ]]]] = None
35
38
self ._header : Optional [Sequence [str ]] = None
36
- self ._line_num : int = 0
39
+ self ._num_lines_yielded : int = 0
37
40
self .reset () # Initialize reader
38
41
39
42
def reset (self , initial_state : Optional [Dict [str , Any ]] = None ):
40
- super ().reset (initial_state )
41
-
42
- if self ._file and not self ._file .closed :
43
- self ._file .close ()
43
+ super ().reset ()
44
+ self .close ()
45
+
46
+ # Reopen the file and reset counters
47
+ self ._file = open (self .file_path , encoding = self .encoding )
48
+ self ._num_lines_yielded = 0
49
+ if initial_state is not None :
50
+ self ._handle_initial_state (initial_state )
51
+ else :
52
+ self ._initialize_reader ()
44
53
45
- self ._file = open (self .file_path , newline = "" , encoding = "utf-8" )
46
- self ._line_num = 0
54
+ def _handle_initial_state (self , state : Dict [str , Any ]):
55
+ """Restore reader state from checkpoint."""
56
+ # Validate header compatibility
57
+ if (not self .has_header and self .HEADER_KEY in state ) or (self .has_header and state [self .HEADER_KEY ] is None ):
58
+ raise ValueError (f"Check if has_header={ self .has_header } matches the state header={ state [self .HEADER_KEY ]} " )
47
59
48
- if initial_state :
49
- self ._header = initial_state .get (self .HEADER_KEY )
50
- target_line_num = initial_state [self .LINE_NUM_KEY ]
60
+ self ._header = state .get (self .HEADER_KEY )
61
+ target_line_num = state [self .NUM_LINES_YIELDED ]
62
+ assert self ._file is not None
63
+ # Create appropriate reader
64
+ if self .return_dict :
51
65
52
- if self .return_dict :
53
- if self ._header is None :
54
- raise ValueError ("return_dict=True requires has_header=True" )
55
- self ._reader = csv .DictReader (self ._file , delimiter = self .delimiter , fieldnames = self ._header )
56
- else :
57
- self ._reader = csv .reader (self ._file , delimiter = self .delimiter )
66
+ self ._reader = csv .DictReader (self ._file , delimiter = self .delimiter , fieldnames = self ._header )
67
+ else :
68
+ self ._reader = csv .reader (self ._file , delimiter = self .delimiter )
69
+ # Skip header if needed (applies only when file has header)
70
+
71
+ assert isinstance (self ._reader , Iterator )
72
+ if self .has_header :
73
+ try :
74
+ next (self ._reader ) # Skip header line
75
+ except StopIteration :
76
+ pass # Empty file
77
+ # Fast-forward to target line using efficient slicing
78
+ consumed = sum (1 for _ in islice (self ._reader , target_line_num ))
79
+ self ._num_lines_yielded = consumed
80
+
81
+ def _initialize_reader (self ):
82
+ """Create fresh reader without state."""
83
+ assert self ._file is not None
84
+ if self .return_dict :
85
+ self ._reader = csv .DictReader (self ._file , delimiter = self .delimiter )
86
+ self ._header = self ._reader .fieldnames
87
+ else :
88
+ self ._reader = csv .reader (self ._file , delimiter = self .delimiter )
58
89
59
- assert isinstance (self ._reader , Iterator )
60
90
if self .has_header :
61
- next (self ._reader ) # Skip header
62
- for _ in range (target_line_num - self ._line_num ):
63
- try :
64
- next (self ._reader )
65
- self ._line_num += 1
66
- except StopIteration :
67
- break
68
- else :
69
91
70
- if self .return_dict :
71
- self ._reader = csv .DictReader (self ._file , delimiter = self .delimiter )
72
- self ._header = self ._reader .fieldnames
73
- else :
74
- self ._reader = csv .reader (self ._file , delimiter = self .delimiter )
75
- if self .has_header :
92
+ try :
76
93
self ._header = next (self ._reader )
94
+ except StopIteration :
95
+ self ._header = None # Handle empty file
77
96
78
97
def next (self ) -> Union [List [str ], Dict [str , str ]]:
79
98
try :
80
99
assert isinstance (self ._reader , Iterator )
81
100
row = next (self ._reader )
82
- self ._line_num += 1
101
+ self ._num_lines_yielded += 1
83
102
return row
84
103
85
104
except StopIteration :
86
105
self .close ()
87
106
raise
88
107
89
108
def get_state (self ) -> Dict [str , Any ]:
90
- return {self .LINE_NUM_KEY : self ._line_num , self .HEADER_KEY : self ._header }
109
+ return {
110
+ self .NUM_LINES_YIELDED : self ._num_lines_yielded ,
111
+ self .HEADER_KEY : self ._header ,
112
+ }
91
113
92
114
def close (self ):
93
115
if self ._file and not self ._file .closed :
0 commit comments