In addition to Weibo, there is also WeChat
Please pay attention
WeChat public account
Shulou
2025-01-16 Update From: SLTechnology News&Howtos shulou NAV: SLTechnology News&Howtos > Development >
Share
Shulou(Shulou.com)05/31 Report--
In this article Xiaobian for you to introduce in detail "PyTorch how to achieve the basic algorithm FedAvg", the content is detailed, the steps are clear, the details are handled properly, I hope that this "PyTorch how to achieve the basic algorithm FedAvg" article can help you solve your doubts, following the editor's ideas slowly in-depth, together to learn new knowledge.
Data introduction
There are multiple clients in federated learning, each with its own dataset that they do not want to share.
The data set selected in this paper is the real power load data of ten districts / counties in a city in northern China from 2016 to 2019. The collection interval is one hour, that is, there are 24 load values every day.
We assume that the power departments in these 10 regions are not willing to share their own data, but they want to get a global model trained by all the data.
In addition to power load data, there is an alternative data set: wind power data set. The two datasets are specified by the parameter type: type = = 'load' for load data and' wind' for wind power data.
Characteristic structure
The load value of a certain time is predicted by the load value of the first 24 times and the relevant meteorological data (such as temperature, humidity, pressure, etc.) at that time.
For the wind power data, the wind power values of the first 24 moments and the relevant meteorological data at that time are also used to predict the wind power values at that time.
Each region should reach an agreement on how to develop the feature set. The characteristics of the data in each region used in this paper are consistent and can be used directly.
Federal Learning 1. Overall framework
The framework of FedAvg proposed in the original paper is:
The client model is built with PyTorch:
Class ANN (nn.Module): def _ init__ (self, input_dim, name, B, E, type, lr): super (ANN, self). _ init__ () self.name = name self.B = B self.E = E self.len = 0 self.type = type self.lr = lr self.loss = 0 self.fc1 = nn.Linear (input_dim) 20) self.relu = nn.ReLU () self.sigmoid = nn.Sigmoid () self.dropout = nn.Dropout () self.fc2 = nn.Linear (20,20) self.fc3 = nn.Linear (20,20) self.fc4 = nn.Linear (20,1) def forward (self Data): X = self.fc1 (data) x = self.sigmoid (x) x = self.fc2 (x) x = self.sigmoid (x) x = self.fc3 (x) x = self.sigmoid (x) x = self.fc4 (x) x = self.sigmoid (x) return x2. Server side
The server side performs the following steps:
To put it simply, only some clients are selected in each round of communication, these clients use local data to update the parameters, and then transmit the updated parameters to the server, and the server summarizes the updated parameters of the client to form the latest global parameters. In the next round of communication, the server distributes the latest parameters to the selected client for the next round of updates.
3. Client
The client has nothing to say but uses local data to update the parameters of the neural network model.
Code implementation 1. Initialize class FedAvg: def _ _ init__ (self Options): self.C = options ['C'] self.E = options ['E'] self.B = options ['B'] self.K = options ['K'] self.r = options ['r'] self.input_dim = options ['input_dim'] self.type = options [' type'] self.lr = options ['lr'] Self.clients = options ['clients'] self.nn = ANN (input_dim=self.input_dim Name='server', Baub, E, type=self.type, lr=self.lr) .to (device) self.nns = [] for i in range (K): temp = copy.deepcopy (self.nn) temp.name = self.clients [I] self.nns.append (temp)
Parameters:
K, the number of clients, this article is 10, that is, 10 regions.
C: selection rate, only C * K clients are selected for each round of communication.
E: when the client updates the parameters of the local model, it trains the E round on the local data set.
B: when the client updates the parameters of the local model, the local dataset batch size is B
R: there is a total of r rounds of communication between the server and the client.
Clients: client collection.
Type: specify the data type, load forecasting, or wind power forecasting.
Lr: learning rate.
Input_dim: data input dimension.
Nn: global model.
Nns: a collection of client models.
two。 Server side
The server-side code is as follows:
Def server (self): for t in range (self.r): print ('th', t + 1, 'round communication:') m = np.max ([int (self.C * self.K), 1]) # sampling index = random.sample (range (0, self.K)) M) # dispatch self.dispatch (index) # local updating self.client_update (index) # aggregation self.aggregation (index) # return global model return self.nn
Where client_update (index):
Def client_update (self, index): # update nn for k in index: self.nns [k] = train (self.nns [k])
Aggregation (index):
Def aggregation (self, index): s = 0 for j in index: # normal s + = self.nns [j] .len params = {} with torch.no_grad (): for k V in self.nns [0] .named _ parameters (): params [k] = copy.deepcopy (v) params [k] .zero _ () for j in index: with torch.no_grad (): for k V in self.nns [j] .named _ parameters (): params [k] + = v * (self.nns [j] .len / s) with torch.no_grad (): for k, v in self.nn.named_parameters (): v.copy_ (params [k])
Dispatch (index):
Def dispatch (self, index): params = {} with torch.no_grad (): for k, v in self.nn.named_parameters (): params [k] = copy.deepcopy (v) for j in index: with torch.no_grad (): for k, v in self.ns [j] .named _ parameters (): v.copy_ (params [k])
Here is an analysis of the important code:
The choice of client
M = np.max ([int (self.C * self.K), 1]) index = random.sample (range (0, self.K), m)
The m integers between 010 and 10 are stored in index, representing the serial number of the selected client.
Client updates
For k in index: self.client_update (self.ns [k])
The server summarizes the parameters of the client model
About how models are summarized, you can refer to another of my articles: understanding the process of model aggregation in FedAvg.
Of course, this is just a very simple way of summarizing, and there are some other types of summarization.
Three kinds of summary methods are summarized in Electricity Consumer Characteristics Identification: A Federated Learning Approach.
Normal: the way in the original paper, that is, to determine the proportion of client parameters in the final combination according to the number of samples.
LA: the proportion of parameters in the final combination is determined according to the proportion of the loss of the client model to the sum of all client losses.
LS: based on the proportion of the product of the loss and the number of samples. Distribute the updated parameters to the selected client
Def dispatch (self, index): params = {} with torch.no_grad (): for k, v in self.nn.named_parameters (): params [k] = copy.deepcopy (v) for j in index: with torch.no_grad (): for k, v in self.ns [j] .named _ parameters (): v.copy_ (params [k]) 3. Client
The client only needs to update with local data:
Def client_update (self, index): # update nn for k in index: self.nns [k] = train (self.nns [k])
Where train ():
Def train (ann): ann.train () # print (p) if ann.type = = 'load': Dtr, Dte = nn_seq (ann.name, ann.B, ann.type) else: Dtr, Dte = nn_seq_wind (ann.named, ann.B) Ann.type) ann.len = len (Dtr) # print (len (Dtr)) loss_function = nn.MSELoss () .to (device) loss = 0 optimizer = torch.optim.Adam (ann.parameters (), lr=ann.lr) for epoch in range (ann.E): cnt = 0 for (seq) Label) in Dtr: cnt + = 1 seq = seq.to (device) label = label.to (device) y_pred = ann (seq) loss = loss_function (y_pred, label) optimizer.zero_grad () loss.backward () optimizer.step () print ('epoch', epoch,':' Loss.item () return ann4. Test def global_test (self): model = self.nn model.eval () c = clients if self.type = = 'load' else clients_wind for client in c: model.name = client test (model) V. Experiment and results
The parameters of this experiment are as follows:
KCEBr100.550505if _ _ name__ = ='_ main__': K, C, E, B, r = 10,0.5,50,50,5 type = 'load' input_dim = 30 if type = =' load' else 28 _ client = clients if type = = 'load' else clients_wind lr = 0.08 options = {' Kraft: K, 'Che: C,' E: e, 'dating: B,' ringing: r, 'type': type 'clients': _ client,' input_dim': input_dim, 'lr': lr} fedavg = FedAvg (options) fedavg.server () fedavg.global_test ()
The performance of each client on the local test set after individual training (50 rounds of training with batch size of 50) is as follows:
Client number 12345678910MAPE /% 5.334.113.034.203.022.702.942.992.304.10
As you can see, because the data of each client is very sufficient, the prediction accuracy of the local model trained by each client is already very high.
After 5 rounds of communication between the server and the client, the global model on the server performs as follows on 10 client test sets:
Client number 12345678910MAPE /% 6.844.543.565.113.754.474.303.903.154.58
As you can see, through the federated learning framework, the global model also performs well on each client; this is because the data distribution is similar in ten regions.
Give a comparison between numpy and PyTorch:
Client number 12345678910 local 5.334.113.034.203.022.702.942.992.304.10numpy6.584.193.175.133.584.694.713.752.944.77PyTorch6.844.543.565.113.754.474.303.903.154.58
Similarly, the effect of the local model is the best. The network built by PyTorch is similar to that built by numpy, but it is recommended to use PyTorch instead of building wheels.
Read here, this article "how to achieve the basic algorithm FedAvg PyTorch" article has been introduced, want to master the knowledge of this article also need to practice and use to understand, if you want to know more about the article, welcome to follow the industry information channel.
Welcome to subscribe "Shulou Technology Information " to get latest news, interesting things and hot topics in the IT industry, and controls the hottest and latest Internet news, technology news and IT industry trends.
Views: 0
*The comments in the above article only represent the author's personal views and do not represent the views and positions of this website. If you have more insights, please feel free to contribute and share.
Continue with the installation of the previous hadoop.First, install zookooper1. Decompress zookoope
"Every 5-10 years, there's a rare product, a really special, very unusual product that's the most un
© 2024 shulou.com SLNews company. All rights reserved.