Skip to content

Parallelize deheap_sort#168

Merged
lmcinnes merged 9 commits intolmcinnes:masterfrom
jamestwebber:parallel-deheap
Feb 10, 2022
Merged

Parallelize deheap_sort#168
lmcinnes merged 9 commits intolmcinnes:masterfrom
jamestwebber:parallel-deheap

Conversation

@jamestwebber
Copy link
Collaborator

@jamestwebber jamestwebber commented Feb 7, 2022

This is a very simple change: I added parallel=True to utils.deheap_sort and I used a numba.prange in the top loop.

Sometimes, when I'm watching CPU usage for NNDescent on very large arrays, I've seen it spend a fair amount of time in a single thread near the end of the operation. Looking at the code, I suspect it's just dealing with this final deheap_sort call.

Because the function operates on each row independently, it's pretty simple to wrap it in a parallel loop and let numba figure it out. It seems to work but I haven't tested rigorously yet. Hoping travis can do that for me.

As an additional tweak, I added python 3.9 to the test matrix here, to see what happens. edit: obsolete given recent updates

@jamestwebber
Copy link
Collaborator Author

Oh my branch was way behind, let me rebase...

@lmcinnes
Copy link
Owner

lmcinnes commented Feb 8, 2022

One catch is that parallel=True can induce some overhead for single threaded cases (and there is a deheap sort at the end of query, and often queries can be one at a time as in ann-benchmarks). Any idea what the overhead costs involved are like?

@jamestwebber
Copy link
Collaborator Author

Any idea what the overhead costs involved are like?

No idea at all! I tested this very lightly, and it might not be worth it. In particular I'm not sure this was really the slow part in the larger datasets I was processing, maybe there was something else going on. Sometimes NNDescent takes longer than I would expect for mysterious reasons.

I can look into it more in a bit, no reason to rush it in without more investigation.

@lmcinnes
Copy link
Owner

lmcinnes commented Feb 9, 2022

It looks promising -- in that it is an easy change that could have benefits. Let's just leave it pending for now until you've dug a little deeper.

@jamestwebber
Copy link
Collaborator Author

It definitely seems like there's considerable overhead when calling with a single query. It's hard to measure because it depends on the parameters and the data, but it's consistent. Of course, with a batch query this can be faster.

So, maybe this isn't a good idea. I'm not sure why I thought this was a bottleneck, I would probably need to run a huge array through to see any effect.

It would be nice if there was some way to dispatch based on array shape, but I don't think that's possible.

@jamestwebber
Copy link
Collaborator Author

jamestwebber commented Feb 9, 2022

One possibility which I haven't quite figured out but seems promising: use guvectorize with a slightly different signature. With target="parallel" this can be run in parallel over arrays but doesn't seem to have as much overhead on single vectors.

But I need to do more debugging and testing, as I'm not familiar with writing these signatures and I think I'm doing something wrong with inplace modification.

One question: the existing code returns the output as np.int64. Is there a reason for this? I think the input is always int32 so this necessitates a copy. edit: if I use guvectorize I guess I do a copy anyway, as numba isn't reliable with inplace modification. But that seems compatible with the existing code.

@jamestwebber
Copy link
Collaborator Author

jamestwebber commented Feb 9, 2022

These errors are confusing, it seems like numba doesn't know what to do with its own ufunc objects. Maybe I need additional annotation somewhere.

Apparently this is a long-standing issue, although it doesn't seem to be documented as far as I can tell. Functions made with guvectorize can't be called within other JITted functions for some reason. I guess I will look into something else.

@lmcinnes
Copy link
Owner

lmcinnes commented Feb 9, 2022

I'm not sure what to make of the errors either. Thinking about this is might be easiest to do something along the lines of the following:

def deheap_sort_base(heaps):
    ...
    for i in numba.prange(indices.shape[0]):
        ...

deheap_sort_bulk = numba.njit(parallel=True)(deheap_sort_base)
deheap_sort_small = numba.njit(parallel=False)(deheap_sort_base)

and then you can call deheap_sort_small in the query method and deheap_sort_bulk for the nndescent index build. Potentially you can even switch between implementations based on dataset size. That may be the simplest approach?

@jamestwebber
Copy link
Collaborator Author

jamestwebber commented Feb 9, 2022

and then you can call deheap_sort_small in the query method and deheap_sort_bulk for the nndescent index build. Potentially you can even switch between implementations based on dataset size. That may be the simplest approach?

I was going to do that but I think I have a slightly better option, which matches the design of query already. You already have a parallel_search_queries flag in the NNDescent that controls the parallelization of the search closure. We can do the same to recompile the deheap_sort function if desired:

self._deheap_function = numba.njit(parallel=self.parallel_batch_queries)(
    deheap_sort.py_func
)

edit: this solution is basically the same as what you suggest, except that the user decides whether to turn this on for queries. I'm not sure what makes for the cleanest code, if you have a preference I can change it.

@lmcinnes
Copy link
Owner

lmcinnes commented Feb 9, 2022

Yes, I like that option a lot. It makes it all pretty clean from the user control perspective.

@lmcinnes lmcinnes merged commit 3182314 into lmcinnes:master Feb 10, 2022
@jamestwebber jamestwebber deleted the parallel-deheap branch February 10, 2022 15:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants