Skip to content

OnlineKNN: OnlineQueue buffers live on CPU, causing device mismatches and cpu->gpu transfrs #379

@sami-bg

Description

@sami-bg

OrderedQueue buffers are always created on CPU because OnlineQueue.setup() never moves them to the model's device after creation. This causes problems in a few places:

NNCLR is broken, because _find_nearest_neighbors in forward.py does torch.mm(query_norm, support_norm.t()) where the query is on cuda but the support set from the queue is on cpu.

Swav and KNN have band-aids that cover this: swav_forward does queue.get().clone().detach().to(proj1.device), and OnlineKNN._compute_knn_predictions checks if cached_features.device != features.device and moves. These work but they're copying the entire queue from CPU to GPU every time the data is consumed. for NNCLR's support set that's 65K x 256 (~64MB) every training step, and for KNN it's 20K x 512 every validation batch.

The fix is OnlineQueue.setup(), after creating or resizing the OrderedQueue, move it to pl_module's device:

device = next(pl_module.parameters()).device
self._shared_queues[self.key].to(device)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions