In addition to Weibo, there is also WeChat
Please pay attention
WeChat public account
Shulou
2025-02-06 Update From: SLTechnology News&Howtos shulou NAV: SLTechnology News&Howtos > Development >
Share
Shulou(Shulou.com)05/31 Report--
This article mainly introduces the PyTorch model saving and loading method related knowledge, the content is detailed and easy to understand, the operation is simple and fast, has a certain reference value, I believe that everyone after reading this PyTorch model saving and loading method article will have a harvest, let's take a look at it.
State_dict is an Python dictionary that maps each layer to its parameter tensor. Note that only layers with learnable parameters (convolution layer, fully connected layer, and so on) and registered caches (the running average of batchnorm) are recorded in state_dict. State_dict also contains optimizer objects that store the state of the optimizer and the hyperparameters used.
A simple example # defines the model class TheModelClass (nn.Module): def _ init__ (self): super (TheModelClass, self). _ init__ () self.conv1 = nn.Conv2d (3,6,5) self.pool = nn.MaxPool2d (2,2) self.conv2 = nn.Conv2d (6,16,5) self.fc1 = nn.Linear (16 * 5 * 5) Self.fc2 = nn.Linear (120,84) self.fc3 = nn.Linear (84,10) def forward (self, x): X = self.pool (F.relu (self.conv1 (x) x = self.pool (F.relu (self.conv2 (x) x = x.view (- 1) 16 * 5 * 5) x = F.relu (self.fc1 (x)) x = F.relu (self.fc2 (x)) x = self.fc3 (x) return x # initialization model model = TheModelClass () # initialization optimizer optimizer = optim.SGD (model.parameters (), lr=0.001 Momentum=0.9) # print model state_dictprint ("Model's state_dict:") for param_tensor in model.state_dict (): print (param_tensor, "\ t", model.state_dict () [param_tensor] .size () # print optimizer state_dictprint ("Optimizer's state_dict:") for var_name in optimizer.state_dict (): print (var_name, "\ t", optimizer.state_dict () [var_name])
Output:
Model's state_dict:
Conv1.weight torch.Size ([6,3,5,5])
Conv1.bias torch.Size ([6])
Conv2.weight torch.Size ([16,6,5,5])
Conv2.bias torch.Size ([16])
Fc1.weight torch.Size ([120,400])
Fc1.bias torch.Size (120)
Fc2.weight torch.Size ([84,120])
Fc2.bias torch.Size ([84])
Fc3.weight torch.Size ([10,84])
Fc3.bias torch.Size ([10])
Optimizer's state_dict:
State {}
Param_groups [{'lr': 0.001,' momentum': 0.001, 'dampening': 0,' weight_decay': 0, 'nesterov': False,' params': [4675713712, 4675713784, 4675714000, 4675714072,4675714216,4675714288,46757144504,46757146484675714720]}]
Save / load state_dict (recommended)
Save:
Torch.save (model.state_dict (), PATH)
Load:
Model = TheModelClass (* args, * kwargs) model.load_state_dict (torch.load (PATH)) model.eval ()
To pay attention to this detail, if you use nn.DataParallel to use multiple GPU on a computer, you must also do nn.DataParallel when loading the model.
When saving the reasoning process of the model, we only need to save the trained parameters of the model, and use torch.save () to save state_dict, which can facilitate the loading of the model. Therefore, it is recommended to save the model in this way.
Remember to use model.eval () to fix the dropout and normalization layers, otherwise each reasoning will produce different results.
Note that load_state_dict () needs to pass in a dictionary object, so you need to deserialize state_dict before passing in load_state_dict ()
Save / load the entire model
Save:
Torch.save (model, PATH)
Load:
# the model class must define model = torch.load (PATH) model.eval () somewhere else
This process of saving / loading the model uses the most intuitive syntax and uses a small amount of code. This uses Python's pickle to save all modules. The disadvantage of this approach is that when the model is saved, the serialized data is bound to a specific class and exact directory. This is because pickle does not save the model class itself, but saves the path to the class and uses it when loading. Therefore, when using or refactoring in other projects, there will be errors when loading the model.
In general, PyTorch models are saved in .pt or .pth file format.
It is important to remember to call model.eval () when evaluating the pattern to fix dropout and batch normalization. Otherwise, inconsistent reasoning results will be produced.
Save and load regular Checkpoint/ for reasoning or continue training
Save:
Torch.save ({'epoch': epoch,' model_state_dict': model.state_dict (), 'optimizer_state_dict': optimizer.state_dict (),' loss': loss,...}, PATH)
Load:
Model = TheModelClass (* args, * * kwargs) optimizer = TheOptimizerClass (* args, * * kwargs) checkpoint = torch.load (PATH) model.load_state_dict (checkpoint ['model_state_dict']) optimizer.load_state_dict (checkpoint [' optimizer_state_dict']) epoch = checkpoint ['epoch'] loss = checkpoint [' loss'] model.eval () #-or-model.train ()
When saving regular checkpoints for reasoning or continuing training, you must save other parameters in addition to the state_dict of the model. Saving the optimizer's state_dict is also important because it contains the optimizer's cache and parameters when the model is trained. In addition, you can also save the number of epoch when you stop training, the latest model losses, additional torch.nn.Embedding layers, and so on.
To save multiple components, put them in a dictionary and serialize the dictionary using torch.save (). In general, use the .tar file format to save these checkpoints.
Load each component, first initialize the model and optimizer, then load the saved dictionary using torch.load (), and then directly query the values in the dictionary to get the saved component.
Also, don't forget to call model.eval () when evaluating the model.
Save multiple models to one file
Save:
Torch.save ({'modelA_state_dict': modelA.state_dict (),' modelB_state_dict': modelB.state_dict (), 'optimizerA_state_dict': optimizerA.state_dict (),' optimizerB_state_dict': optimizerB.state_dict (),...}, PATH)
Load:
ModelA = TheModelAClass (* args, * kwargs) modelB = TheModelBClass (* args, * * kwargs) optimizerA = TheOptimizerAClass (* args, * * kwargs) optimizerB = TheOptimizerBClass (* args * * kwargs) checkpoint = torch.load (PATH) modelA.load_state_dict (checkpoint ['modelA_state_dict']) modelB.load_state_dict (checkpoint [' modelB_state_dict']) optimizerA.load_state_dict (checkpoint ['optimizerA_state_dict']) optimizerB.load_state_dict (checkpoint [' optimizerB_state_dict']) modelA.eval () modelB.eval () #-or-modelA.train () modelB.train ()
When a saved model contains multiple torch.nn.Modules, such as GAN, a sequence-sequence model, or a combined model, use and save regular checkpoints to save the model. That is, save the state_dict of each model and the corresponding optimizer to a dictionary. We can save anything that can help us continue our training into this dictionary.
Use other models to preheat the current model
Save:
Torch.save (modelA.state_dict (), PATH)
Load:
ModelB = TheModelBClass (* args, * * kwargs) modelB.load_state_dict (torch.load (PATH), strict=False)
It is common to load partial models when migrating and learning or training new complex models. Using the trained parameters, even if only a few parameters are available, it will help to preheat the training process and make the model converge faster.
When loading some model parameters for pre-training, you are likely to encounter a key mismatch (the model weights are saved and loaded back in the form of key-value pairs). Therefore, mismatched keys can be ignored by setting the strict parameter to False in the load_state_dict () function, whether there are missing keys or more keys.
If you want to load the parameters of one layer into other layers, but some keys do not match, then changing the key of the parameters in state_dict can solve this problem.
Save and load models across devices
Save on GPU, load on CPU
Save:
Torch.save (model.state_dict (), PATH)
Load:
Device = torch.device ('cpu') model = TheModelClass (* args, * * kwargs) model.load_state_dict (torch.load (PATH, map_location=device))
When loading a model trained on GPU on CPU, specify map_location=torch.device ('cpu') in torch.load (), and map_location dynamically remaps the underlying storage of tensors to the CPU device.
The above code is valid only if the model is trained on one GPU. If the model is trained on more than one GPU, you will get an error similar to the following when loaded on CPU:
KeyError: 'unexpected key "module.conv1.weight" in state_dict'
The reason is that when you use multi-GPU to train and save the model, the parameter names of the model are prefixed with module, so you can remove this prefix from key when loading the model:
# the original file saved through DataParallel state_dict = torch.load ('myfile.pth.tar') # create a new OrderedDictfrom collections import OrderedDictnew_state_dict = OrderedDict () for k without `module.`, v in state_dict.items (): name = k [7:] # remove `module.`module.`new_state_ upload [name] = v# load parameter model.load_state_dict (new_state_dict)
Save on GPU, load on GPU
Save:
Torch.save (model.state_dict (), PATH)
Load:
Device = torch.device ("cuda") model = TheModelClass (* args, * * kwargs) model.load_state_dict (torch.load (PATH)) model.to (device) # Don't forget to call input = input.to (device) on any tensor when entering data into the model
When loading a model trained on GPU onto GPU, you only need to use model.to (torch.devie ('cuda')) to transform the initialized model into an CUDA optimized model. Also make sure to use .to (torch.device ('cuda')) on all inputs to the model. Note that a call to my_tensor.to (device) returns a copy of the my_tensor on GPU. The original my_tensor will not be overwritten, so remember to manually rewrite tensor: my_tensor = my_tensor.to (torch.device ('cuda')).
Save on CPU, load on GPU
Save:
Torch.save (model.state_dict (), PATH)
Load:
Device = torch.device ("cuda") model = TheModelClass (* args, * * kwargs) model.load_state_dict (torch.load (PATH, map_location= "cuda:0")) # Select the GPUmodel.to (device) you want to use
Save the torch.nn.DataParallel model
Save:
This is the end of the article torch.save (model.module.state_dict (), PATH) on "methods for saving and loading PyTorch models". Thank you for reading! I believe you all have a certain understanding of the knowledge of "how to save and load PyTorch models". If you want to learn more, you are welcome to follow the industry information channel.
Welcome to subscribe "Shulou Technology Information " to get latest news, interesting things and hot topics in the IT industry, and controls the hottest and latest Internet news, technology news and IT industry trends.
Views: 0
*The comments in the above article only represent the author's personal views and do not represent the views and positions of this website. If you have more insights, please feel free to contribute and share.
Continue with the installation of the previous hadoop.First, install zookooper1. Decompress zookoope
"Every 5-10 years, there's a rare product, a really special, very unusual product that's the most un
© 2024 shulou.com SLNews company. All rights reserved.