Skip to content

Commit 2924e26

Browse files
authored
V1.6.0: Refactor the core dependency module
V1.6.0: Refactor the core dependency module
2 parents adeb14c + e55140a commit 2924e26

File tree

16 files changed

+1458
-1388
lines changed

16 files changed

+1458
-1388
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22
<div align="center">
33
<img src="https://user-images.githubusercontent.com/18592211/232830417-0b21a874-516e-4420-8984-4de414a35085.png" width="400px"></img>
44
<h2></h2>
5-
<h3>Towards Any Structural Pruning<h3>
6-
<img src="assets/intro.png" width="50%">
5+
<img src="https://github.com/user-attachments/assets/50b03774-7345-4eb6-bf28-209195d354b0" width="40%">
6+
<h2></h2>
77
</div>
88

99
<p align="center">
1010
<a href="https://github.com/VainF/Torch-Pruning/actions"><img src="https://img.shields.io/badge/tests-passing-9c27b0.svg" alt="Test Status"></a>
1111
<a href="https://pytorch.org/"><img src="https://img.shields.io/badge/PyTorch-1.x %20%7C%202.x-673ab7.svg" alt="Tested PyTorch Versions"></a>
1212
<a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/badge/License-MIT-4caf50.svg" alt="License"></a>
1313
<a href="https://pepy.tech/project/Torch-Pruning"><img src="https://static.pepy.tech/badge/Torch-Pruning?color=2196f3" alt="Downloads"></a>
14-
<a href="https://github.com/VainF/Torch-Pruning/releases/latest"><img src="https://img.shields.io/badge/Latest%20Version-1.5.3-3f51b5.svg" alt="Latest Version"></a>
14+
<a href="https://github.com/VainF/Torch-Pruning/releases/latest"><img src="https://img.shields.io/badge/Latest%20Version-1.6.0-3f51b5.svg" alt="Latest Version"></a>
1515
<a href="https://colab.research.google.com/drive/1TRvELQDNj9PwM-EERWbF3IQOyxZeDepp?usp=sharing">
1616
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
1717
</a>

examples/timm_models/prune_timm_models.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from timm.models.vision_transformer import Attention
1010
import torch_pruning as tp
1111
import argparse
12+
from typing import Optional, Type
13+
from timm.models.vision_transformer import maybe_add_mask
1214

1315
parser = argparse.ArgumentParser(description='Prune timm models')
1416
parser.add_argument('--model', default=None, type=str, help='model name')
@@ -18,9 +20,11 @@
1820
parser.add_argument('--list_models', default=False, action='store_true', help='list all models in timm')
1921
args = parser.parse_args()
2022

21-
22-
def forward(self, x):
23-
"""https://github.com/huggingface/pytorch-image-models/blob/054c763fcaa7d241564439ae05fbe919ed85e614/timm/models/vision_transformer.py#L79"""
23+
def forward(
24+
self,
25+
x: torch.Tensor,
26+
attn_mask: Optional[torch.Tensor] = None,
27+
) -> torch.Tensor:
2428
B, N, C = x.shape
2529
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
2630
q, k, v = qkv.unbind(0)
@@ -29,21 +33,23 @@ def forward(self, x):
2933
if self.fused_attn:
3034
x = F.scaled_dot_product_attention(
3135
q, k, v,
32-
dropout_p=self.attn_drop.p,
36+
attn_mask=attn_mask,
37+
dropout_p=self.attn_drop.p if self.training else 0.,
3338
)
3439
else:
3540
q = q * self.scale
3641
attn = q @ k.transpose(-2, -1)
42+
attn = maybe_add_mask(attn, attn_mask)
3743
attn = attn.softmax(dim=-1)
3844
attn = self.attn_drop(attn)
3945
x = attn @ v
4046

41-
x = x.transpose(1, 2).reshape(B, N, -1) # original implementation: x = x.transpose(1, 2).reshape(B, N, C)
47+
x = x.transpose(1, 2).reshape(B, N, -1)
48+
x = self.norm(x)
4249
x = self.proj(x)
4350
x = self.proj_drop(x)
4451
return x
4552

46-
4753
def main():
4854
timm_models = timm.list_models()
4955
if args.list_models:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setuptools.setup(
77
name="torch-pruning",
8-
version="v1.5.3",
8+
version="v1.6.0",
99
author="Gongfan Fang",
1010
author_email="[email protected]",
1111
description="Towards Any Structural Pruning",

torch_pruning/_helpers.py

Lines changed: 0 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -42,91 +42,6 @@ def is_scalar(x):
4242
return False
4343
return False
4444

45-
46-
class _FlattenIndexMapping(object):
47-
def __init__(self, stride=1, reverse=False):
48-
self._stride = stride
49-
self.reverse = reverse
50-
51-
def __call__(self, idxs: _HybridIndex):
52-
new_idxs = []
53-
54-
if self.reverse == True:
55-
for i in idxs:
56-
new_idxs.append( _HybridIndex( idx = (i.idx // self._stride), root_idx=i.root_idx ) )
57-
new_idxs = list(set(new_idxs))
58-
else:
59-
for i in idxs:
60-
new_idxs.extend(
61-
[ _HybridIndex(idx=k, root_idx=i.root_idx) for k in range(i.idx * self._stride, (i.idx + 1) * self._stride) ]
62-
)
63-
return new_idxs
64-
65-
66-
class _ConcatIndexMapping(object):
67-
def __init__(self, offset, reverse=False):
68-
self.offset = offset
69-
self.reverse = reverse
70-
71-
def __call__(self, idxs: _HybridIndex):
72-
if self.reverse == True:
73-
new_idxs = [
74-
_HybridIndex(idx = i.idx - self.offset[0], root_idx=i.root_idx )
75-
for i in idxs
76-
if (i.idx >= self.offset[0] and i.idx < self.offset[1])
77-
]
78-
else:
79-
new_idxs = [ _HybridIndex(idx=i.idx + self.offset[0], root_idx=i.root_idx) for i in idxs]
80-
return new_idxs
81-
82-
class _GQAIndexMapping(object):
83-
def __init__(self, repeat, head_dim, reverse=False):
84-
self.repeat = repeat
85-
self.reverse = reverse
86-
self.head_dim = head_dim
87-
88-
def __call__(self, idxs: _HybridIndex):
89-
head_dim = self.head_dim
90-
repeat = self.repeat
91-
if self.reverse == True:
92-
new_idxs = [ _HybridIndex(idx=( i.idx - i.idx // (head_dim * repeat) * head_dim * (repeat - 1) - i.idx//head_dim%repeat * head_dim ), root_idx=None) for i in idxs ]
93-
else:
94-
new_idxs = []
95-
96-
return new_idxs
97-
98-
class _SliceIndexMapping(object):
99-
def __init__(self, dim, start, step, end, reverse=False):
100-
self.start = start
101-
self.step = step
102-
self.end = end
103-
self.reverse = reverse
104-
self.dim = dim
105-
106-
def __call__(self, idxs: _HybridIndex):
107-
108-
if self.reverse == True:
109-
new_idxs = [ _HybridIndex(idx=i.idx * self.step + self.start, root_idx=i.root_idx) for i in idxs]
110-
else:
111-
new_idxs = [ _HybridIndex(idx=(i.idx - self.start) // self.step, root_idx=i.root_idx) for i in idxs if (i.idx >= self.start and i.idx < self.end and (i.idx-self.start)%self.step==0) ]
112-
return new_idxs
113-
114-
class _SplitIndexMapping(object):
115-
def __init__(self, offset, reverse=False):
116-
self.offset = offset
117-
self.reverse = reverse
118-
119-
def __call__(self, idxs: _HybridIndex):
120-
if self.reverse == True:
121-
new_idxs = [ _HybridIndex(idx=i.idx + self.offset[0], root_idx=i.root_idx) for i in idxs]
122-
else:
123-
new_idxs = [
124-
_HybridIndex(idx = i.idx - self.offset[0], root_idx=i.root_idx)
125-
for i in idxs
126-
if (i.idx >= self.offset[0] and i.idx < self.offset[1])
127-
]
128-
return new_idxs
129-
13045
class ScalarSum:
13146
def __init__(self):
13247
self._results = {}

0 commit comments

Comments
 (0)