-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[proto] Speed up adjust color ops #6784
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
94e918c
fc4f237
58eec29
ffc5c4f
b7b5178
4f3491a
2a5e4d8
a170513
0b55072
a99d6ad
b7fdd39
4117957
a82cf8c
247ed7d
f19edc9
eff6c6f
2397725
f364efc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -2,9 +2,29 @@ | |||||||||
from torchvision.prototype import features | ||||||||||
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT | ||||||||||
|
||||||||||
from ._meta import get_dimensions_image_tensor | ||||||||||
from ._meta import _rgb_to_gray, get_dimensions_image_tensor, get_num_channels_image_tensor | ||||||||||
|
||||||||||
|
||||||||||
def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: | ||||||||||
ratio = float(ratio) | ||||||||||
fp = image1.is_floating_point() | ||||||||||
bound = 1.0 if fp else 255.0 | ||||||||||
output = image1.mul(ratio).add_(image2, alpha=(1.0 - ratio)).clamp_(0, bound) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like this! Supersedes the work at #6765 |
||||||||||
return output if fp else output.to(image1.dtype) | ||||||||||
|
||||||||||
|
||||||||||
def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float) -> torch.Tensor: | ||||||||||
if brightness_factor < 0: | ||||||||||
raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.") | ||||||||||
|
||||||||||
_FT._assert_channels(image, [1, 3]) | ||||||||||
|
||||||||||
fp = image.is_floating_point() | ||||||||||
bound = 1.0 if fp else 255.0 | ||||||||||
output = image.mul(brightness_factor).clamp_(0, bound) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Issue: Suggestion: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @MiChatz thanks for the feedback. There is (yet unwritten) assumption for color transformations on float images that image range is between [0, 1]. |
||||||||||
return output if fp else output.to(image.dtype) | ||||||||||
|
||||||||||
|
||||||||||
adjust_brightness_image_tensor = _FT.adjust_brightness | ||||||||||
adjust_brightness_image_pil = _FP.adjust_brightness | ||||||||||
|
||||||||||
|
||||||||||
|
@@ -21,7 +41,20 @@ def adjust_brightness(inpt: features.InputTypeJIT, brightness_factor: float) -> | |||||||||
return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor) | ||||||||||
|
||||||||||
|
||||||||||
adjust_saturation_image_tensor = _FT.adjust_saturation | ||||||||||
def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float) -> torch.Tensor: | ||||||||||
if saturation_factor < 0: | ||||||||||
raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.") | ||||||||||
|
||||||||||
c = get_num_channels_image_tensor(image) | ||||||||||
if c not in [1, 3]: | ||||||||||
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") | ||||||||||
|
||||||||||
if c == 1: # Match PIL behaviour | ||||||||||
return image | ||||||||||
|
||||||||||
return _blend(image, _rgb_to_gray(image), saturation_factor) | ||||||||||
|
||||||||||
|
||||||||||
adjust_saturation_image_pil = _FP.adjust_saturation | ||||||||||
|
||||||||||
|
||||||||||
|
@@ -38,7 +71,19 @@ def adjust_saturation(inpt: features.InputTypeJIT, saturation_factor: float) -> | |||||||||
return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor) | ||||||||||
|
||||||||||
|
||||||||||
adjust_contrast_image_tensor = _FT.adjust_contrast | ||||||||||
def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> torch.Tensor: | ||||||||||
if contrast_factor < 0: | ||||||||||
raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.") | ||||||||||
|
||||||||||
c = get_num_channels_image_tensor(image) | ||||||||||
if c not in [1, 3]: | ||||||||||
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") | ||||||||||
dtype = image.dtype if torch.is_floating_point(image) else torch.float32 | ||||||||||
grayscale_image = _rgb_to_gray(image) if c == 3 else image | ||||||||||
mean = torch.mean(grayscale_image.to(dtype), dim=(-3, -2, -1), keepdim=True) | ||||||||||
Comment on lines
+82
to
+83
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
This saves one conversion in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @pmeier actually, doing so the output of
So, finally consistency tests report for example:
and this is a real failure, IMO. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that the behavior changes, but IMO repeatedly converting to |
||||||||||
return _blend(image, mean, contrast_factor) | ||||||||||
|
||||||||||
|
||||||||||
adjust_contrast_image_pil = _FP.adjust_contrast | ||||||||||
|
||||||||||
|
||||||||||
|
@@ -74,7 +119,7 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) | |||||||||
else: | ||||||||||
needs_unsquash = False | ||||||||||
|
||||||||||
output = _FT._blend(image, _FT._blurred_degenerate_image(image), sharpness_factor) | ||||||||||
output = _blend(image, _FT._blurred_degenerate_image(image), sharpness_factor) | ||||||||||
|
||||||||||
if needs_unsquash: | ||||||||||
output = output.reshape(shape) | ||||||||||
|
@@ -183,13 +228,13 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT: | |||||||||
return autocontrast_image_pil(inpt) | ||||||||||
|
||||||||||
|
||||||||||
def _equalize_image_tensor_vec(img: torch.Tensor) -> torch.Tensor: | ||||||||||
# input img shape should be [N, H, W] | ||||||||||
shape = img.shape | ||||||||||
def _equalize_image_tensor_vec(image: torch.Tensor) -> torch.Tensor: | ||||||||||
# input image shape should be [N, H, W] | ||||||||||
shape = image.shape | ||||||||||
# Compute image histogram: | ||||||||||
flat_img = img.flatten(start_dim=1).to(torch.long) # -> [N, H * W] | ||||||||||
hist = flat_img.new_zeros(shape[0], 256) | ||||||||||
hist.scatter_add_(dim=1, index=flat_img, src=flat_img.new_ones(1).expand_as(flat_img)) | ||||||||||
flat_image = image.flatten(start_dim=1).to(torch.long) # -> [N, H * W] | ||||||||||
hist = flat_image.new_zeros(shape[0], 256) | ||||||||||
hist.scatter_add_(dim=1, index=flat_image, src=flat_image.new_ones(1).expand_as(flat_image)) | ||||||||||
|
||||||||||
# Compute image cdf | ||||||||||
chist = hist.cumsum_(dim=1) | ||||||||||
|
@@ -213,7 +258,7 @@ def _equalize_image_tensor_vec(img: torch.Tensor) -> torch.Tensor: | |||||||||
zeros = lut.new_zeros((1, 1)).expand(shape[0], 1) | ||||||||||
lut = torch.cat([zeros, lut[:, :-1]], dim=1) | ||||||||||
|
||||||||||
return torch.where((step == 0).unsqueeze(-1), img, lut.gather(dim=1, index=flat_img).reshape_as(img)) | ||||||||||
return torch.where((step == 0).unsqueeze(-1), image, lut.gather(dim=1, index=flat_image).reshape_as(image)) | ||||||||||
|
||||||||||
|
||||||||||
def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: | ||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.