Network Security Internet Technology Development Database Servers Mobile Phone Android Software Apple Software Computer Software News IT Information

In addition to Weibo, there is also WeChat

Please pay attention

WeChat public account

Shulou

What are the problems of model preservation and migration in Pytorch

2025-01-31 Update From: SLTechnology News&Howtos shulou NAV: SLTechnology News&Howtos > Development >

Share

Shulou(Shulou.com)06/03 Report--

This article introduces the relevant knowledge about the preservation and migration of models in Pytorch. In the operation of actual cases, many people will encounter such a dilemma. Then let the editor lead you to learn how to deal with these situations. I hope you can read it carefully and be able to achieve something!

Catalogue

1 introduction

2 preservation and reuse of models

2.1 View network model parameters

2.2 load the model for inference

2.3 load the model for training

2.4 load the model for migration

3 Summary

1 introduction

Hello, dear friends. Welcome to Yuelai Inn. What I want to introduce to you today is how to save and load the model in the Pytorch framework, as well as how to migrate and retrain the model. Generally speaking, the most common scenario is the inference process of the model after training. After completing the training, a network model usually needs to predict the new samples, so it only needs to build the forward propagation process of the model, and then load the trained parameters to initialize the network.

The second scenario is the retraining process of the model. A model needs to be saved locally after training on a batch of data, and a new batch of data may be collected after a period of time, so at this time, the previous model needs to be loaded for incremental training on the new data (or full training on the whole data).

The third application scenario is the transfer learning of the model. This is when you take the pre-model that has been trained by others and initialize it as part of your own network model parameters. For example, if you add several fully connected layers to the Bert model to do the classification task, you need to load the parameters from the original BERT model and initialize the weight parameters of the BERT part of your network.

In the following article, the author takes the above three scenarios as examples to introduce how to use the Pytorch framework to complete the above process.

2 preservation and reuse of models

In Pytorch, we can complete the main steps in the above scenario through torch.save () and torch.load (). Next, the author will take the previously introduced LeNet5 network model as an example to introduce respectively. Before we do that, however, let's take a look at the saved form of the model parameters in Pytorch.

2.1 View network model parameters

(1) View parameters

First, define the network model structure of LeNet5, as shown in the following code:

Class LeNet5 (nn.Module): def _ _ init__ (self,): super (LeNet5, self). _ _ init__ () self.conv = nn.Sequential (# [nmae1 mine28 nn.Conv2d 28] nn.Conv2d (1, 6, 5, padding=2), # in_channels, out_channels, kernel_size nn.ReLU (), # [nmae6, out_channels 24] nn.MaxPool2d (2, 2), # kernel_size Stride (6, 16) nn.Conv2d (6, 16), # [nn.ReLU (), nn.MaxPool2d (2, 2)) # [self.fc = nn.Sequential (nn.Flatten (), nn.Linear (16 * 5 * 5120), nn.ReLU ()) Nn.Linear (120,84), nn.ReLU (), nn.Linear (84,10) def forward (self, img): output = self.conv (img) output = self.fc (output) return output

After defining the class of the LeNet5 network structure, as long as we complete the instantiation of this class, then the corresponding weight parameters in the network also complete the initialization work, that is, there is an initial value. At the same time, we can access it in the following ways:

# Print model's state_dictprint ("Model's state_dict:") for param_tensor in model.state_dict (): print (param_tensor, "\ t", model.state_dict () [param_tensor] .size ())

The result of its output is:

Conv.0.weight torch.Size ([6, 1, 5])

Conv.0.bias torch.Size ([6])

Conv.3.weight torch.Size ([16,6,5,5])

....

....

You can find that the parameter model.state_dict () in the network model is actually saved in the form of a dictionary (essentially an OrderedDict in the collections module):

Print (model.state_dict (). Keys ()) # odict_keys (['conv.0.weight',' conv.0.bias', 'conv.3.weight',' conv.3.bias', 'fc.1.weight',' fc.1.bias', 'fc.3.weight',' fc.3.bias', 'fc.5.weight',' fc.5.bias'])

(2) Custom parameter prefix

At the same time, there are two points worth noting here: the fc and conv prefixes in the ① parameter name are determined based on the name when you defined nn.Sequential () above; the number in the ② parameter name indicates the location of the network layer in each Sequential (). For example, the network structure is defined in the following form:

Class LeNet5 (nn.Module): def _ _ init__ (self,): super (LeNet5, self). _ _ init__ () self.moon = nn.Sequential (# [nmae1 mine28 nn.Conv2d 28] nn.Conv2d (1, 6, 5, padding=2), # in_channels, out_channels, kernel_size nn.ReLU (), # [nmae6, out_channels 24] nn.MaxPool2d (2, 2), # kernel_size Stride (6, 16, 5) nn.Conv2d (6, 16, 5), # [nn.ReLU (), nn.MaxPool2d (2, 2), nn.Flatten (), nn.Linear (16 * 5 * 5,120), nn.ReLU (), nn.Linear (120,84), nn.ReLU () Nn.Linear (84,10))

Then the parameter name is:

Print (model.state_dict (). Keys ()) odict_keys (['moon.0.weight',' moon.0.bias', 'moon.3.weight',' moon.3.bias', 'moon.7.weight',' moon.7.bias', 'moon.9.weight',' moon.9.bias', 'moon.11.weight',' moon.11.bias'])

Understanding this is very helpful for us to parse and load some pre-training models later.

In addition, for the optimizer in, and so on, there is also a corresponding state_dict () method to obtain the parameters for, for example:

Optimizer = torch.optim.SGD (model.parameters (), lr=0.001, momentum=0.9) print ("Optimizer's state_dict:") for var_name in optimizer.state_dict (): print (var_name, "\ t", optimizer.state_dict () [var_name]) # Optimizer's state_dict:state {} param_groups [{'lr': 0.001,' momentum': 0.001, 'dampening': 0,' weight_decay': 0 'nesterov': False,' params': [140239245300504, 140239208339784, 140239245311360, 140239245310856, 140239266942480,140239266942552, 14023926694262424, 140239266942696, 140239266942912140239267041352]]

After introducing the method of viewing the model parameters, you can go to the content introduction of the model reuse phase.

2.2 load the model for inference

(1) Model preservation

In Pytorch, it is very simple to save the model, usually with the following two lines of code:

Model_save_path = os.path.join (model_save_dir, 'model.pt') torch.save (model.state_dict (), model_save_path)

When specifying saved model names, Pytorch officially recommends a suffix of .pt or .pth (which is not mandatory, of course). Finally, you only need to add the second line of code in the appropriate place to save the model.

At the same time, if you want to save the optimal model under certain conditions during the training process, you should do so in the following ways:

Best_model_state = deepcopy (model.state_dict ()) torch.save (best_model_state, model_save_path)

Instead of:

Best_model_state = model.state_dict () torch.save (best_model_state, model_save_path)

Because the latter best_model_state is just a reference to model.state_dict (), it will still change with the training process.

(2) reuse the model for inference.

In the process of inference, you need to initialize the network first, and then load the existing model parameters to override the weight parameters in the network. The example code is as follows:

Def inference (data_iter, device, model_save_dir='./MODEL'): model = LeNet5 () # initialize the weight parameter model.to (device) model_save_path = os.path.join (model_save_dir) of the existing model 'model.pt') if os.path.exists (model_save_path): loaded_paras = torch.load (model_save_path) model.load_state_dict (loaded_paras) # reinitialize the network weight parameter model.eval () with the local existing model # Don't forget with torch.no_grad (): acc_sum, n = 0.0,0 for x Y in data_iter: X, y = x.to (device), y.to (device) logits = model (x) acc_sum + = (logits.argmax (1) = = y). Float (). Sum (). Item () n + = len (y) print ("Accuracy in test data is:", acc_sum / n)

In the above code, lines 4-7 are used to load the local model parameters and overwrite the original parameters in the network model. In this way, subsequent inference work can be carried out:

Accuracy in test data is: 0.88512.3 load the model for training

After introducing the preservation and reuse of the model, the additional training for the network is very simple. The easiest way is to save only the network weight during the training process, and then load only the network weight parameters to initialize the network for training in the subsequent additional training. The example is as follows (see [2] for complete code):

Def train (self): # Model_save_path = os.path.join (self.model_save_dir, 'model.pt') if os.path.exists (model_save_path): loaded_paras = torch.load (model_save_path) self.model.load_state_dict (loaded_paras) print ("# successfully loaded the existing model for additional training.") Optimizer = torch.optim.Adam (self.model.parameters (), lr=self.learning_rate) # defines the optimizer #. For epoch in range (self.epochs): for I, (x, y) in enumerate (train_iter): X, y = x.to (device), y.to (device) logits = self.model (x) #. Print ("Epochs [{} / {}]-- acc on test {: .4}" .format (epoch, self.epochs, self.evaluate (test_iter, self.model, device)) torch.save (self.model.state_dict (), model_save_path)

In this way, the additional training of the model is completed:

# load the existing model successfully and carry out additional training. Epochs [0lap5]-- batch [938Univer]-acc 0.9062---loss 0.2926Epochs [0lap5]-batch [938Univer]-acc 0.9375---loss 0.1598.

In addition, when you save the parameters, you can save the optimizer parameters, lost values, and so on, and then restore the model with other parameters, as shown in the following example:

Model_save_path = os.path.join (model_save_dir, 'model.pt') torch.save ({' epoch': epoch, 'model_state_dict': model.state_dict (),' optimizer_state_dict': optimizer.state_dict (), 'loss': loss,...}, model_save_path)

The loading method is as follows:

Checkpoint = torch.load (model_save_path) model.load_state_dict (checkpoint ['model_state_dict']) optimizer.load_state_dict (checkpoint [' optimizer_state_dict']) epoch = checkpoint ['epoch'] loss = checkpoint [' loss'] 2.4 load the model for migration

(1) define a new model

So far, even if the introduction of the first two application scenarios has been completed, it can be found that it is not complex in general. But for the application of scenario 3, it will be a little more complicated.

Suppose there is a LeNet6 network model that adds a full connectivity layer to the foundation of LeNet5, which is defined as follows:

Class LeNet6 (nn.Module): def _ _ init__ (self,): super (LeNet6, self). _ _ init__ () self.conv = nn.Sequential (# [nmae1 mine28 nn.Conv2d 28] nn.Conv2d (1, 6, 5, padding=2), # in_channels, out_channels, kernel_size nn.ReLU (), # [nmae6, out_channels 24] nn.MaxPool2d (2, 2), # kernel_size Stride (6, 16) nn.Conv2d (6, 16), # [nn.ReLU (), nn.MaxPool2d (2, 2)) # [self.fc = nn.Sequential (nn.Flatten (), nn.Linear (16 * 5 * 5120), nn.ReLU ()) Nn.Linear (120,84), nn.ReLU (), nn.Linear (84,64), nn.ReLU (), nn.Linear (64,10)) # newly added full connection layer

Next, we need to transfer the weight parameters trained on LeNet5 to LeNet6 network. From the definition of LeNet6 above, we can see that although only an additional full connection layer has been added, the dimension of the parameters in the penultimate layer has also changed. Therefore, for LeNet6, only the weight parameters of the first four layers of the LeNet5 network can be reused.

(2) View model parameters

After getting a model parameter, we can first load it and view the information about the relevant parameters:

Model_save_path = os.path.join ('. / MODEL', 'model.pt') loaded_paras = torch.load (model_save_path) for param_tensor in loaded_paras: print (param_tensor, "\ t", loaded_ paras [param _ tensor]. Size ()) #-reusable part conv.0.weight torch.Size ([6,1,5) 5]) conv.0.bias torch.Size ([6]) conv.3.weight torch.Size ([16,6,5,5]) conv.3.bias torch.Size ([16]) fc.1.weight torch.Size ([120,400]) fc.1.bias torch.Size ([120]) fc.3.weight torch.Size ([84]) ) fc.3.bias torch.Size ([84]) #-non-reusable part fc.5.weight torch.Size ([10,84]) fc.5.bias torch.Size ([10])

At the same time, the parameter information for LeNet6 network is:

Model = LeNet6 () for param_tensor in model.state_dict (): print (param_tensor, "\ t", model.state_dict () [param_tensor]. Size () # conv.0.weight torch.Size ([6,1,5,5]) conv.0.bias torch.Size ([6]) conv.3.weight torch.Size ([16,6,5,5]) conv.3.bias torch.Size ([16]) fc.1.weight torch.Size Fc.1.bias torch.Size ([120]) fc.3.weight torch.Size ([84,120]) fc.3.bias torch.Size ([84]) #-newly added part fc.5.weight torch.Size ([64,84]) fc.5.bias torch.Size ([64]) fc.7.weight torch.Size ([10,64]) fc.7.bias torch.Size ([10])

After clarifying the parameters of the old and new models, we can take out the parameters we need in LeNet5, and then change them to the network of LeNet6.

(3) Model migration

Although the locally loaded model parameters (loaded_paras above) and the initialized parameters of the model (model.state_dict () above) are in the form of a dictionary, we cannot directly change the weight parameters in model.state_dict (). Here you need to construct a state_dict and then reinitialize the parameters in the network through the model.load_state_dict () method.

At the same time, in the process, we need to filter out the non-reusable parts of the local model, as follows:

Def para_state_dict (model, model_save_dir): state_dict = deepcopy (model.state_dict ()) model_save_path = os.path.join (model_save_dir 'model.pt') if os.path.exists (model_save_path): loaded_paras = torch.load (model_save_path) for key in state_dict: # traverse the corresponding parameter if key in loaded_paras and state_ parameter in the new network model [key] .size () = = loaded_para s [key] .size (): print ("initialize parameter successfully:" Key) state_ keys = loaded_ paras [key] return state_dict

In the above code, the function of line 2 is to copy the original parameters in the network (LeNet6) first; lines 6-9 replace the corresponding part of the LeNet6 with the local model parameters (LeNet5) that can be reused, and line 7 is to determine the available conditions. At the same time, it should be noted that the way of screening may be different in different situations, so the specific situation needs specific analysis, but the overall logic is the same.

Finally, we only need to call this function before model training, and then reinitialize some of the weight parameters in LeNet6 [2]:

State_dict = para_state_dict (self.model, self.model_save_dir) self.model.load_state_dict (state_dict)

The training results are as follows:

Parameter for successful initialization: conv.0.weight

Parameter for successful initialization: conv.0.bias

Parameter for successful initialization: conv.3.weight

Parameter for successful initialization: conv.3.bias

Parameter for successful initialization: fc.1.weight

Parameter for successful initialization: fc.1.bias

Parameter for successful initialization: fc.3.weight

Parameter for successful initialization: fc.3.bias

# successfully load the existing model and carry out additional training.

Epochs [0There 5]-batch [938 Universe 0]-acc 0.1094---loss 2.512

Epochs [0Action5]-batch [938Univer]-acc 0.9375---loss 0.2141

Epochs [0Action5]-batch [938Uniq200]-acc 0.9219---loss 0.2729

Epochs [0char5]-batch [938 Uniq300]-acc 0.8906---loss 0.2958

.

Epochs [0Action5]-batch [938Uniq900]-acc 0.8906---loss 0.2828

Epochs [0B5]-- acc on test 0.8808

It can be found that after about 100 batch, the accuracy of the model has improved.

This is the end of the content of "what are the problems of model preservation and migration in Pytorch". Thank you for reading. If you want to know more about the industry, you can follow the website, the editor will output more high-quality practical articles for you!

Welcome to subscribe "Shulou Technology Information " to get latest news, interesting things and hot topics in the IT industry, and controls the hottest and latest Internet news, technology news and IT industry trends.

Views: 0

*The comments in the above article only represent the author's personal views and do not represent the views and positions of this website. If you have more insights, please feel free to contribute and share.

Share To

Development

Wechat

© 2024 shulou.com SLNews company. All rights reserved.

12
Report