Closed
Description
🐛 Describe the bug
Consider the following example:
from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights
img = read_image("dog2.jpg").to("mps")
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights).to("mps")
model.eval()
preprocess = weights.transforms()
batch = preprocess(img).unsqueeze(0)
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
When run on MPS following image is classified as necklace:
Versions
Nightly/1.13.0
cc @ezyang @gchanan @zou3519 @kulinseth @albanD @DenisVieriu97 @razarmehr @abhudev