Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
290 changes: 190 additions & 100 deletions torchvision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,156 @@
import collections


def _is_pil_image(img):
if accimage is not None:
return isinstance(img, (Image.Image, accimage.Image))
else:
return isinstance(img, Image.Image)


def _is_tensor_image(img):
return torch.is_tensor(img) and img.ndimension() == 3


def _is_numpy_image(img):
return isinstance(img, np.ndarray) and (img.ndim in {2, 3})


def to_tensor(pic):
if not(_is_pil_image(pic) or _is_numpy_image(pic)):
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))

if isinstance(pic, np.ndarray):
# handle numpy array
img = torch.from_numpy(pic.transpose((2, 0, 1)))
# backward compatibility
return img.float().div(255)

if accimage is not None and isinstance(pic, accimage.Image):
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
pic.copyto(nppic)
return torch.from_numpy(nppic)

# handle PIL Image
if pic.mode == 'I':
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
elif pic.mode == 'I;16':
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
else:
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
if pic.mode == 'YCbCr':
nchannel = 3
elif pic.mode == 'I;16':
nchannel = 1
else:
nchannel = len(pic.mode)
img = img.view(pic.size[1], pic.size[0], nchannel)
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 1).transpose(0, 2).contiguous()
if isinstance(img, torch.ByteTensor):
return img.float().div(255)
else:
return img


def to_pilimage(pic):

This comment was marked as off-topic.

if not(_is_numpy_image(pic) or _is_tensor_image(pic)):
raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))

npimg = pic
mode = None
if isinstance(pic, torch.FloatTensor):
pic = pic.mul(255).byte()
if torch.is_tensor(pic):
npimg = np.transpose(pic.numpy(), (1, 2, 0))
assert isinstance(npimg, np.ndarray)
if npimg.shape[2] == 1:
npimg = npimg[:, :, 0]

if npimg.dtype == np.uint8:
mode = 'L'
if npimg.dtype == np.int16:
mode = 'I;16'
if npimg.dtype == np.int32:
mode = 'I'
elif npimg.dtype == np.float32:
mode = 'F'
else:
if npimg.dtype == np.uint8:
mode = 'RGB'
assert mode is not None, '{} is not supported'.format(npimg.dtype)
return Image.fromarray(npimg, mode=mode)


def normalize(tensor, mean, std):
if not _is_tensor_image(tensor):
raise TypeError('tensor is not a torch image.')
# TODO: make efficient

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

for t, m, s in zip(tensor, mean, std):
t.sub_(m).div_(s)
return tensor


def scale(img, size, interpolation=Image.BILINEAR):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)):
raise TypeError('Got inappropriate size arg: {}'.format(size))

if isinstance(size, int):
w, h = img.size
if (w <= h and w == size) or (h <= w and h == size):
return img
if w < h:
ow = size
oh = int(size * h / w)
return img.resize((ow, oh), interpolation)
else:
oh = size
ow = int(size * w / h)
return img.resize((ow, oh), interpolation)
else:
return img.resize(size, interpolation)


def pad(img, padding, fill=0):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

if not isinstance(padding, (numbers.Number, tuple)):
raise TypeError('Got inappropriate padding arg')
if not isinstance(fill, (numbers.Number, str, tuple)):
raise TypeError('Got inappropriate fill arg')

if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
"{} element tuple".format(len(padding)))

return ImageOps.expand(img, border=padding, fill=fill)


def crop(img, x, y, w, h):

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

return img.crop((x, y, x + w, y + h))

This comment was marked as off-topic.



def scaled_crop(img, x, y, w, h, size, interpolation=Image.BILINEAR):

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

assert _is_pil_image(img), 'img should be PIL Image'
img = crop(img, x, y, w, h)

This comment was marked as off-topic.

img = scale(img, size, interpolation)


def hflip(img):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

return img.transpose(Image.FLIP_LEFT_RIGHT)


class Compose(object):
"""Composes several transforms together.

Expand Down Expand Up @@ -50,39 +200,7 @@ def __call__(self, pic):
Returns:
Tensor: Converted image.
"""
if isinstance(pic, np.ndarray):
# handle numpy array
img = torch.from_numpy(pic.transpose((2, 0, 1)))
# backward compatibility
return img.float().div(255)

if accimage is not None and isinstance(pic, accimage.Image):
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
pic.copyto(nppic)
return torch.from_numpy(nppic)

# handle PIL Image
if pic.mode == 'I':
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
elif pic.mode == 'I;16':
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
else:
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
if pic.mode == 'YCbCr':
nchannel = 3
elif pic.mode == 'I;16':
nchannel = 1
else:
nchannel = len(pic.mode)
img = img.view(pic.size[1], pic.size[0], nchannel)
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 1).transpose(0, 2).contiguous()
if isinstance(img, torch.ByteTensor):
return img.float().div(255)
else:
return img
return to_tensor(pic)


class ToPILImage(object):
Expand All @@ -101,29 +219,7 @@ def __call__(self, pic):
PIL.Image: Image converted to PIL.Image.

