Skip to content

Commit 0c31907

Browse files
ebsmotherskrammnickrammnic
authored
Remove unused FSDP components (#2016)
Co-authored-by: krammnic <[email protected]> Co-authored-by: Mark <[email protected]>
1 parent ac14e96 commit 0c31907

File tree

12 files changed

+69
-586
lines changed

12 files changed

+69
-586
lines changed

docs/source/api_ref_modules.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ PEFT Components
7777
peft.set_trainable_params
7878
peft.get_adapter_state_dict
7979
peft.validate_missing_and_unexpected_for_lora
80-
peft.validate_state_dict_for_lora
8180
peft.disable_adapter
8281

8382

docs/source/api_ref_training.rst

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,9 @@ Utilities for enabling and working with distributed training.
5050
:toctree: generated/
5151
:nosignatures:
5252

53-
FSDPPolicyType
5453
init_distributed
5554
is_distributed
5655
get_world_size_and_rank
57-
get_full_finetune_fsdp_wrap_policy
58-
lora_fsdp_wrap_policy
5956
gather_cpu_state_dict
6057

6158
.. _ac_label:

docs/source/tutorials/lora_finetune.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,7 @@ model without any wrappers or custom checkpoint conversion logic.
205205
206206
.. note::
207207
Whenever loading weights with :code:`strict=False`, you should verify that any missing or extra keys in
208-
the loaded :code:`state_dict` are as expected. torchtune's LoRA recipes do this by default via e.g.
209-
:func:`validate_state_dict_for_lora() <torchtune.modules.peft.validate_state_dict_for_lora>` or
208+
the loaded :code:`state_dict` are as expected. torchtune's LoRA recipes do this by default via
210209
:func:`validate_missing_and_unexpected_for_lora() <torchtune.modules.peft.validate_missing_and_unexpected_for_lora>`.
211210

212211
Once we've loaded the base model weights, we also want to set only LoRA parameters to trainable.

docs/source/tutorials/qat_finetune.rst

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,6 @@ modifications accordingly:
168168
fake_quant_after_n_steps: 1000
169169
memory_efficient_fsdp_wrap: False
170170
171-
.. note::
172-
173-
QAT in torchtune is currently not compatible with `memory_efficient_fsdp_wrap <https://pytorch.org/torchtune/stable/generated/torchtune.utils.get_full_finetune_fsdp_wrap_policy.html#torchtune.utils.get_full_finetune_fsdp_wrap_policy>`_.
174-
This is a known issue and will be fixed in a future torchtune version.
175-
176171
Empirically, we observed that disabling fake quantization for the first N steps
177172
led to better results, presumably because doing so allows the weights to stabilize
178173
before we start introducing quantization noise to the fine-tuning process.

recipes/lora_dpo_single_device.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
get_merged_lora_ckpt,
2828
set_trainable_params,
2929
validate_missing_and_unexpected_for_lora,
30-
validate_state_dict_for_lora,
3130
)
3231
from torchtune.recipe_interfaces import FTRecipeInterface
3332

@@ -271,19 +270,6 @@ def _setup_model(
271270
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
272271
)
273272

274-
validate_state_dict_for_lora(
275-
lora_attn_modules=cfg_model.lora_attn_modules,
276-
apply_lora_to_mlp=cfg_model.apply_lora_to_mlp,
277-
apply_lora_to_output=getattr(cfg_model, "apply_lora_to_output", False),
278-
full_model_state_dict_keys=model.state_dict().keys(),
279-
lora_state_dict_keys=(
280-
lora_weights_state_dict.keys()
281-
if lora_weights_state_dict is not None
282-
else None
283-
),
284-
base_model_state_dict_keys=base_model_state_dict.keys(),
285-
)
286-
287273
base_missing, base_unexpected = model.load_state_dict(
288274
base_model_state_dict, strict=False
289275
)

tests/torchtune/modules/peft/test_utils.py

Lines changed: 62 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
LoRALinear,
2222
set_trainable_params,
2323
validate_missing_and_unexpected_for_lora,
24-
validate_state_dict_for_lora,
2524
)
2625

