diff --git a/torchtext/csrc/register_torchbindings.cpp b/torchtext/csrc/register_torchbindings.cpp index ee509668cf..6f13bcb044 100644 --- a/torchtext/csrc/register_torchbindings.cpp +++ b/torchtext/csrc/register_torchbindings.cpp @@ -67,6 +67,7 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) { }, // __setstate__ [](torch::Tensor state) -> c10::intrusive_ptr { + state = state.to(at::kCPU); auto* data = static_cast(state.data_ptr()); auto numel = state.size(0); return c10::make_intrusive(std::string(data, numel)); diff --git a/torchtext/data/utils.py b/torchtext/data/utils.py index a1b53ab559..89a72ea455 100644 --- a/torchtext/data/utils.py +++ b/torchtext/data/utils.py @@ -228,7 +228,7 @@ def _get_ngrams(n): yield " ".join(x) -class RandomShuffler(object): +class RandomShuffler: """Use random functions while keeping track of the random state to make it reproducible and deterministic.""" diff --git a/torchtext/vocab/vectors.py b/torchtext/vocab/vectors.py index af60c4f60b..3743207557 100644 --- a/torchtext/vocab/vectors.py +++ b/torchtext/vocab/vectors.py @@ -31,7 +31,7 @@ def _infer_shape(f): return num_lines, vector_dim -class Vectors(object): +class Vectors: def __init__(self, name, cache=None, url=None, unk_init=None, max_vectors=None) -> None: """ Args: