- Federated learning algorithm: see Base Method for structure details.
- Client algorithm: see Base Client for structure details.
- Server algorithm: see Base Server for structure details.
🔩 Implementation a FedaAvg with Proximal Term
-
The main difference from the Federated Averaging algorithm is the introduction of a regularizer into the client's local loss function: $$ \ell_i(w, x, y) \rightarrow \ell_i(w, x, y) + \dfrac{\lambda}{2}|w_i - w_g|^2 $$
-
To do this, you need to redefine the
Client Algorithm:client.py --> fedprox_client.py
loss = super().get_loss_value(outputs, targets)
proximity = (
0.5
* self.fed_prox_lambda
* sum(
[
(p.float() - q.float()).norm() ** 2
for (_, p), (_, q) in zip(
self.model.state_dict().items(),
self.server_model_state.items(),
)
]
)
)
loss += proximity- We redefine
Client Algorithm, so, we need to updateclient_cls:
def _init_client_cls(self):
super()._init_client_cls()
self.client_cls = FedProxClient
self.client_kwargs["client_cls"] = self.client_cls
self.client_args.extend([self.fed_prox_lambda])- Let's also add a warmup parameter that specifies the round at which the proximal term is added.
- To do this, we need to pass it to clients
def get_communication_content(self, rank): # In fedprox we need additionaly send current round to warmup content = super().get_communication_content(rank) content["current_round"] = self.cur_round return content
- And process it on the client side.
def create_pipe_commands(self): pipe_commands_map = super().create_pipe_commands() pipe_commands_map["current_round"] = self.set_cur_round return pipe_commands_map def set_cur_round(self, round): self.cur_com_round = round
- So, change proximity term with condition
loss = super().get_loss_value(outputs, targets) if self.cur_com_round > self.num_fedavg_rounds - 1: proximity = ( 0.5 * self.fed_prox_lambda * sum( [ (p.float() - q.float()).norm() ** 2 for (_, p), (_, q) in zip( self.model.state_dict().items(), self.server_model_state.items(), ) ] ) ) loss += proximity