In addition to Weibo, there is also WeChat
Please pay attention
WeChat public account
Shulou
2025-01-18 Update From: SLTechnology News&Howtos shulou NAV: SLTechnology News&Howtos > Development >
Share
Shulou(Shulou.com)05/31 Report--
This article mainly introduces PyTorch how to achieve FedProx federal learning algorithm related knowledge, the content is detailed and easy to understand, the operation is simple and fast, has a certain reference value, I believe that everyone after reading this PyTorch how to achieve FedProx federal learning algorithm article will have a harvest, let's take a look at it.
i. Preface
The principle of FedProx can be found in the experiment summary of FedAvg federation learning FedProx heterogeneous network optimization.
There are multiple clients in federated learning, each with its own dataset that they do not want to share.
The data set is the wind power of ten areas in a city. we assume that the power departments of these 10 regions are not willing to share their own data, but they also want to get a global model trained by all the data.
III. FedProx
Algorithm pseudo code:
1. Model definition
The client model is a simple four-layer neural network model:
#-*-coding:utf-8-*-"@ Time: 2022-03-03 12:23@Author: KI@File: model.py@Motto: Hungry And Humble" from torch import nnclass ANN (nn.Module): def _ _ init__ (self, args, name): super (ANN) Self). _ init__ () self.name = name self.len = 0 self.loss = 0 self.fc1 = nn.Linear (args.input_dim, 20) self.relu = nn.ReLU () self.sigmoid = nn.Sigmoid () self.dropout = nn.Dropout () self.fc2 = nn.Linear (20,20) self.fc3 = nn.Linear (20) 20) self.fc4 = nn.Linear (20,1) def forward (self, data): X = self.fc1 (data) x = self.sigmoid (x) x = self.fc2 (x) x = self.sigmoid (x) x = self.fc3 (x) x = self.sigmoid (x) x = self.fc4 (x) x = self.sigmoid (x) return x2. Server side
The server side is consistent with FedAvg, that is, repeat the three steps of client sampling, parameter communication, and parameter aggregation:
#-*-coding:utf-8-*-"@ Time: 2022-03-03 12:50@Author: KI@File: server.py@Motto: Hungry And Humble" import copyimport randomimport numpy as npimport torchfrom model import ANNfrom client import train, testclass FedProx: def _ _ init__ (self, args): self.args = args self.nn = ANN (args=self.args) Name='server') .to (args.device) self.nns = [] for i in range (self.args.K): temp = copy.deepcopy (self.nn) temp.name = self.args.clients [I] self.nns.append (temp) def server (self): for t in range (self.args.r): print ('round', t + 1) ':') # sampling m = np.max ([int (self.args.C * self.args.K), 1]) index = random.sample (range (0, self.args.K), m) # st # dispatch self.dispatch (index) # local updating self.client_update (index T) # aggregation self.aggregation (index) return self.nn def aggregation (self, index): s = 0 for j in index: # normal s + = self.nns [j] .len params = {} for k V in self.nns [0] .named _ parameters (): params [k] = torch.zeros_like (v.data) for j in index: for k, v in self.nns [j] .named _ parameters (): params [k] + = v.data * (self.nns.len / s) for k V in self.nn.named_parameters (): v.data = params.data.clone () def dispatch (self, index): for j in index: for old_params, new_params in zip (self.ns [j] .parameters (), self.nn.parameters ()): old_params.data = new_params.data.clone () def client_update (self, index) Global_round): # update nn for k in index: self.nns [k] = train (self.args, self.nns [k], self.nn, global_round) def global_test (self): model = self.nn model.eval () for client in self.args.clients: model.name = client test (self.args, model) 3. Client update
The functions that the client needs to optimize in FedProx are:
Based on the FedAvg loss function, the author introduces a proximal term, which we can call the near-end term. After introducing the near-end item, the model parameter w obtained by the client after local training will not deviate too much from the initial server parameter wt.
The corresponding code is:
Def train (args, model, server, global_round): model.train () Dtr, Dte = nn_seq_wind (model.name, args.B) model.len = len (Dtr) global_model = copy.deepcopy (server) if args.weight_decay! = 0: lr = args.lr * pow (args.weight_decay Global_round) else: lr= args.lr if args.optimizer = 'adam': optimizer = torch.optim.Adam (model.parameters (), lr=lr, weight_decay=args.weight_decay) else: optimizer = torch.optim.SGD (model.parameters (), lr=lr, momentum=0.9) Weight_decay=args.weight_decay) print ('training...') Loss_function = nn.MSELoss () .to (args.device) loss = 0 for epoch in range (args.E): for (seq) Label) in Dtr: seq = seq.to (args.device) label = label.to (args.device) y_pred = model (seq) optimizer.zero_grad () # compute proximal_term proximal_term = 0.0 for w, Walt in zip (model.parameters () Global_model.parameters (): proximal_term + = (w-wintert) .norm (2) loss = loss_function (y_pred, label) + (args.mu / 2) * proximal_term loss.backward () optimizer.step () print ('epoch', epoch,':', loss.item ()) return model
We add a near-end term to the original MSE loss function:
For w, wattt in zip (model.parameters (), global_model.parameters ()): proximal_term + = (w-wintert) .norm (2)
Then the gradient is calculated by back propagation, and then the optimizer step updates the parameters.
The concept of inexact solution is also put forward in the original paper.
However, it is worth noting that I did not find an explanation on how to choose γ\ gamma γ in the experimental section of the original paper. After checking the data, it is found that it is related to the knowledge of near-end gradient decline, this code does not consider the inexact solution, may be made up later.
Where:
Server.py is a server-side operation.
Client.py operates for the client.
Data_process.py is the data processing part.
Model.py is the model definition file.
Args.py is the parameter definition file.
Main.py is the main file. If you want to run this project, you can run it directly:
This is the end of python main.py 's article on "how to implement the FedProx federated learning algorithm in PyTorch". Thank you for reading! I believe you all have a certain understanding of "how to achieve FedProx federated learning algorithm in PyTorch". If you want to learn more, you are welcome to follow the industry information channel.
Welcome to subscribe "Shulou Technology Information " to get latest news, interesting things and hot topics in the IT industry, and controls the hottest and latest Internet news, technology news and IT industry trends.
Views: 0
*The comments in the above article only represent the author's personal views and do not represent the views and positions of this website. If you have more insights, please feel free to contribute and share.
Continue with the installation of the previous hadoop.First, install zookooper1. Decompress zookoope
"Every 5-10 years, there's a rare product, a really special, very unusual product that's the most un
© 2024 shulou.com SLNews company. All rights reserved.