99
1010import torch
1111from torch import nn
12-
13- from torchtune .modules .vision_transformer import VisionTransformer , CLSProjection
14- from torchtune .models .clip ._position_embeddings import TokenPositionalEmbedding , TiledTokenPositionalEmbedding , TilePositionalEmbedding
12+ from torchtune .models .clip ._position_embeddings import (
13+ TiledTokenPositionalEmbedding ,
14+ TilePositionalEmbedding ,
15+ TokenPositionalEmbedding ,
16+ )
1517
1618from torchtune .modules import (
17- TransformerSelfAttentionLayer ,
19+ FeedForward ,
20+ Fp32LayerNorm ,
1821 MultiHeadAttention ,
1922 TanhGate ,
20- FeedForward ,
21- Fp32LayerNorm
23+ TransformerSelfAttentionLayer ,
2224)
2325
2426from torchtune .modules .common_utils import reparametrize_as_dtype_state_dict_post_hook
2527
2628from torchtune .modules .peft import DoRALinear , LORA_ATTN_MODULES , LoRALinear
2729
30+ from torchtune .modules .vision_transformer import CLSProjection , VisionTransformer
31+
2832
2933def clip_vision_encoder (
3034 tile_size : int ,
@@ -43,7 +47,7 @@ def clip_vision_encoder(
4347) -> VisionTransformer :
4448 """
4549 Builds the vision encoder associated with the clip model. This includes:
46-
50+
4751 - TransformerEncoderLayer
4852 - positional embeddings
4953 - CLS projection (optional)
@@ -82,21 +86,25 @@ def clip_vision_encoder(
8286 """
8387 assert embed_dim % num_heads == 0 , "embed_dim must be divisible by num_heads"
8488
85- cls_projection = CLSProjection (embed_dim = embed_dim , cls_output_dim = cls_output_dim ) if output_cls_projection else None
89+ cls_projection = (
90+ CLSProjection (embed_dim = embed_dim , cls_output_dim = cls_output_dim )
91+ if output_cls_projection
92+ else None
93+ )
8694
8795 # transformer layer
8896 self_attn = MultiHeadAttention (
89- embed_dim = embed_dim ,
90- num_heads = num_heads ,
91- num_kv_heads = num_heads ,
92- head_dim = embed_dim // num_heads ,
93- q_proj = nn .Linear (embed_dim , embed_dim , bias = attn_bias ),
94- k_proj = nn .Linear (embed_dim , embed_dim , bias = attn_bias ),
95- v_proj = nn .Linear (embed_dim , embed_dim , bias = attn_bias ),
96- output_proj = nn .Linear (embed_dim , embed_dim , bias = attn_bias ),
97- pos_embeddings = None ,
98- attn_dropout = 0.0 ,
99- is_causal = False ,
97+ embed_dim = embed_dim ,
98+ num_heads = num_heads ,
99+ num_kv_heads = num_heads ,
100+ head_dim = embed_dim // num_heads ,
101+ q_proj = nn .Linear (embed_dim , embed_dim , bias = attn_bias ),
102+ k_proj = nn .Linear (embed_dim , embed_dim , bias = attn_bias ),
103+ v_proj = nn .Linear (embed_dim , embed_dim , bias = attn_bias ),
104+ output_proj = nn .Linear (embed_dim , embed_dim , bias = attn_bias ),
105+ pos_embeddings = None ,
106+ attn_dropout = 0.0 ,
107+ is_causal = False ,
100108 )
101109 mlp = clip_mlp (
102110 in_dim = embed_dim ,
@@ -107,8 +115,8 @@ def clip_vision_encoder(
107115 transformer_layer = TransformerSelfAttentionLayer (
108116 attn = self_attn ,
109117 mlp = mlp ,
110- sa_norm = Fp32LayerNorm (embed_dim , eps = 1e-5 ),
111- mlp_norm = Fp32LayerNorm (embed_dim , eps = 1e-5 ),
118+ sa_norm = Fp32LayerNorm (embed_dim , eps = 1e-5 ),
119+ mlp_norm = Fp32LayerNorm (embed_dim , eps = 1e-5 ),
112120 sa_scale = None ,
113121 mlp_scale = None ,
114122 )
@@ -118,17 +126,21 @@ def clip_vision_encoder(
118126 pre_tile_pos_embed = None
119127 post_tile_pos_embed = None
120128 token_pos_embedding = TokenPositionalEmbedding (
121- embed_dim = embed_dim ,
122- patch_size = patch_size ,
123- tile_size = tile_size )
129+ embed_dim = embed_dim , patch_size = patch_size , tile_size = tile_size
130+ )
124131 else :
125- pre_tile_pos_embed = TilePositionalEmbedding (max_num_tiles = max_num_tiles , embed_dim = embed_dim )
126- post_tile_pos_embed = TilePositionalEmbedding (max_num_tiles = max_num_tiles , embed_dim = embed_dim )
132+ pre_tile_pos_embed = TilePositionalEmbedding (
133+ max_num_tiles = max_num_tiles , embed_dim = embed_dim
134+ )
135+ post_tile_pos_embed = TilePositionalEmbedding (
136+ max_num_tiles = max_num_tiles , embed_dim = embed_dim
137+ )
127138 token_pos_embedding = TiledTokenPositionalEmbedding (
128- max_num_tiles = max_num_tiles ,
129- embed_dim = embed_dim ,
130- patch_size = patch_size ,
131- tile_size = tile_size )
139+ max_num_tiles = max_num_tiles ,
140+ embed_dim = embed_dim ,
141+ patch_size = patch_size ,
142+ tile_size = tile_size ,
143+ )
132144
133145 return VisionTransformer (
134146 num_layers = num_layers ,
@@ -145,13 +157,29 @@ def clip_vision_encoder(
145157 )
146158
147159
148- def clip_mlp (in_dim : int , out_dim : int , hidden_dim : int , activation : nn .Module , quantize_base : bool = False ) -> FeedForward :
160+ def clip_mlp (
161+ in_dim : int ,
162+ out_dim : int ,
163+ hidden_dim : int ,
164+ activation : nn .Module ,
165+ quantize_base : bool = False ,
166+ ) -> FeedForward :
149167 """
150168 Build the MLP layer associated with the clip model.
151169 """
152- gate_proj = nn .Linear (in_dim , hidden_dim ) if not quantize_base else FrozenNF4Linear (in_dim , hidden_dim )
153- down_proj = nn .Linear (hidden_dim , out_dim ) if not quantize_base else FrozenNF4Linear (hidden_dim , out_dim )
154- return FeedForward (gate_proj = gate_proj , down_proj = down_proj , up_proj = None , activation = activation )
170+ gate_proj = (
171+ nn .Linear (in_dim , hidden_dim )
172+ if not quantize_base
173+ else FrozenNF4Linear (in_dim , hidden_dim )
174+ )
175+ down_proj = (
176+ nn .Linear (hidden_dim , out_dim )
177+ if not quantize_base
178+ else FrozenNF4Linear (hidden_dim , out_dim )
179+ )
180+ return FeedForward (
181+ gate_proj = gate_proj , down_proj = down_proj , up_proj = None , activation = activation
182+ )
155183
156184
157185# ------------------ LoRA CLIP ------------------
@@ -222,42 +250,46 @@ def lora_clip_vision_encoder(
222250 quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base
223251 weights within linear layers LoRA is applied to. The final output linear projection is not
224252 supported for quantization currently.
225-
253+
226254
227255 Returns:
228256 VisionTransformer: Instantiation of VisionTransformer model.
229257 """
230258 assert embed_dim % num_heads == 0 , "embed_dim must be divisible by num_heads"
231259
232260 # TODO: add support for quantizing and LoRA for the final output projection
233- cls_projection = CLSProjection (embed_dim = embed_dim , cls_output_dim = cls_output_dim ) if output_cls_projection else None
261+ cls_projection = (
262+ CLSProjection (embed_dim = embed_dim , cls_output_dim = cls_output_dim )
263+ if output_cls_projection
264+ else None
265+ )
234266
235267 # transformer layer
236268 self_attn = lora_clip_attention (
237- lora_modules = lora_modules ,
238- embed_dim = embed_dim ,
239- num_heads = num_heads ,
240- num_kv_heads = num_heads ,
241- head_dim = embed_dim // num_heads ,
242- attn_dropout = 0.0 ,
269+ lora_modules = lora_modules ,
270+ embed_dim = embed_dim ,
271+ num_heads = num_heads ,
272+ num_kv_heads = num_heads ,
273+ head_dim = embed_dim // num_heads ,
274+ attn_dropout = 0.0 ,
275+ lora_rank = lora_rank ,
276+ lora_alpha = lora_alpha ,
277+ lora_dropout = lora_dropout ,
278+ use_dora = use_dora ,
279+ quantize_base = quantize_base ,
280+ )
281+ if apply_lora_to_mlp :
282+ mlp = lora_clip_mlp (
283+ in_dim = embed_dim ,
284+ hidden_dim = 4 * embed_dim ,
285+ out_dim = embed_dim ,
286+ activation = activation (),
243287 lora_rank = lora_rank ,
244288 lora_alpha = lora_alpha ,
289+ quantize_base = quantize_base ,
245290 lora_dropout = lora_dropout ,
246291 use_dora = use_dora ,
247- quantize_base = quantize_base ,
248- )
249- if apply_lora_to_mlp :
250- mlp = lora_clip_mlp (
251- in_dim = embed_dim ,
252- hidden_dim = 4 * embed_dim ,
253- out_dim = embed_dim ,
254- activation = activation (),
255- lora_rank = lora_rank ,
256- lora_alpha = lora_alpha ,
257- quantize_base = quantize_base ,
258- lora_dropout = lora_dropout ,
259- use_dora = use_dora ,
260- )
292+ )
261293 else :
262294 mlp = clip_mlp (
263295 in_dim = embed_dim ,
@@ -269,8 +301,8 @@ def lora_clip_vision_encoder(
269301 transformer_layer = TransformerSelfAttentionLayer (
270302 attn = self_attn ,
271303 mlp = mlp ,
272- sa_norm = Fp32LayerNorm (embed_dim , eps = 1e-5 ),
273- mlp_norm = Fp32LayerNorm (embed_dim , eps = 1e-5 ),
304+ sa_norm = Fp32LayerNorm (embed_dim , eps = 1e-5 ),
305+ mlp_norm = Fp32LayerNorm (embed_dim , eps = 1e-5 ),
274306 sa_scale = None ,
275307 mlp_scale = None ,
276308 )
@@ -280,17 +312,21 @@ def lora_clip_vision_encoder(
280312 pre_tile_pos_embed = None
281313 post_tile_pos_embed = None
282314 token_pos_embedding = TokenPositionalEmbedding (
283- embed_dim = embed_dim ,
284- patch_size = patch_size ,
285- tile_size = tile_size )
315+ embed_dim = embed_dim , patch_size = patch_size , tile_size = tile_size
316+ )
286317 else :
287- pre_tile_pos_embed = TilePositionalEmbedding (max_num_tiles = max_num_tiles , embed_dim = embed_dim )
288- post_tile_pos_embed = TilePositionalEmbedding (max_num_tiles = max_num_tiles , embed_dim = embed_dim )
318+ pre_tile_pos_embed = TilePositionalEmbedding (
319+ max_num_tiles = max_num_tiles , embed_dim = embed_dim
320+ )
321+ post_tile_pos_embed = TilePositionalEmbedding (
322+ max_num_tiles = max_num_tiles , embed_dim = embed_dim
323+ )
289324 token_pos_embedding = TiledTokenPositionalEmbedding (
290- max_num_tiles = max_num_tiles ,
291- embed_dim = embed_dim ,
292- patch_size = patch_size ,
293- tile_size = tile_size )
325+ max_num_tiles = max_num_tiles ,
326+ embed_dim = embed_dim ,
327+ patch_size = patch_size ,
328+ tile_size = tile_size ,
329+ )
294330
295331 model = VisionTransformer (
296332 num_layers = num_layers ,
@@ -467,19 +503,23 @@ def lora_clip_mlp(
467503 """
468504 adapter_cls = DoRALinear if use_dora else LoRALinear
469505 gate_proj = adapter_cls (
470- in_dim = dim ,
506+ in_dim = in_dim ,
471507 out_dim = hidden_dim ,
472508 rank = lora_rank ,
473509 alpha = lora_alpha ,
474510 dropout = lora_dropout ,
475511 quantize_base = quantize_base ,
512+ use_bias = True ,
476513 )
477514 down_proj = adapter_cls (
478515 in_dim = hidden_dim ,
479- out_dim = dim ,
516+ out_dim = out_dim ,
480517 rank = lora_rank ,
481518 alpha = lora_alpha ,
482519 dropout = lora_dropout ,
483520 quantize_base = quantize_base ,
521+ use_bias = True ,
522+ )
523+ return FeedForward (
524+ gate_proj = gate_proj , down_proj = down_proj , up_proj = None , activation = activation
484525 )
485- return FeedForward (gate_proj = gate_proj , down_proj = down_proj , up_proj = None , activation = activation )
0 commit comments