"""
npimg = pic
mode = None
if isinstance(pic, torch.FloatTensor):
pic = pic.mul(255).byte()
if torch.is_tensor(pic):
npimg = np.transpose(pic.numpy(), (1, 2, 0))
assert isinstance(npimg, np.ndarray), 'pic should be Tensor or ndarray'
if npimg.shape[2] == 1:
npimg = npimg[:, :, 0]

if npimg.dtype == np.uint8:
mode = 'L'
if npimg.dtype == np.int16:
mode = 'I;16'
if npimg.dtype == np.int32:
mode = 'I'
elif npimg.dtype == np.float32:
mode = 'F'
else:
if npimg.dtype == np.uint8:
mode = 'RGB'
assert mode is not None, '{} is not supported'.format(npimg.dtype)
return Image.fromarray(npimg, mode=mode)
return to_pilimage(pic)


class Normalize(object):
Expand Down Expand Up @@ -151,10 +247,7 @@ def __call__(self, tensor):
Returns:
Tensor: Normalized image.
"""
# TODO: make efficient
for t, m, s in zip(tensor, self.mean, self.std):
t.sub_(m).div_(s)
return tensor
return normalize(tensor, self.mean, self.std)


class Scale(object):
Expand Down Expand Up @@ -183,20 +276,7 @@ def __call__(self, img):
Returns:
PIL.Image: Rescaled image.
"""
if isinstance(self.size, int):
w, h = img.size
if (w <= h and w == self.size) or (h <= w and h == self.size):
return img
if w < h:
ow = self.size
oh = int(self.size * h / w)
return img.resize((ow, oh), self.interpolation)
else:
oh = self.size
ow = int(self.size * w / h)
return img.resize((ow, oh), self.interpolation)
else:
return img.resize(self.size, self.interpolation)
return scale(img, self.size, self.interpolation)


class CenterCrop(object):
Expand All @@ -214,6 +294,13 @@ def __init__(self, size):
else:
self.size = size

def get_params(self, img):

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

w, h = img.size
th, tw = self.size
x1 = int(round((w - tw) / 2.))
y1 = int(round((h - th) / 2.))
return x1, y1, tw, th

def __call__(self, img):
"""
Args:
Expand All @@ -222,11 +309,8 @@ def __call__(self, img):
Returns:
PIL.Image: Cropped image.
"""
w, h = img.size
th, tw = self.size
x1 = int(round((w - tw) / 2.))
y1 = int(round((h - th) / 2.))
return img.crop((x1, y1, x1 + tw, y1 + th))
x1, y1, tw, th = self.get_params(img)
return crop(img, x1, y1, tw, th)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.



class Pad(object):
Expand Down Expand Up @@ -260,7 +344,7 @@ def __call__(self, img):
Returns:
PIL.Image: Padded image.
"""
return ImageOps.expand(img, border=self.padding, fill=self.fill)
return pad(img, self.padding, self.fill)


class Lambda(object):

This comment was marked as off-topic.

This comment was marked as off-topic.

Expand Down Expand Up @@ -298,6 +382,16 @@ def __init__(self, size, padding=0):
self.size = size
self.padding = padding

def get_params(self, img):

This comment was marked as off-topic.

This comment was marked as off-topic.

w, h = img.size
th, tw = self.size
if w == tw and h == th:
return img

x1 = random.randint(0, w - tw)
y1 = random.randint(0, h - th)
return x1, y1, tw, th

def __call__(self, img):
"""
Args:
Expand All @@ -307,16 +401,11 @@ def __call__(self, img):
PIL.Image: Cropped image.
"""
if self.padding > 0:
img = ImageOps.expand(img, border=self.padding, fill=0)
img = pad(img, self.padding)

w, h = img.size
th, tw = self.size
if w == tw and h == th:
return img
x1, y1, tw, th = self.get_params(img)

x1 = random.randint(0, w - tw)
y1 = random.randint(0, h - th)
return img.crop((x1, y1, x1 + tw, y1 + th))
return crop(img, x1, y1, tw, th)


class RandomHorizontalFlip(object):
Expand All @@ -331,7 +420,7 @@ def __call__(self, img):
PIL.Image: Randomly flipped image.
"""
if random.random() < 0.5:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return hflip(img)
return img


Expand All @@ -352,7 +441,7 @@ def __init__(self, size, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation

def __call__(self, img):
def get_params(self, img):

This comment was marked as off-topic.

for attempt in range(10):
area = img.size[0] * img.size[1]
target_area = random.uniform(0.08, 1.0) * area
Expand All @@ -365,15 +454,16 @@ def __call__(self, img):
w, h = h, w

if w <= img.size[0] and h <= img.size[1]:
x1 = random.randint(0, img.size[0] - w)
y1 = random.randint(0, img.size[1] - h)

img = img.crop((x1, y1, x1 + w, y1 + h))
assert(img.size == (w, h))

return img.resize((self.size, self.size), self.interpolation)
x = random.randint(0, img.size[0] - w)
y = random.randint(0, img.size[1] - h)
return x, y, w, h

# Fallback
scale = Scale(self.size, interpolation=self.interpolation)
crop = CenterCrop(self.size)
return crop(scale(img))
w = min(img.size[0], img.shape[1])
x = (img.shape[0] - w) // 2
y = (img.shape[1] - w) // 2
return x, y, w, w

def __call__(self, img):
x, y, w, h = self.get_params(img)
return scaled_crop(img, x, y, w, h, self.size, self.interpolation)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.