Skip to content

Update lora_quantization_layers.py #10876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from

Conversation

tugang-baidu
Copy link

Fix parallel QLoRA in reference to paddlenlp/quantization/quantization_linear.py

The original lora_quantization_layers.py is in paddlenlp/peft/lora

In class QuantizationLoRABaseLinear:
In method init:
insert codes '
self.state = 0
if self.weight_quantize_algo in ["a8w8linear", "a8w4linear", "fp8linear"]:
self.act_scale = self.create_parameter(
shape=[1],
dtype=self._dtype,
is_bias=False,
default_initializer=nn.initializer.Constant(value=0.0),
)
self.act_scale.is_distributed = False
self.act_scale.stop_gradient = True
self.group = get_act_scale_group(is_row=True)
else:
raise NotImplementedError(
f"Not supported weight_quantize_algo {self.weight_quantize_algo}"
)
'
between 'self.bias = layer.bias' and 'self.lora_config = lora_config'

In method forward:
insert 'act_state=(self.state, self.training, self.act_scale, self.group)' in the parameter list of 'output=quant_weight_linear'

insert codes '
if self.training:
self.state += 1
'
before 'return output'

However, after such change, in different cases, I found that loss would start to converge with different beginnings (I have deleted all checkpoints every time I start a new case):

  1. No parallelism: 7.68577909
c513e1950276abb2b9ce3ef3318fd588
  1. tensor parallelism with paddle.distributed.launch: 14.6875057
06e7970ba2a23e10e79c7991343f2855
  1. paddle.distributed.launch: 7.88057899
862df2c2c3f946fd37bb80116a21a2cd
  1. pipeline parallelism with paddle.distributed.launch: 13.96667099
a8a56f2950f844f80888f6c3a2f0e6da

Fix parallel QLoRA in reference to paddlenlp/quantization/quantization_linear.py

In class QuantizationLoRABaseLinear:
In method __init__:
insert codes '
        self.state = 0
        if self.weight_quantize_algo in ["a8w8linear", "a8w4linear", "fp8linear"]:
            self.act_scale = self.create_parameter(
                shape=[1],
                dtype=self._dtype,
                is_bias=False,
                default_initializer=nn.initializer.Constant(value=0.0),
            )
            self.act_scale.is_distributed = False
            self.act_scale.stop_gradient = True
            self.group = get_act_scale_group(is_row=True)
        else:
            raise NotImplementedError(
                f"Not supported weight_quantize_algo {self.weight_quantize_algo}"
            )
'
between 'self.bias = layer.bias' and 'self.lora_config = lora_config'

In method forward:
insert 'act_state=(self.state, self.training, self.act_scale, self.group)' in the parameter list of 'output=quant_weight_linear'

insert codes '
        if self.training:
            self.state += 1
'
before 'return output'
Copy link

paddle-bot bot commented Jul 21, 2025

Thanks for your contribution!

@CLAassistant
Copy link

CLAassistant commented Jul 21, 2025

CLA assistant check
All committers have signed the CLA.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants