Skip to content

Commit 1e5f0d5

Browse files
felipemello1Felipe Mello
andauthored
LoRA typo fix + bias=True (#1881)
Co-authored-by: Felipe Mello <[email protected]>
1 parent ca37c59 commit 1e5f0d5

File tree

2 files changed

+145
-95
lines changed

2 files changed

+145
-95
lines changed

torchtune/models/clip/_component_builders.py

Lines changed: 110 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,26 @@
99

1010
import torch
1111
from torch import nn
12-
13-
from torchtune.modules.vision_transformer import VisionTransformer, CLSProjection
14-
from torchtune.models.clip._position_embeddings import TokenPositionalEmbedding, TiledTokenPositionalEmbedding, TilePositionalEmbedding
12+
from torchtune.models.clip._position_embeddings import (
13+
TiledTokenPositionalEmbedding,
14+
TilePositionalEmbedding,
15+
TokenPositionalEmbedding,
16+
)
1517

1618
from torchtune.modules import (
17-
TransformerSelfAttentionLayer,
19+
FeedForward,
20+
Fp32LayerNorm,
1821
MultiHeadAttention,
1922
TanhGate,
20-
FeedForward,
21-
Fp32LayerNorm
23+
TransformerSelfAttentionLayer,
2224
)
2325

2426
from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook
2527

2628
from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear
2729

30+
from torchtune.modules.vision_transformer import CLSProjection, VisionTransformer
31+
2832

2933
def clip_vision_encoder(
3034
tile_size: int,
@@ -43,7 +47,7 @@ def clip_vision_encoder(
4347
) -> VisionTransformer:
4448
"""
4549
Builds the vision encoder associated with the clip model. This includes:
46-
50+
4751
- TransformerEncoderLayer
4852
- positional embeddings
4953
- CLS projection (optional)
@@ -82,21 +86,25 @@ def clip_vision_encoder(
8286
"""
8387
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
8488

85-
cls_projection = CLSProjection(embed_dim=embed_dim, cls_output_dim=cls_output_dim) if output_cls_projection else None
89+
cls_projection = (
90+
CLSProjection(embed_dim=embed_dim, cls_output_dim=cls_output_dim)
91+
if output_cls_projection
92+
else None
93+
)
8694

8795
# transformer layer
8896
self_attn = MultiHeadAttention(
89-
embed_dim=embed_dim,
90-
num_heads=num_heads,
91-
num_kv_heads=num_heads,
92-
head_dim=embed_dim // num_heads,
93-
q_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
94-
k_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
95-
v_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
96-
output_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
97-
pos_embeddings=None,
98-
attn_dropout=0.0,
99-
is_causal=False,
97+
embed_dim=embed_dim,
98+
num_heads=num_heads,
99+
num_kv_heads=num_heads,
100+
head_dim=embed_dim // num_heads,
101+
q_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
102+
k_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
103+
v_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
104+
output_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
105+
pos_embeddings=None,
106+
attn_dropout=0.0,
107+
is_causal=False,
100108
)
101109
mlp = clip_mlp(
102110
in_dim=embed_dim,
@@ -107,8 +115,8 @@ def clip_vision_encoder(
107115
transformer_layer = TransformerSelfAttentionLayer(
108116
attn=self_attn,
109117
mlp=mlp,
110-
sa_norm= Fp32LayerNorm(embed_dim, eps=1e-5),
111-
mlp_norm= Fp32LayerNorm(embed_dim, eps=1e-5),
118+
sa_norm=Fp32LayerNorm(embed_dim, eps=1e-5),
119+
mlp_norm=Fp32LayerNorm(embed_dim, eps=1e-5),
112120
sa_scale=None,
113121
mlp_scale=None,
114122
)
@@ -118,17 +126,21 @@ def clip_vision_encoder(
118126
pre_tile_pos_embed = None
119127
post_tile_pos_embed = None
120128
token_pos_embedding = TokenPositionalEmbedding(
121-
embed_dim=embed_dim,
122-
patch_size=patch_size,
123-
tile_size=tile_size)
129+
embed_dim=embed_dim, patch_size=patch_size, tile_size=tile_size
130+
)
124131
else:
125-
pre_tile_pos_embed = TilePositionalEmbedding(max_num_tiles=max_num_tiles, embed_dim=embed_dim)
126-
post_tile_pos_embed = TilePositionalEmbedding(max_num_tiles=max_num_tiles, embed_dim=embed_dim)
132+
pre_tile_pos_embed = TilePositionalEmbedding(
133+
max_num_tiles=max_num_tiles, embed_dim=embed_dim
134+
)
135+
post_tile_pos_embed = TilePositionalEmbedding(
136+
max_num_tiles=max_num_tiles, embed_dim=embed_dim
137+
)
127138
token_pos_embedding = TiledTokenPositionalEmbedding(
128-
max_num_tiles=max_num_tiles,
129-
embed_dim=embed_dim,
130-
patch_size=patch_size,
131-
tile_size=tile_size)
139+
max_num_tiles=max_num_tiles,
140+
embed_dim=embed_dim,
141+
patch_size=patch_size,
142+
tile_size=tile_size,
143+
)
132144

133145
return VisionTransformer(
134146
num_layers=num_layers,
@@ -145,13 +157,29 @@ def clip_vision_encoder(
145157
)
146158

147159

148-
def clip_mlp(in_dim: int, out_dim: int, hidden_dim: int, activation: nn.Module, quantize_base: bool = False) -> FeedForward:
160+
def clip_mlp(
161+
in_dim: int,
162+
out_dim: int,
163+
hidden_dim: int,
164+
activation: nn.Module,
165+
quantize_base: bool = False,
166+
) -> FeedForward:
149167
"""
150168
Build the MLP layer associated with the clip model.
151169
"""
152-
gate_proj = nn.Linear(in_dim, hidden_dim) if not quantize_base else FrozenNF4Linear(in_dim, hidden_dim)
153-
down_proj = nn.Linear(hidden_dim, out_dim) if not quantize_base else FrozenNF4Linear(hidden_dim, out_dim)
154-
return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation)
170+
gate_proj = (
171+
nn.Linear(in_dim, hidden_dim)
172+
if not quantize_base
173+
else FrozenNF4Linear(in_dim, hidden_dim)
174+
)
175+
down_proj = (
176+
nn.Linear(hidden_dim, out_dim)
177+
if not quantize_base
178+
else FrozenNF4Linear(hidden_dim, out_dim)
179+
)
180+
return FeedForward(
181+
gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation
182+
)
155183

156184

157185
# ------------------ LoRA CLIP ------------------
@@ -222,42 +250,46 @@ def lora_clip_vision_encoder(
222250
quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base
223251
weights within linear layers LoRA is applied to. The final output linear projection is not
224252
supported for quantization currently.
225-
253+
226254
227255
Returns:
228256
VisionTransformer: Instantiation of VisionTransformer model.
229257
"""
230258
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
231259

232260
# TODO: add support for quantizing and LoRA for the final output projection
233-
cls_projection = CLSProjection(embed_dim=embed_dim, cls_output_dim=cls_output_dim) if output_cls_projection else None
261+
cls_projection = (
262+
CLSProjection(embed_dim=embed_dim, cls_output_dim=cls_output_dim)
263+
if output_cls_projection
264+
else None
265+
)
234266

235267
# transformer layer
236268
self_attn = lora_clip_attention(
237-
lora_modules=lora_modules,
238-
embed_dim=embed_dim,
239-
num_heads=num_heads,
240-
num_kv_heads=num_heads,
241-
head_dim=embed_dim // num_heads,
242-
attn_dropout=0.0,
269+
lora_modules=lora_modules,
270+
embed_dim=embed_dim,
271+
num_heads=num_heads,
272+
num_kv_heads=num_heads,
273+
head_dim=embed_dim // num_heads,
274+
attn_dropout=0.0,
275+
lora_rank=lora_rank,
276+
lora_alpha=lora_alpha,
277+
lora_dropout=lora_dropout,
278+
use_dora=use_dora,
279+
quantize_base=quantize_base,
280+
)
281+
if apply_lora_to_mlp:
282+
mlp = lora_clip_mlp(
283+
in_dim=embed_dim,
284+
hidden_dim=4 * embed_dim,
285+
out_dim=embed_dim,
286+
activation=activation(),
243287
lora_rank=lora_rank,
244288
lora_alpha=lora_alpha,
289+
quantize_base=quantize_base,
245290
lora_dropout=lora_dropout,
246291
use_dora=use_dora,
247-
quantize_base=quantize_base,
248-
)
249-
if apply_lora_to_mlp:
250-
mlp = lora_clip_mlp(
251-
in_dim=embed_dim,
252-
hidden_dim=4 * embed_dim,
253-
out_dim=embed_dim,
254-
activation=activation(),
255-
lora_rank=lora_rank,
256-
lora_alpha=lora_alpha,
257-
quantize_base=quantize_base,
258-
lora_dropout=lora_dropout,
259-
use_dora=use_dora,
260-
)
292+
)
261293
else:
262294
mlp = clip_mlp(
263295
in_dim=embed_dim,
@@ -269,8 +301,8 @@ def lora_clip_vision_encoder(
269301
transformer_layer = TransformerSelfAttentionLayer(
270302
attn=self_attn,
271303
mlp=mlp,
272-
sa_norm= Fp32LayerNorm(embed_dim, eps=1e-5),
273-
mlp_norm= Fp32LayerNorm(embed_dim, eps=1e-5),
304+
sa_norm=Fp32LayerNorm(embed_dim, eps=1e-5),
305+
mlp_norm=Fp32LayerNorm(embed_dim, eps=1e-5),
274306
sa_scale=None,
275307
mlp_scale=None,
276308
)
@@ -280,17 +312,21 @@ def lora_clip_vision_encoder(
280312
pre_tile_pos_embed = None
281313
post_tile_pos_embed = None
282314
token_pos_embedding = TokenPositionalEmbedding(
283-
embed_dim=embed_dim,
284-
patch_size=patch_size,
285-
tile_size=tile_size)
315+
embed_dim=embed_dim, patch_size=patch_size, tile_size=tile_size
316+
)
286317
else:
287-
pre_tile_pos_embed = TilePositionalEmbedding(max_num_tiles=max_num_tiles, embed_dim=embed_dim)
288-
post_tile_pos_embed = TilePositionalEmbedding(max_num_tiles=max_num_tiles, embed_dim=embed_dim)
318+
pre_tile_pos_embed = TilePositionalEmbedding(
319+
max_num_tiles=max_num_tiles, embed_dim=embed_dim
320+
)
321+
post_tile_pos_embed = TilePositionalEmbedding(
322+
max_num_tiles=max_num_tiles, embed_dim=embed_dim
323+
)
289324
token_pos_embedding = TiledTokenPositionalEmbedding(
290-
max_num_tiles=max_num_tiles,
291-
embed_dim=embed_dim,
292-
patch_size=patch_size,
293-
tile_size=tile_size)
325+
max_num_tiles=max_num_tiles,
326+
embed_dim=embed_dim,
327+
patch_size=patch_size,
328+
tile_size=tile_size,
329+
)
294330

295331
model = VisionTransformer(
296332
num_layers=num_layers,
@@ -467,19 +503,23 @@ def lora_clip_mlp(
467503
"""
468504
adapter_cls = DoRALinear if use_dora else LoRALinear
469505
gate_proj = adapter_cls(
470-
in_dim=dim,
506+
in_dim=in_dim,
471507
out_dim=hidden_dim,
472508
rank=lora_rank,
473509
alpha=lora_alpha,
474510
dropout=lora_dropout,
475511
quantize_base=quantize_base,
512+
use_bias=True,
476513
)
477514
down_proj = adapter_cls(
478515
in_dim=hidden_dim,
479-
out_dim=dim,
516+
out_dim=out_dim,
480517
rank=lora_rank,
481518
alpha=lora_alpha,
482519
dropout=lora_dropout,
483520
quantize_base=quantize_base,
521+
use_bias=True,
522+
)
523+
return FeedForward(
524+
gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation
484525
)
485-
return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation)

0 commit comments

Comments
 (0)