In addition to Weibo, there is also WeChat
Please pay attention
WeChat public account
Shulou
2025-02-27 Update From: SLTechnology News&Howtos shulou NAV: SLTechnology News&Howtos > Internet Technology >
Share
Shulou(Shulou.com)06/01 Report--
This article focuses on "how to achieve PyTorch batch visualization", interested friends may wish to have a look. The method introduced in this paper is simple, fast and practical. Now let the editor take you to learn "how to achieve PyTorch batch visualization"!
1. Visualize the Loss and Accuracy curves trained by any network model. Train and Valid must be in the same graph.
two。 Batch visualization of arbitrary image training input data by using make_grid
1. Accuracy curve
Not running on the server, read only
#-*-coding:utf-8-*-"" @ brief: monitor loss, accuracy, weights, gradients "import osimport numpy as npimport torchimport torch.nn as nnfrom torch.utils.data import DataLoaderimport torchvision.transforms as transformsfrom torch.utils.tensorboard import SummaryWriterimport torch.optim as optimfrom matplotlib import pyplot as pltfrom model.lenet import LeNetfrom tools.my_dataset import RMBDatasetfrom tools.common_tools2 import set_seedset_seed () # set random seed rmb_label = {" 1 ": 0 MAX_EPOCH = 10BATCH_SIZE = 16LR = 0.01log_interval = 10val_interval = data = = split_dir = os.path.join ("..", "data", "rmb_split") train_dir = os.path.join (split_dir, "train") valid_dir = os.path.join (split_dir, "valid") norm_mean = [0.485, 0.456 0.406] norm_std = [0.229, 0.224, 0.225] train_transform = transforms.Compose ([transforms.Resize ((32,32)), transforms.RandomCrop (32, padding=4), transforms.RandomGrayscale (pendant 0.8), transforms.ToTensor (), transforms.Normalize (norm_mean, norm_std),]) valid_transform = transforms.Compose ([transforms.Resize ((32,32)), transforms.ToTensor (), transforms.Normalize (norm_mean) Norm_std),]) # build MyDataset instance train_data = RMBDataset (data_dir=train_dir, transform=train_transform) valid_data = RMBDataset (data_dir=valid_dir, transform=valid_transform) # build DataLodertrain_loader = DataLoader (dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) valid_loader = DataLoader (dataset=valid_data Batch_size=BATCH_SIZE) # = = step 2amp 5 Model = = net = LeNet (classes=2) net.initialize_weights () # = = step 3ram 5 loss function = = criterion = nn.CrossEntropyLoss () # Select loss function # = = step 4ram 5 optimizer = = optimizer = optim.SGD (net.parameters (), lr=LR, momentum=0.9) # Select optimizer scheduler = torch.optim.lr_scheduler.StepLR (optimizer, step_size=10) Gamma=0.1) # set learning rate reduction strategy # = = step 5 valid_curve 5 training = = train_curve = list () valid_curve = list () iter_count = building SummaryWriterwriter = SummaryWriter (comment='test_your_comment', filename_suffix= "_ test_your_filename_suffix") for epoch in range (MAX_EPOCH): loss_mean = 0. Correct = 0. Total = 0. Net.train () for I, data in enumerate (train_loader): iter_count + = 1 # forward inputs, labels = data outputs = net (inputs) # backward optimizer.zero_grad () loss = criterion (outputs, labels) loss.backward () # update weights optimizer.step () # Statistical classification _, predicted = torch.max (outputs.data 1) total + = labels.size (0) correct + = (predicted = = labels). Squeeze () .sum () .numpy () # print training information loss_mean + = loss.item () train_curve.append (loss.item ()) if (I + 1)% log_interval = = 0: loss_mean = loss_mean / log_interval print ( "Training:Epoch [{: 0 > 3} / {: 0 > 3}] Iteration [{: 0 > 3} / {: 0 > 3}] Loss: {: .4f} Acc: {: .2%}" .format (epoch MAX_EPOCH, I + 1, len (train_loader), loss_mean, correct / total)) loss_mean = 0. # record data, save it in event file writer.add_scalars ("Loss", {"Train": loss.item ()}, iter_count) writer.add_scalars ("Accuracy", {"Train": correct / total}, iter_count) # each epoch, record gradient Weight for name, param in net.named_parameters (): writer.add_histogram (name +'_ grad', param.grad, epoch) writer.add_histogram (name +'_ data', param, epoch) scheduler.step () # update learning rate # validate the model if (epoch + 1)% val_interval = = 0: correct_val = 0. Total_val = 0. Loss_val = 0. Net.eval () with torch.no_grad (): for j, data in enumerate (valid_loader): inputs, labels = data outputs = net (inputs) loss = criterion (outputs, labels) _, predicted = torch.max (outputs.data 1) total_val + = labels.size (0) correct_val + = (predicted = = labels). Squeeze (). Sum (). Numpy () loss_val + = loss.item () valid_curve.append (loss.item ()) print ("Valid:\ t Epoch [{: 0 > 3} / {: 0 > 3}] Iteration [{: 0 > 3} / {: 0 > 3} ] Loss: {: .4F} Acc: {: 0.2%} ".format (epoch MAX_EPOCH, j + 1, len (valid_loader), loss_val, correct / total)) # record data Saved in event file writer.add_scalars ("Loss", {"Valid": np.mean (valid_curve)}, iter_count) writer.add_scalars ("Accuracy", {"Valid": correct / total}, iter_count) train_x = range (len (train_curve)) train_y = train_curvetrain_iters = len (train_loader) valid_x = np.arange (1 Len (valid_curve) + 1) * train_iters * val_interval # because epochloss is recorded in valid The recording point needs to be converted to iterationsvalid_y = valid_curveplt.plot (train_x, train_y, label='Train') plt.plot (valid_x, valid_y, label='Valid') plt.legend (loc='upper right') plt.ylabel ('loss value') plt.xlabel (' Iteration') plt.show () 2. Batch visualization
Not running on the server, read only
#-*-coding:utf-8-*-"@ brief: visualization of convolution kernels and feature graphs" import torch.nn as nnfrom PIL import Imageimport torchvision.transforms as transformsfrom torch.utils.tensorboard import SummaryWriterimport torchvision.utils as vutilsfrom tools.common_tools import set_seedimport torchvision.models as modelsset_seed (1) # set random seeds #-- -kernel visualization-- # flag = 0flag = 1if flag: writer = SummaryWriter (comment='test_your_comment' Filename_suffix= "_ test_your_filename_suffix") alexnet = models.alexnet (pretrained=True) kernel_num =-1 vis_max = 1 for sub_module in alexnet.modules (): if isinstance (sub_module, nn.Conv2d): kernel_num + = 1 if kernel_num > vis_max: break kernels = sub_module.weight c_out, c_int, Kremw Khamh = tuple (kernels.shape) for o_idx in range (c_out): kernel_idx = Kernels [o _ idx,:] .unsqueeze (1) # make_grid requires BCHW Here we expand the C dimension kernel_grid = vutils.make_grid (kernel_idx, normalize=True, scale_each=True, nrow=c_int) writer.add_image ('{} _ Convlayer_split_in_channel'.format (kernel_num), kernel_grid, global_step=o_idx) kernel_all = kernels.view (- 1,3, kumbh, kumbw) # 3, h W kernel_grid = vutils.make_grid (kernel_all, normalize=True, scale_each=True, nrow=8) # c, h, w writer.add_image ('{} _ all'.format (kernel_num), kernel_grid, global_step=322) print ("{} _ convlayer shape: {}" .format (kernel_num) Tuple (kernels.shape)) writer.close () #-- feature map visualization-- # flag = 0flag = 1if flag: writer = SummaryWriter (comment='test_your_comment' Filename_suffix= "_ test_your_filename_suffix") # data path_img = ". / lena.png" # your path to image normMean = [0.49139968, 0.48215827, 0.44653124] normStd = [0.24703233, 0.24348505, 0.26158768] norm_transform = transforms.Normalize (normMean, normStd) img_transforms = transforms.Compose ([transforms.Resize ((224,224)), transforms.ToTensor () Norm_transform]) img_pil = Image.open (path_img) .convert ('RGB') if img_transforms is not None: img_tensor = img_transforms (img_pil) img_tensor.unsqueeze_ (0) # chw-- > bchw # Model alexnet = models.alexnet (pretrained=True) # forward convlayer1 = alexnet.features [0] fmap_1 = convlayer1 (img_tensor) # preprocessing fmap_1.transpose_ (0 1) # bchw= (1,64,55,55)-> (64,1,55,55) fmap_1_grid = vutils.make_grid (fmap_1, normalize=True, scale_each=True, nrow=8) writer.add_image ('feature map in conv1', fmap_1_grid, global_step=322) writer.close () so far I believe that everyone on the "PyTorch batch visualization how to achieve" have a deeper understanding, might as well to the actual operation of it! Here is the website, more related content can enter the relevant channels to inquire, follow us, continue to learn!
Welcome to subscribe "Shulou Technology Information " to get latest news, interesting things and hot topics in the IT industry, and controls the hottest and latest Internet news, technology news and IT industry trends.
Views: 0
*The comments in the above article only represent the author's personal views and do not represent the views and positions of this website. If you have more insights, please feel free to contribute and share.
Continue with the installation of the previous hadoop.First, install zookooper1. Decompress zookoope
"Every 5-10 years, there's a rare product, a really special, very unusual product that's the most un
© 2024 shulou.com SLNews company. All rights reserved.