Skip to content

Commit 85e4429

Browse files
authored
Merge branch 'pytorch:main' into main
2 parents 1e578b7 + db3a905 commit 85e4429

File tree

6 files changed

+511
-1
lines changed

6 files changed

+511
-1
lines changed

scripts/README.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
Utility scripts
2+
===============
3+
4+
* `fbcode_to_main_sync.sh`
5+
6+
This shell script is used to synchronise internal changes with the main repository.
7+
8+
To run this script:
9+
10+
.. code:: bash
11+
12+
chmod +x fbcode_to_main_sync.sh
13+
./fbcode_to_main_sync.sh <commit_hash> <fork_name> <fork_main_branch>
14+
15+
where
16+
17+
``commit_hash`` represents the commit hash in fbsync branch from where we should start the sync.
18+
19+
``fork_name`` is the name of the remote corresponding to your fork, you can check it by doing `"git remote -v"`.
20+
21+
``fork_main_branch`` (optional) is the name of the main branch on your fork(default="main").
22+
23+
This script will create PRs corresponding to the commits in fbsync. Please review these, add the [FBCode->GH] prefix on the title and publish them. Most importantly, add the [FBCode->GH] prefix at the beginning of the merge message as well.

scripts/fbcode_to_main_sync.sh

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#!/bin/bash
2+
3+
if [ -z $1 ]
4+
then
5+
echo "Commit hash is required to be passed when running this script."
6+
echo "./fbcode_to_main_sync.sh <commit_hash> <fork_name> <fork_main_branch>"
7+
exit 1
8+
fi
9+
commit_hash=$1
10+
11+
if [ -z $2 ]
12+
then
13+
echo "Fork name is required to be passed when running this script."
14+
echo "./fbcode_to_main_sync.sh <commit_hash> <fork_name> <fork_main_branch>"
15+
exit 1
16+
fi
17+
fork_name=$2
18+
19+
if [ -z $3 ]
20+
then
21+
fork_main_branch="main"
22+
else
23+
fork_main_branch=$3
24+
fi
25+
26+
from_branch="fbsync"
27+
git stash
28+
git checkout $from_branch
29+
git pull
30+
# Add random prefix in the new branch name to keep it unique per run
31+
prefix=$RANDOM
32+
IFS='
33+
'
34+
for line in $(git log --pretty=oneline "$commit_hash"..HEAD)
35+
do
36+
if [[ $line != *\[fbsync\]* ]]
37+
then
38+
echo "Parsing $line"
39+
hash=$(echo $line | cut -f1 -d' ')
40+
git checkout $fork_main_branch
41+
git checkout -B cherrypick_${prefix}_${hash}
42+
git cherry-pick -x "$hash"
43+
git push $fork_name cherrypick_${prefix}_${hash}
44+
git checkout $from_branch
45+
fi
46+
done
47+
echo "Please review the PRs, add [FBCode->GH] prefix in the title and publish them."

torchvision/prototype/datasets/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@
1111
from ._home import home
1212

1313
# Load this last, since some parts depend on the above being loaded first
14-
from ._api import register, _list as list, info, load # usort: skip
14+
from ._api import register, _list as list, info, load, find # usort: skip
1515
from ._folder import from_data_folder, from_image_folder

torchvision/prototype/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from .vgg import *
55
from .efficientnet import *
66
from .mobilenetv3 import *
7+
from .mobilenetv2 import *
78
from .mnasnet import *
9+
from .regnet import *
810
from . import detection
911
from . import quantization
1012
from . import segmentation
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import warnings
2+
from functools import partial
3+
from typing import Any, Optional
4+
5+
from torchvision.transforms.functional import InterpolationMode
6+
7+
from ...models.mobilenetv2 import MobileNetV2
8+
from ..transforms.presets import ImageNetEval
9+
from ._api import Weights, WeightEntry
10+
from ._meta import _IMAGENET_CATEGORIES
11+
12+
13+
__all__ = ["MobileNetV2", "MobileNetV2Weights", "mobilenet_v2"]
14+
15+
16+
_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
17+
18+
19+
class MobileNetV2Weights(Weights):
20+
ImageNet1K_RefV1 = WeightEntry(
21+
url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
22+
transforms=partial(ImageNetEval, crop_size=224),
23+
meta={
24+
**_common_meta,
25+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2",
26+
"acc@1": 71.878,
27+
"acc@5": 90.286,
28+
},
29+
)
30+
31+
32+
def mobilenet_v2(weights: Optional[MobileNetV2Weights] = None, progress: bool = True, **kwargs: Any) -> MobileNetV2:
33+
if "pretrained" in kwargs:
34+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
35+
weights = MobileNetV2Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
36+
weights = MobileNetV2Weights.verify(weights)
37+
38+
if weights is not None:
39+
kwargs["num_classes"] = len(weights.meta["categories"])
40+
41+
model = MobileNetV2(**kwargs)
42+
43+
if weights is not None:
44+
model.load_state_dict(weights.state_dict(progress=progress))
45+
46+
return model

0 commit comments

Comments
 (0)