Network Security Internet Technology Development Database Servers Mobile Phone Android Software Apple Software Computer Software News IT Information

In addition to Weibo, there is also WeChat

Please pay attention

WeChat public account

Shulou

How to use the hook function of pytorch

2025-04-06 Update From: SLTechnology News&Howtos shulou NAV: SLTechnology News&Howtos > Internet Technology >

Share

Shulou(Shulou.com)06/01 Report--

This article mainly explains "how to use the hook function of pytorch". Interested friends may wish to have a look. The method introduced in this paper is simple, fast and practical. Now let the editor take you to learn how to use the hook function of pytorch.

Course code "@ brief: hook function of pytorch"import torchimport torch.nn as nnfrom tools.common_tools2 import set_seedset_seed (1) #-- 1 tensor hook 1flag =" flag = 1if flag: W = torch.tensor ([1.], requires_grad=True) x = torch.tensor ([2.] Requires_grad=True) a = torch.add (w, x) b = torch.add (w, 1) y = torch.mul (a, b) a_grad = list () def grad_hook (grad): a_grad.append (grad) handle = a.register_hook (grad_hook) y.backward () # View gradient print ("gradient:", w.grad, x.grad, a.grad, b.grad Y.grad) print ("a_grad [0]:", a_grad [0]) handle.remove () #-2 tensor hook 2flag = torch.tensor = 1if flag: W = torch.tensor ([1.], requires_grad=True) x = torch.tensor ([2.], requires_grad=True) a = torch.add (w X) b = torch.add (w, 1) y = torch.mul (a, b) a_grad = list () def grad_hook (grad): grad * = 2 return grad * 3 handle = w.register_hook (grad_hook) y.backward () print ("w.grad:" W.grad) handle.remove () #-3 Module.register_forward_hook and pre hook# flag = 0flag = 1if flag: class Net (nn.Module): def _ _ init__ (self): super (Net, self). _ _ init__ () self.conv1 = nn.Conv2d (1,2) 3) self.pool1 = nn.MaxPool2d (2,2) def forward (self, x): X = self.conv1 (x) x = self.pool1 (x) return x def forward_hook (module, data_input, data_output): fmap_block.append (data_output) input_block.append (data_input) def forward_pre_hook (module Data_input): print ("forward_pre_hook input: {}" .format (data_input)) def backward_hook (module, grad_input) Grad_output): print ("backward hook input: {}" .format (grad_input)) print ("backward hook output: {}" .format (grad_output)) # initialize the network net = Net () net.conv1.weight [0] .detach (). Fill_ (1) net.conv1.weight [1] .detach (). Fill_ (2) net.conv1.bias.data.detach (). Zero _ () # Registration hook fmap_block = list () input_block = list () net.conv1.register_forward_hook (forward_hook) net.conv1.register_forward_pre_hook (forward_pre_hook) net.conv1.register_backward_hook (backward_hook) # inference fake_img = torch.ones ((1) 1, 4, 4) # batch size * channel * H * W output = net (fake_img) # forward propagation loss_fnc = nn.L1Loss () target = torch.randn_like (output) loss = loss_fnc (target, output) loss.backward () # observe print ("output shape: {}\ noutput value: {}\ n" .format (output.shape Output) print ("feature maps shape: {}\ noutput value: {}\ n" .format (fmap_block [0] .shape, fmap_ block0) print ("input shape: {}\ ninput value: {}" .format (input_block [0] [0] .shape, input_ block0)) job

1. The torch.nn.Module.register_forward_hook mechanism is used to realize the visualization of the output characteristic graph of the first convolution layer of AlexNet, and the 28th line of / torchvision/models/alexnet.py is changed to: nn.ReLU (inplace=False), observe

The difference between inplace=True and inplace=False

1. Hook draws feature map #-*-coding:utf-8-*-"@ brief: use hook function to visualize feature map" import torch.nn as nnimport numpy as npfrom PIL import Imageimport torchvision.transforms as transformsimport torchvision.utils as vutilsfrom torch.utils.tensorboard import SummaryWriterfrom tools.common_tools2 import set_seedimport torchvision.models as modelsset_seed (1) # set random seed #- -feature map visualization# flag = 0flag = 1if flag: writer = SummaryWriter (comment='test_your_comment' Filename_suffix= "_ test_your_filename_suffix") # data path_img = ". / lena.png" # your path to image normMean = [0.49139968, 0.48215827, 0.44653124] normStd = [0.24703233, 0.24348505, 0.26158768] norm_transform = transforms.Normalize (normMean, normStd) img_transforms = transforms.Compose ([transforms.Resize ((224,224)), transforms.ToTensor () Norm_transform]) img_pil = Image.open (path_img) .convert ('RGB') if img_transforms is not None: img_tensor = img_transforms (img_pil) img_tensor.unsqueeze_ (0) # chw-- > bchw # Model alexnet = models.alexnet (pretrained=True) # Registration hook fmap_dict = dict () for name Sub_module in alexnet.named_modules (): if isinstance (sub_module, nn.Conv2d): key_name = str (sub_module.weight.shape) fmap_dict.setdefault (key_name, list ()) N1, N2 = name.split (".") Def hook_func (m, I, o): key_name = str (m.weight.shape) fmap_ modules [key _ name] .append (o) alexnet._ modules [N1]. _ modules [N2] .register _ forward_hook (hook_func) # forward output = alexnet (img_tensor) # add image for layer_name Fmap_list in fmap_dict.items (): fmap = fmap_list [0] fmap.transpose_ (0,1) nrow= int (np.sqrt (fmap.shape [0])) fmap_grid = vutils.make_grid (fmap, normalize=True, scale_each=True, nrow=nrow) writer.add_image ('feature map in {}' .format (layer_name), fmap_grid, global_step=322) so far I believe you have a deeper understanding of "how to use the hook function of pytorch", so you might as well do it in practice. Here is the website, more related content can enter the relevant channels to inquire, follow us, continue to learn!

Welcome to subscribe "Shulou Technology Information " to get latest news, interesting things and hot topics in the IT industry, and controls the hottest and latest Internet news, technology news and IT industry trends.

Views: 0

*The comments in the above article only represent the author's personal views and do not represent the views and positions of this website. If you have more insights, please feel free to contribute and share.

Share To

Internet Technology

Wechat

© 2024 shulou.com SLNews company. All rights reserved.

12
Report