2121 LoRALinear ,
2222 set_trainable_params ,
2323 validate_missing_and_unexpected_for_lora ,
24- validate_state_dict_for_lora ,
2524)
2625
2726N_LAYERS = 3
@@ -261,9 +260,10 @@ def test_set_trainable_params(
261260 lora_attn_modules,
262261 apply_lora_to_mlp,
263262 apply_lora_to_output,
264- full_model_state_dict_keys,
265- lora_state_dict_keys,
266- base_model_state_dict_keys,
263+ base_missing,
264+ base_unexpected,
265+ lora_missing,
266+ lora_unexpected,
267267 expected
268268 """
269269 ),
@@ -272,188 +272,117 @@ def test_set_trainable_params(
272272 ["q_proj" , "k_proj" ],
273273 False ,
274274 False ,
275- ["q_proj.lora_a.weight" , "dummy_param.weight" ],
276275 ["q_proj.lora_a.weight" ],
276+ [],
277277 ["dummy_param.weight" ],
278+ [],
278279 "" ,
279280 ),
280- (
281- ["v_proj" ],
282- False ,
283- False ,
284- ["param_a" , "param_b" ],
285- None ,
286- ["param_a" , "param_b" ],
287- "" ,
288- ),
281+ (["v_proj" ], False , False , [], [], ["param_a" , "param_b" ], [], "" ),
289282 (
290283 ["output_proj" ],
291284 False ,
292285 True ,
293- ["output_proj.weight" , "output_proj.lora_a.weight" ],
294286 ["output_proj.lora_a.weight" ],
287+ [],
295288 ["output_proj.weight" ],
289+ [],
296290 "" ,
297291 ),
298- (["q_proj" ], False , False , ["param_a" ], [], [], "Missing non-LoRA" ),
299292 (
300- ["k_proj" , "output_proj " ],
293+ ["q_proj " ],
301294 False ,
302- True ,
303- ["k_proj.lora_a.weight" , " param_a" ],
304- ["k_proj.lora_a.weight" , "param_a" ],
295+ False ,
296+ ["param_a" ],
297+ [],
305298 ["param_a" ],
306- "found in LoRA" ,
299+ [],
300+ "Missing non-LoRA" ,
307301 ),
308302 (
309- ["k_proj" ],
310- False ,
303+ ["k_proj" , "output_proj" ],
311304 False ,
312- ["k_proj.lora_a.weight" ],
305+ True ,
306+ [],
313307 [],
314308 ["k_proj.lora_a.weight" ],
315- "found in base model" ,
309+ [],
310+ "Missing LoRA key" ,
316311 ),
317312 (
318- ["k_proj" ],
319- False ,
313+ ["q_proj" , " k_proj" ],
314+ True ,
320315 False ,
321- ["k_proj.lora_a.weight" ],
316+ ["k_proj.lora" ],
317+ [],
318+ ["q_proj.lora" ],
322319 [],
323- None ,
324320 "Missing LoRA" ,
325321 ),
326- (["q_proj" ], False , False , [], ["a" ], ["a" ], "overlapping" ),
327- (
328- ["v_proj" ],
329- False ,
330- False ,
331- ["dummy_param.weight" ],
332- ["v_proj.lora_a.weight" ],
333- ["dummy_param.weight" ],
334- "Extra" ,
335- ),
336322 (
337- ["w1 " , "w2" , "w3 " ],
323+ ["q_proj " , "k_proj " ],
338324 True ,
339325 False ,
340- ["w1.lora_a.weight" , "w2.weight" , "q_proj.weight" ],
341- ["w1.lora_a.weight" ],
342- ["q_proj.weight" ],
343- "Missing non-LoRA key" ,
326+ ["k_proj.lora" ],
327+ [],
328+ ["q_proj.magnitude" ],
329+ [],
330+ "Missing LoRA" ,
344331 ),
345332 (
346- ["q_proj" , "output" ],
347- False ,
333+ ["q_proj" , "k_proj" ],
348334 True ,
349- [
350- "q_proj.lora_a" ,
351- "output.weight" ,
352- "output.lora_a" ,
353- "output_proj.lora_b" ,
354- ],
355- ["q_proj.lora_a" , "output.lora_a" , "output_proj.lora_b" ],
356- ["output.weight" ],
357- "Missing non-LoRA key" ,
358- ),
359- (
360- ["q_proj" , "v_proj" ],
361- False ,
362335 False ,
363- "lora_llama2_model_all_keys" ,
364- "lora_llama2_expected_adapter_keys" ,
365- "lora_llama2_expected_base_model_keys" ,
366- "" ,
336+ ["output_proj.lora" ],
337+ [],
338+ ["q_proj.lora" ],
339+ [],
340+ "Missing non-LoRA" ,
367341 ),
368342 (
369- ["q_proj" , "v_proj " ],
370- False ,
343+ ["q_proj" , "k_proj " ],
344+ True ,
371345 False ,
372- "dora_llama2_model_all_keys" ,
373- "dora_llama2_expected_adapter_keys" ,
374- "lora_llama2_expected_base_model_keys" ,
375- "" ,
376- ),
377- ],
378- )
379- def test_validate_lora_state_dict (
380- self ,
381- request ,
382- lora_attn_modules ,
383- apply_lora_to_mlp ,
384- apply_lora_to_output ,
385- full_model_state_dict_keys ,
386- lora_state_dict_keys ,
387- base_model_state_dict_keys ,
388- expected ,
389- ):
390- if isinstance (full_model_state_dict_keys , str ):
391- full_model_state_dict_keys = request .getfixturevalue (
392- full_model_state_dict_keys
393- )
394- if isinstance (lora_state_dict_keys , str ):
395- lora_state_dict_keys = request .getfixturevalue (lora_state_dict_keys )
396- if isinstance (base_model_state_dict_keys , str ):
397- base_model_state_dict_keys = request .getfixturevalue (
398- base_model_state_dict_keys
399- )
400- if expected :
401- with pytest .raises (AssertionError , match = expected ):
402- validate_state_dict_for_lora (
403- lora_attn_modules ,
404- apply_lora_to_mlp ,
405- apply_lora_to_output ,
406- full_model_state_dict_keys = full_model_state_dict_keys ,
407- lora_state_dict_keys = lora_state_dict_keys ,
408- base_model_state_dict_keys = base_model_state_dict_keys ,
409- )
410- else :
411- validate_state_dict_for_lora (
412- lora_attn_modules ,
413- apply_lora_to_mlp ,
414- apply_lora_to_output ,
415- full_model_state_dict_keys = full_model_state_dict_keys ,
416- lora_state_dict_keys = lora_state_dict_keys ,
417- base_model_state_dict_keys = base_model_state_dict_keys ,
418- )
419-
420- @pytest .mark .parametrize (
421- (
422- """
423- base_missing,
424- base_unexpected,
425- lora_missing,
426- lora_unexpected,
427- expected
428- """
429- ),
430- [
431- (["k_proj.lora" ], [], ["q_proj.lora" ], [], "Missing LoRA" ),
432- (["k_proj.lora" ], [], ["q_proj.magnitude" ], [], "Missing LoRA" ),
433- (["output_proj.lora" ], [], ["q_proj.lora" ], [], "Missing non-LoRA" ),
434- (
435346 ["k_proj.lora" ],
436347 ["output.weight" ],
437348 ["q_proj.base_weight" ],
438349 [],
439350 "loading base model" ,
440351 ),
441352 (
353+ ["q_proj" , "k_proj" ],
354+ True ,
355+ False ,
442356 ["k_proj.lora" ],
443357 [],
444358 ["q_proj.base_weight" ],
445359 ["output.weight" ],
446360 "loading adapter" ,
447361 ),
448- (["k_proj.lora" ], [], ["q_proj.base_weight" ], [], "" ),
362+ (
363+ ["q_proj" , "k_proj" ],
364+ True ,
365+ False ,
366+ ["k_proj.lora" ],
367+ [],
368+ ["q_proj.base_weight" ],
369+ [],
370+ "" ,
371+ ),
449372 ],
450373 )
451374 def test_validate_missing_and_unexpected_for_lora (
452- self , base_missing , base_unexpected , lora_missing , lora_unexpected , expected
375+ self ,
376+ lora_attn_modules ,
377+ apply_lora_to_mlp ,
378+ apply_lora_to_output ,
379+ base_missing ,
380+ base_unexpected ,
381+ lora_missing ,
382+ lora_unexpected ,
383+ expected ,
453384 ):
454- lora_attn_modules = ["q_proj" , "k_proj" ]
455- apply_lora_to_mlp = True
456- apply_lora_to_output = False
385+
457386 if expected :
458387 with pytest .raises (AssertionError , match = expected ):
459388 validate_missing_and_unexpected_for_lora (
0 commit comments