Skip to content

Conversation

@MarcBresson
Copy link
Contributor

@MarcBresson MarcBresson commented Aug 23, 2023

This new version is much quicker (granted, it will not save a lot of absolute time).

It avoids enumerating on a tensor, which is always slow.

def old_uniform(kernel_size: int):
    max, min = 2.5, -2.5
    ksize_half = (kernel_size - 1) * 0.5
    kernel = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
    for i, j in enumerate(kernel):
        if min <= j <= max:
            kernel[i] = 1 / (max - min)
        else:
            kernel[i] = 0

    return kernel.unsqueeze(dim=0)

def new_uniform(kernel_size):
    kernel = torch.zeros(kernel_size)

    start_uniform_index = max(kernel_size // 2 - 2, 0)
    end_uniform_index = min(kernel_size // 2 + 3, kernel_size)

    min_, max_  = -2.5, 2.5
    kernel[start_uniform_index:end_uniform_index] = 1 / (max_ - min_)

    return kernel.unsqueeze(dim=0)

Performance comparison

%timeit old_uniform(11)
>>> 354 µs ± 14.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

%timeit new_uniform(11)
>>> 13.6 µs ± 303 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

%timeit old_uniform(3)
>>> 123 µs ± 2.19 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

%timeit new_uniform(3)
>>> 11.6 µs ± 1.13 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

Check for equality

for kernel_size in range(1, 101, 2):
    torch.testing.assert_close(old_uniform(kernel_size), new_uniform(kernel_size))

Check list:

  • New tests are added (if a new feature is added)
  • New doc strings: description and/or example code are in RST format
  • Documentation is updated (if required)

@github-actions github-actions bot added the module: metrics Metrics module label Aug 23, 2023
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Aug 23, 2023

Nice improvements, @MarcBresson !

For equality tests I would use torch.testing.assert_close(old_uniform(kernel_size), new_uniform(kernel_size)).

I wonder how to see that

start_uniform_index = max(kernel_size // 2 - 2, 0)
end_uniform_index = min(kernel_size // 2 + 3, kernel_size)
kernel[start_uniform_index:end_uniform_index] = 1 / (max_ - min_)

is equivalent to

        if min <= j <= max:
            kernel[i] = 1 / (max - min)

?

@MarcBresson
Copy link
Contributor Author

It was hard to wrap my head around the possible decisions that led to this code.

Basically, the former code was creating a tensor that went from -kernel_size // 2 to kernel_size // 2. Then all the values of this tensor that were < to -2.5 or > to 2.5 were replaced by 0, and the other were set to 1 / (max - min).

The new code just put 0 everywhere then compute the indices that should be set to 1 / (max - min).

The only difference between the two codes is when the kernel size is an even number, but this raise an error beforehand.

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarcBresson thanks a lot for the perf improvement!
LGTM

@vfdev-5 vfdev-5 merged commit 178d82c into pytorch:master Aug 24, 2023
@MarcBresson MarcBresson deleted the refactor-ssim-_uniform branch August 24, 2023 08:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: metrics Metrics module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants