Skip to content

Commit bdbf845

Browse files
committed
more tests
1 parent 2de80ae commit bdbf845

File tree

2 files changed

+146
-122
lines changed

2 files changed

+146
-122
lines changed

xarray/core/accessor_str.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,10 @@ def get(self, i, default=None):
7878
items : array of objects
7979
'''
8080
if default is None:
81-
default = ''
81+
default = self._obj.dtype.type('')
8282

83-
def f(x):
84-
n = len(x)
85-
if n <= i or i < -n:
86-
return default
87-
return x[i]
88-
return self._apply(f)
83+
obj = slice(i, i + 1)
84+
return self._apply(lambda x: x[obj])
8985

9086
def slice(self, start=None, stop=None, step=None):
9187
'''
@@ -130,12 +126,14 @@ def slice_replace(self, start=None, stop=None, repl=''):
130126
-------
131127
replaced : same type as values
132128
'''
129+
repl = self._obj.dtype.type(repl)
130+
133131
def f(x):
134-
if x[start:stop] == '':
132+
if len(x[start:stop]) == 0:
135133
local_stop = start
136134
else:
137135
local_stop = stop
138-
y = ''
136+
y = self._obj.dtype.type('')
139137
if start is not None:
140138
y += x[:start]
141139
y += repl
@@ -314,6 +312,7 @@ def count(self, pat, flags=0):
314312
-------
315313
counts : array of int
316314
'''
315+
pat = self._obj.dtype.type(pat)
317316
regex = re.compile(pat, flags=flags)
318317
f = lambda x: len(regex.findall(x))
319318
return self._apply(f, dtype=int)
@@ -333,6 +332,7 @@ def startswith(self, pat):
333332
An array of booleans indicating whether the given pattern matches
334333
the start of each string element.
335334
'''
335+
pat = self._obj.dtype.type(pat)
336336
f = lambda x: x.startswith(pat)
337337
return self._apply(f, dtype=bool)
338338

@@ -351,6 +351,7 @@ def endswith(self, pat):
351351
A Series of booleans indicating whether the given pattern matches
352352
the end of each string element.
353353
'''
354+
pat = self._obj.dtype.type(pat)
354355
f = lambda x: x.endswith(pat)
355356
return self._apply(f, dtype=bool)
356357

@@ -374,7 +375,7 @@ def pad(self, width, side='left', fillchar=' '):
374375
Array with a minimum number of char in each element.
375376
'''
376377
width = int(width)
377-
fillchar = str(fillchar)
378+
fillchar = self._obj.dtype.type(fillchar)
378379
if len(fillchar) != 1:
379380
raise TypeError('fillchar must be a character, not str')
380381

@@ -491,6 +492,7 @@ def contains(self, pat, case=True, flags=0, regex=True):
491492
given pattern is contained within the string of each element
492493
of the array.
493494
'''
495+
pat = self._obj.dtype.type(pat)
494496
if regex:
495497
if not case:
496498
flags |= re.IGNORECASE
@@ -530,6 +532,7 @@ def match(self, pat, case=True, flags=0):
530532
if not case:
531533
flags |= re.IGNORECASE
532534

535+
pat = self._obj.dtype.type(pat)
533536
regex = re.compile(pat, flags=flags)
534537
f = lambda x: bool(regex.match(x))
535538
return self._apply(f, dtype=bool)
@@ -554,6 +557,9 @@ def strip(self, to_strip=None, side='both'):
554557
-------
555558
stripped : same type as values
556559
'''
560+
if to_strip is not None:
561+
to_strip = self._obj.dtype.type(to_strip)
562+
557563
if side == 'both':
558564
f = lambda x: x.strip(to_strip)
559565
elif side == 'left':
@@ -703,7 +709,8 @@ def find(self, sub, start=0, end=None, side='left'):
703709
-------
704710
found : array of integer values
705711
'''
706-
sub = str(sub)
712+
sub = self._obj.dtype.type(sub)
713+
707714
if side == 'left':
708715
method = 'find'
709716
elif side == 'right':
@@ -761,7 +768,7 @@ def index(self, sub, start=0, end=None, side='left'):
761768
-------
762769
found : array of integer values
763770
'''
764-
sub = str(sub)
771+
sub = self._obj.dtype.type(sub)
765772

766773
if side == 'left':
767774
method = 'index'
@@ -837,6 +844,12 @@ def replace(self, pat, repl, n=-1, case=None, flags=0, regex=True):
837844
if not (_is_str_like(repl) or callable(repl)):
838845
raise TypeError("repl must be a string or callable")
839846

847+
if _is_str_like(pat):
848+
pat = self._obj.dtype.type(pat)
849+
850+
if _is_str_like(repl):
851+
repl = self._obj.dtype.type(repl)
852+
840853
is_compiled_re = isinstance(pat, type(re.compile('')))
841854
if regex:
842855
if is_compiled_re:
@@ -906,4 +919,4 @@ def encode(self, encoding, errors='strict'):
906919
else:
907920
encoder = codecs.getencoder(encoding)
908921
f = lambda x: encoder(x, errors)[0]
909-
return self._apply(f, dtype=np.string_)
922+
return self._apply(f, dtype=np.bytes_)

0 commit comments

Comments
 (0)