2726
N_LAYERS = 3
@@ -261,9 +260,10 @@ def test_set_trainable_params(
261260
lora_attn_modules,
262261
apply_lora_to_mlp,
263262
apply_lora_to_output,
264-
full_model_state_dict_keys,
265-
lora_state_dict_keys,
266-
base_model_state_dict_keys,
263+
base_missing,
264+
base_unexpected,
265+
lora_missing,
266+
lora_unexpected,
267267
expected
268268
"""
269269
),
@@ -272,188 +272,117 @@ def test_set_trainable_params(
272272
["q_proj", "k_proj"],
273273
False,
274274
False,
275-
["q_proj.lora_a.weight", "dummy_param.weight"],
276275
["q_proj.lora_a.weight"],
276+
[],
277277
["dummy_param.weight"],
278+
[],
278279
"",
279280
),
280-
(
281-
["v_proj"],
282-
False,
283-
False,
284-
["param_a", "param_b"],
285-
None,
286-
["param_a", "param_b"],
287-
"",
288-
),
281+
(["v_proj"], False, False, [], [], ["param_a", "param_b"], [], ""),
289282
(
290283
["output_proj"],
291284
False,
292285
True,
293-
["output_proj.weight", "output_proj.lora_a.weight"],
294286
["output_proj.lora_a.weight"],
287+
[],
295288
["output_proj.weight"],
289+
[],
296290
"",
297291
),
298-
(["q_proj"], False, False, ["param_a"], [], [], "Missing non-LoRA"),
299292
(
300-
["k_proj", "output_proj"],
293+
["q_proj"],
301294
False,
302-
True,
303-
["k_proj.lora_a.weight", "param_a"],
304-
["k_proj.lora_a.weight", "param_a"],
295+
False,
296+
["param_a"],
297+
[],
305298
["param_a"],
306-
"found in LoRA",
299+
[],
300+
"Missing non-LoRA",
307301
),
308302
(
309-
["k_proj"],
310-
False,
303+
["k_proj", "output_proj"],
311304
False,
312-
["k_proj.lora_a.weight"],
305+
True,
306+
[],
313307
[],
314308
["k_proj.lora_a.weight"],
315-
"found in base model",
309+
[],
310+
"Missing LoRA key",
316311
),
317312
(
318-
["k_proj"],
319-
False,
313+
["q_proj", "k_proj"],
314+
True,
320315
False,
321-
["k_proj.lora_a.weight"],
316+
["k_proj.lora"],
317+
[],
318+
["q_proj.lora"],
322319
[],
323-
None,
324320
"Missing LoRA",
325321
),
326-
(["q_proj"], False, False, [], ["a"], ["a"], "overlapping"),
327-
(
328-
["v_proj"],
329-
False,
330-
False,
331-
["dummy_param.weight"],
332-
["v_proj.lora_a.weight"],
333-
["dummy_param.weight"],
334-
"Extra",
335-
),
336322
(
337-
["w1", "w2", "w3"],
323+
["q_proj", "k_proj"],
338324
True,
339325
False,
340-
["w1.lora_a.weight", "w2.weight", "q_proj.weight"],
341-
["w1.lora_a.weight"],
342-
["q_proj.weight"],
343-
"Missing non-LoRA key",
326+
["k_proj.lora"],
327+
[],
328+
["q_proj.magnitude"],
329+
[],
330+
"Missing LoRA",
344331
),
345332
(
346-
["q_proj", "output"],
347-
False,
333+
["q_proj", "k_proj"],
348334
True,
349-
[
350-
"q_proj.lora_a",
351-
"output.weight",
352-
"output.lora_a",
353-
"output_proj.lora_b",
354-
],
355-
["q_proj.lora_a", "output.lora_a", "output_proj.lora_b"],
356-
["output.weight"],
357-
"Missing non-LoRA key",
358-
),
359-
(
360-
["q_proj", "v_proj"],
361-
False,
362335
False,
363-
"lora_llama2_model_all_keys",
364-
"lora_llama2_expected_adapter_keys",
365-
"lora_llama2_expected_base_model_keys",
366-
"",
336+
["output_proj.lora"],
337+
[],
338+
["q_proj.lora"],
339+
[],
340+
"Missing non-LoRA",
367341
),
368342
(
369-
["q_proj", "v_proj"],
370-
False,
343+
["q_proj", "k_proj"],
344+
True,
371345
False,
372-
"dora_llama2_model_all_keys",
373-
"dora_llama2_expected_adapter_keys",
374-
"lora_llama2_expected_base_model_keys",
375-
"",
376-
),
377-
],
378-
)
379-
def test_validate_lora_state_dict(
380-
self,
381-
request,
382-
lora_attn_modules,
383-
apply_lora_to_mlp,
384-
apply_lora_to_output,
385-
full_model_state_dict_keys,
386-
lora_state_dict_keys,
387-
base_model_state_dict_keys,
388-
expected,
389-
):
390-
if isinstance(full_model_state_dict_keys, str):
391-
full_model_state_dict_keys = request.getfixturevalue(
392-
full_model_state_dict_keys
393-
)
394-
if isinstance(lora_state_dict_keys, str):
395-
lora_state_dict_keys = request.getfixturevalue(lora_state_dict_keys)
396-
if isinstance(base_model_state_dict_keys, str):
397-
base_model_state_dict_keys = request.getfixturevalue(
398-
base_model_state_dict_keys
399-
)
400-
if expected:
401-
with pytest.raises(AssertionError, match=expected):
402-
validate_state_dict_for_lora(
403-
lora_attn_modules,
404-
apply_lora_to_mlp,
405-
apply_lora_to_output,
406-
full_model_state_dict_keys=full_model_state_dict_keys,
407-
lora_state_dict_keys=lora_state_dict_keys,
408-
base_model_state_dict_keys=base_model_state_dict_keys,
409-
)
410-
else:
411-
validate_state_dict_for_lora(
412-
lora_attn_modules,
413-
apply_lora_to_mlp,
414-
apply_lora_to_output,
415-
full_model_state_dict_keys=full_model_state_dict_keys,
416-
lora_state_dict_keys=lora_state_dict_keys,
417-
base_model_state_dict_keys=base_model_state_dict_keys,
418-
)
419-
420-
@pytest.mark.parametrize(
421-
(
422-
"""
423-
base_missing,
424-
base_unexpected,
425-
lora_missing,
426-
lora_unexpected,
427-
expected
428-
"""
429-
),
430-
[
431-
(["k_proj.lora"], [], ["q_proj.lora"], [], "Missing LoRA"),
432-
(["k_proj.lora"], [], ["q_proj.magnitude"], [], "Missing LoRA"),
433-
(["output_proj.lora"], [], ["q_proj.lora"], [], "Missing non-LoRA"),
434-
(
435346
["k_proj.lora"],
436347
["output.weight"],
437348
["q_proj.base_weight"],
438349
[],
439350
"loading base model",
440351
),
441352
(
353+
["q_proj", "k_proj"],
354+
True,
355+
False,
442356
["k_proj.lora"],
443357
[],
444358
["q_proj.base_weight"],
445359
["output.weight"],
446360
"loading adapter",
447361
),
448-
(["k_proj.lora"], [], ["q_proj.base_weight"], [], ""),
362+
(
363+
["q_proj", "k_proj"],
364+
True,
365+
False,
366+
["k_proj.lora"],
367+
[],
368+
["q_proj.base_weight"],
369+
[],
370+
"",
371+
),
449372
],
450373
)
451374
def test_validate_missing_and_unexpected_for_lora(
452-
self, base_missing, base_unexpected, lora_missing, lora_unexpected, expected
375+
self,
376+
lora_attn_modules,
377+
apply_lora_to_mlp,
378+
apply_lora_to_output,
379+
base_missing,
380+
base_unexpected,
381+
lora_missing,
382+
lora_unexpected,
383+
expected,
453384
):
454-
lora_attn_modules = ["q_proj", "k_proj"]
455-
apply_lora_to_mlp = True
456-
apply_lora_to_output = False
385+
457386
if expected:
458387
with pytest.raises(AssertionError, match=expected):
459388
validate_missing_and_unexpected_for_lora(

0 commit comments

Comments
 (0)