In addition to Weibo, there is also WeChat
Please pay attention
WeChat public account
Shulou
2025-01-19 Update From: SLTechnology News&Howtos shulou NAV: SLTechnology News&Howtos > Development >
Share
Shulou(Shulou.com)06/01 Report--
This article mainly introduces how pytorch uses model.eval () and BN layer. It is very detailed and has a certain reference value. Friends who are interested must read it!
Look at the code ~ class ConvNet (nn.module): def _ _ init__ (self, num_class=10): super (ConvNet, self). _ _ init__ () self.layer1 = nn.Sequential (nn.Conv2d (1,16, kernel_size=5, stride=1, padding=2), nn.BatchNorm2d (16), nn.ReLU () Nn.MaxPool2d (kernel_size=2, stride=2) self.layer2 = nn.Sequential (nn.Conv2d (16, 32, kernel_size=5, stride=1, padding=2), nn.BatchNorm2d (32), nn.ReLU () Nn.MaxPool2d (kernel_size=2, stride=2) self.fc = nn.Linear (7 / 7 / 32, num_classes) def forward (self, x): out = self.layer1 (x) out = self.layer2 (out) print (out.size ()) out = out.reshape (out.size (0)) -1) out = self.fc (out) return out# Test the modelmodel.eval () # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance) with torch.no_grad (): correct = 0 total = 0 for images, labels in test_loader: images = images.to (device) labels = labels.to (device) outputs = model (images) _, predicted = torch.max (outputs.data 1) total + = labels.size (0) correct + = (predicted = = labels) .sum () .item ()
If the network model model contains a BN layer, the mode should be switched to evaluation mode, that is, model.eval (), when forecasting.
The mean and variance of the BN layer under the evaluation simulation should be the mean and variance of the whole training set, namely moving mean/variance.
Under the training mode, the mean and variance of BN layer are the mean and variance of mini-batch, so we should pay special attention to it.
Add: there is a huge difference between Pytorch model training mode and eval model (Pytorch train and eval) with solution
When the pytorch model is eval (), sometimes the result is very different from train (True), this difference is mainly due to the use of BN, in eval, the BN is a fixed running rate, and in train this running rate will change according to input.
The solution is to freeze bndef freeze_bn (m): if isinstance (m, nn.BatchNorm2d): m.eval () model.apply (freeze_bn)
In this way, a stable output can be obtained.
These are all the contents of the article "how pytorch uses model.eval () and BN layers". Thank you for reading! Hope to share the content to help you, more related knowledge, welcome to follow the industry information channel!
Welcome to subscribe "Shulou Technology Information " to get latest news, interesting things and hot topics in the IT industry, and controls the hottest and latest Internet news, technology news and IT industry trends.
Views: 0
*The comments in the above article only represent the author's personal views and do not represent the views and positions of this website. If you have more insights, please feel free to contribute and share.
Continue with the installation of the previous hadoop.First, install zookooper1. Decompress zookoope
"Every 5-10 years, there's a rare product, a really special, very unusual product that's the most un
© 2024 shulou.com SLNews company. All rights reserved.