Network Security Internet Technology Development Database Servers Mobile Phone Android Software Apple Software Computer Software News IT Information

In addition to Weibo, there is also WeChat

Please pay attention

WeChat public account

Shulou

What is the hook mechanism in pytorch?

2025-01-29 Update From: SLTechnology News&Howtos shulou NAV: SLTechnology News&Howtos > Development >

Share

Shulou(Shulou.com)06/01 Report--

This article introduces the relevant knowledge of "what is the hook mechanism in pytorch". In the operation of actual cases, many people will encounter such a dilemma, so let the editor lead you to learn how to deal with these situations. I hope you can read it carefully and be able to achieve something!

1. Hook background

Hook is known as the hook mechanism, which is not the first creation of pytorch and has been widely used in Windows programming, including in-process hooks and global hooks. According to my own understanding, the role of hook is to maintain a linked list through the system, allowing users to intercept (obtain) communication messages for handling events.

Pytorch contains two hook registration functions, forward and backward, which are used to obtain the input and output in forward and backward. According to my incomplete understanding, the goal should be "not to change the definition code of the network, nor need to return the output of a layer of interest in the forward function, so the code is too jumbled".

2. Source code reading

The register_forward_hook () function must be used before the forward () function is called, because the function source comments show that the function "it will not have effect on forward since this is called after: func: `roomd` is called", that is, the function has no effect after forward ()! ):

Function: get the input and output of each layer in the forward process, which is used to compare whether the hook is recorded correctly.

Def register_forward_hook (self, hook): r "Registers a forward hook on the module. The hook will be called every time after: func: `roomd` has computed an output. It should have the following signature:: hook (module, input, output)-> None or modified output The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after: func: `roomd` is called. Returns:: class: `torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove () ``"handle = hooks.RemovableHandle (self._forward_hooks) self._forward_ hooks [handle.id] = hook return handle3, define a class to test hooker

If you initialize each layer randomly, you won't be able to test whether the input and output you get is the input and output in forward, so you need to set the weight and offset of each layer to a recognizable value (for example, initialize all to 1). The network consists of two layers (Linear has parameters that need to be derived is called a layer, while ReLU does not have parameters that need to be derived is not called a layer), and the initialize function is called in _ _ init__ () to initialize all layers.

Note: the forward () function returns the output of each layer, but ReLU6 does not, because hook is not registered for this layer during subsequent tests.

Class TestForHook (nn.Module): def _ init__ (self): super (). _ init__ () self.linear_1 = nn.Linear (in_features=2, out_features=2) self.linear_2 = nn.Linear (in_features=2, out_features=1) self.relu = nn.ReLU () self.relu6 = nn.ReLU6 () self.initialize () def forward (self X): linear_1 = self.linear_1 (x) linear_2 = self.linear_2 (linear_1) relu = self.relu (linear_2) relu_6 = self.relu6 (relu) layers_in = (x, linear_1, linear_2) layers_out = (linear_1, linear_2, relu) return relu_6, layers_in Layers_out def initialize (self): "defines a special initialization Used to verify that the weight "" self.linear_1.weight = torch.nn.Parameter (torch.FloatTensor ([[1,1], [1,1]])) self.linear_1.bias = torch.nn.Parameter (torch.FloatTensor ([1,1])) self.linear_2.weight = torch.nn.Parameter (torch.FloatTensor ([[1])]) is obtained. 1]]) self.linear_2.bias = torch.nn.Parameter (torch.FloatTensor ([1])) return True4, define hook function

The hook () function is a parameter that must be provided by the register_forward_hook () function. The advantage is that "users can decide what to do after intercepting intermediate information!" For example, you want to simply record the input and output of the network (you can also make more complex operations such as modifications).

First define several containers for recording:

Define the container used to obtain the input and output tensor of each layer of the network:

# and define module_name to record the corresponding module name module_name = [] features_in_hook = [] features_out_hook = [] hook function requires three parameters, these three parameters are passed to the hook function by the system, you cannot modify these three parameters:

The hook function is responsible for adding the acquired input and output to the feature list and providing the corresponding module name

Def hook (module, fea_in, fea_out): print ("hooker working") module_name.append (module.__class__) features_in_hook.append (fea_in) features_out_hook.append (fea_out) return None5, register hook for the required tier

The registration hook must be registered before the forward () function is executed, that is, before the network is defined for calculation. The following code registers all layers of the network except ReLU6 (you can also select some layers to register):

Registering hooks can be done separately for certain layers:

Net = TestForHook () net_chilren = net.children () for child in net_chilren: if not isinstance (child, nn.ReLU6): child.register_forward_hook (hook=hook) 6, test the characteristics returned by forward () and the consistency of hook records 6.1 Test the input and output characteristics provided by forward ()

Since the previous forward () function returns the characteristics that need to be recorded, you can test it directly:

Out, features_in_forward, features_out_forward = net (x) print ("*" * 5 + "forward return features" + "*" * 5) print (features_in_forward) print (features_out_forward) print ("*" * 5 + "forward return features" + "*" * 5)

It is only natural to get the following output:

* forward return features*

(tensor ([0.1000, 0.1000])

[0.1000, 0.1000]), tensor ([1.2000, 1.2000]

[1.2000, 1.2000]], grad_fn=), tensor ([3.4000]

[3.4000]], grad_fn=))

(tensor ([1.2000, 1.2000])

[1.2000, 1.2000]], grad_fn=), tensor ([3.4000]

[3.4000]], grad_fn=), tensor ([3.4000]

[3.4000]], grad_fn=))

* forward return features*

6.2 input and output characteristics of hook records

Hook is recorded through the list structure, so you can print directly

Test whether features_in stores input:

Print ("*" * 5 + "hook record features" + "*" * 5) print (features_in_hook) print (features_out_hook) print (module_name) print ("*" * 5 + "hook record features" + "*" * 5)

Get the same result as forward:

* hook record features*

[(tensor ([0.1000, 0.1000])

[0.1000, 0.1000]),), (tensor ([1.2000, 1.2000])

[1.2000, 1.2000], grad_fn=),), (tensor ([[3.4000])

[3.4000]], grad_fn=),)]

[tensor ([1.2000, 1.2000])

[1.2000, 1.2000]], grad_fn=), tensor ([3.4000]

[3.4000]], grad_fn=), tensor ([3.4000]

[3.4000]], grad_fn=)]

[

]

* hook record features*

6.3.Subtract hook records and forward

If you are afraid of a numerical inconsistency after the decimal point, or a data type mismatch, you can subtract the characteristics of the hook record and the forward record:

Test whether the feautes_in returned by forward is consistent with the hook record:

Print ("sub result'") for forward_return, hook_record in zip (features_in_forward, features_in_hook): print (forward_return-hook_record [0])

All you get are 0, which means there is no problem with hook:

Sub resulttensor ([[0, 0.], [0, 0.]) tensor ([[0, 0.], [0, 0.], grad_fn=) tensor ([[0.], [0.]] Grad_fn=) 7. Complete code import torchimport torch.nn as nnclass TestForHook (nn.Module): def _ _ init__ (self): super (). _ init__ () self.linear_1 = nn.Linear (in_features=2, out_features=2) self.linear_2 = nn.Linear (in_features=2 Out_features=1) self.relu = nn.ReLU () self.relu6 = nn.ReLU6 () self.initialize () def forward (self, x): linear_1 = self.linear_1 (x) linear_2 = self.linear_2 (linear_1) relu = self.relu (linear_2) relu_6 = self.relu6 (relu) layers_in = (x, linear_1 Linear_2) layers_out = (linear_1, linear_2, relu) return relu_6, layers_in, layers_out def initialize (self): "" defines a special initialization Used to verify that the weight "" self.linear_1.weight = torch.nn.Parameter (torch.FloatTensor ([[1,1], [1,1]])) self.linear_1.bias = torch.nn.Parameter (torch.FloatTensor ([1,1])) self.linear_2.weight = torch.nn.Parameter (torch.FloatTensor ([[1])]) is obtained. 1]]) self.linear_2.bias = torch.nn.Parameter (torch.FloatTensor ([1])) return True

Define the container used to obtain the input and output tensor of each layer of the network, and define the module_name to record the corresponding module name

Module_name = [] features_in_hook = [] features_out_hook = []

The hook function is responsible for adding the acquired input and output to the feature list and providing the corresponding module name

Def hook (module, fea_in, fea_out): print ("hooker working") module_name.append (module.__class__) features_in_hook.append (fea_in) features_out_hook.append (fea_out) return None

Define inputs that are all 1s:

X = torch.FloatTensor ([[0.1,0.1], [0.1,0.1]])

Registering hooks can be done separately for certain layers:

Net = TestForHook () net_chilren = net.children () for child in net_chilren: if not isinstance (child, nn.ReLU6): child.register_forward_hook (hook=hook)

Test the network output:

Out, features_in_forward, features_out_forward = net (x)

Print ("*" * 5 + "forward return features" + "*" * 5)

Print (features_in_forward)

Print (features_out_forward)

Print ("*" * 5 + "forward return features" + "*" * 5)

Test whether features_in stores input:

Print ("*" * 5 + "hook record features" + "*" * 5) print (features_in_hook) print (features_out_hook) print (module_name) print ("*" * 5 + "hook record features" + "*" * 5)

Test whether the feautes_in returned by forward is consistent with the hook record:

Print ("sub result")

For forward_return, hook_record in zip (features_in_forward, features_in_hook):

Print (forward_return-hook_record [0])

This is the end of the content of "what is the hook mechanism in pytorch". Thank you for reading. If you want to know more about the industry, you can follow the website, the editor will output more high-quality practical articles for you!

Welcome to subscribe "Shulou Technology Information " to get latest news, interesting things and hot topics in the IT industry, and controls the hottest and latest Internet news, technology news and IT industry trends.

Views: 0

*The comments in the above article only represent the author's personal views and do not represent the views and positions of this website. If you have more insights, please feel free to contribute and share.

Share To

Development

Wechat

© 2024 shulou.com SLNews company. All rights reserved.

12
Report