Network Security Internet Technology Development Database Servers Mobile Phone Android Software Apple Software Computer Software News IT Information

In addition to Weibo, there is also WeChat

Please pay attention

WeChat public account

Shulou

How to use the torch.topk () function in pytorch

2025-03-26 Update From: SLTechnology News&Howtos shulou NAV: SLTechnology News&Howtos > Development >

Share

Shulou(Shulou.com)06/01 Report--

This article mainly introduces "how to use the torch.topk () function in pytorch". In daily operation, I believe many people have doubts about how to use the torch.topk () function in pytorch. The editor consulted all kinds of materials and sorted out simple and easy-to-use operation methods. I hope it will be helpful to answer the doubts about "how to use the torch.topk () function in pytorch". Next, please follow the editor to study!

Function function:

The function of this function is taken literally, topk: to sort the first k elements of the array.

Typically, this function returns two values, the first of which is the sorted array, and the second is the location label of the elements obtained in the array in the original array.

For example, Chestnut: import numpy as npimport torchimport torch.utils.data.dataset as Datasetfrom torch.utils.data import Dataset,DataLoader# prepares an array of # tensor1=torch.tensor ([10 tensor1=torch.tensor 1, 2, 2, 1, 1, 1, 1, 1, 1, 1), [3, 4, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]. 1,1,1,1,1,1,1], [1,4,7,1,1,1,1,1,1,1,1]] Dtype=torch.float32) # print the original array # print ('tensor1:') print (tensor1) # use the torch.topk () function # print (' use torch. The function topk () gets the following:') 'knew3 means to get three elements from the original array Dim=1 means to get elements from the first dimension in the original array (in this case, they are obtained from four arrays, namely, [10pyrrone, 2pyrrine, 2pyrrine, 1pyrrine, 1pyrrine, 1phion, 1pyrort, 1jort, 1pyrrine, 1pyrrine, 1pyrum, 1pyrum, 1pyrum, 1pyrum, 1pyrum, 1pyrum, 1pyrum, 1pyrrine, and dim=1], where largest=True means to fetch three elements from big to small, such as''print' (torch.topk (tensor1, Kjorie 3, dim=1). Largest=True)) # print the first return value of this function # print ('the first return value of the function topk [0] is as follows') print (torch.topk (tensor1, KF3, dim=1) Largest=True) [0]) # print the second return value of this function # print ('the second return value of the function topk [1] is as follows') print (torch.topk (tensor1, KF3, dim=1) Largest=True) [1])''# run result # # tensor1:tensor ([[10, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1.], [3, 4, 5, 1. 1, 1, 1, 1.], [7, 8, 9, 1, 1, 1, 1, 1, 1, 1, 1, 1.], [1, 4, 7, 1, 1, 1, 1, 1, 1, 1, 1.]) Use the torch.topk () function to get: 'the resulting values is the four sets of three element values from the largest to the smallest of the original array dim=1; the resulting indices is the position of the acquired element values in the original array dim=1.' Torch.return_types.topk (values=tensor ([[10, 10, 2.], [5, 4, 3.], [9, 8, 7.], [7, 4, 1.]), indices=tensor ([[0, 10, 2], [2, 1, 0], [2, 1, 0]) 0]]) the first return value of the function topk [0] is as follows tensor ([[10.10.10.2.], [5.4.3.], [9.8.7.], [7.4,1.]])) The second return value of the function topk [1] is as follows: tensor ([[0,10,2], [2,1,0], [2,1,0], [2,1,0]])''

This function is often used to get the largest or smallest element and index position in a tensor or array, and is a frequently used basic function.

Example demonstration

Task 1:

Take top1 (maximum):

Pred = torch.tensor ([[- 0.5816,-0.3873,-1.0215,-1.0145, 0.4053], [0.7265, 1.4164, 1.3443, 1.2035, 1.8823], [- 0.4451,0.1673,1.2590,-2.0757,1.7255], [0.2021,0.3041,0.1383,0.3849] -1.6311]]) print (pred) values, indices = pred.topk (1, dim=0, largest=True, sorted=True) print (indices) print (values) # results obtained by max Set keepdim to True to avoid dimension reduction. Because the index returned by the topk function is not dimensionally reduced, the shape is consistent with the input. _, indices_max = pred.max (dim=0, keepdim=True) print (indices_max) print (indices_max = = indices) output: tensor ([[- 0.5816,-0.3873,-1.0215,-1.0145, 0.4053], [0.7265, 1.4164, 1.3443, 1.2035, 1.8823], [- 0.4451,0.1673,1.2590,-2.0757,1.7255] [0.2021, 0.3041, 0.1383, 0.3849,-1.6311]) tensor ([[1, 1, 1, 1, 1]]) tensor ([[0.7265, 1.4164, 1.3443, 1.2035, 1.8823]]) tensor ([[1, 1, 1, 1, 1]) tensor ([[True, True]])

Task 2:

Extract the topk by line, and set the value smaller than topk to inf:

Pred = torch.tensor ([[- 0.5816,-0.3873,-1.0215,-1.0145, 0.4053], [0.7265, 1.4164, 1.3443, 1.2035, 1.8823], [- 0.4451,0.1673,1.2590,-2.0757,1.7255], [0.2021,0.3041,0.1383,0.3849] -1.6311]]) print (pred) top_k = 2 # find the largest two values of each row by line filter_value=-float ('Inf') indices_to_remove = pred < torch.topk (pred, top_k) [0] [.,-1 None] print (indices_to_remove) pred [logs _ to_remove] = filter_value # for elements other than topk, the logits value is set to negative infinite print (pred) output: tensor ([[- 0.5816,-0.3873,-1.0215,-1.0145, 0.4053], [0.7265, 1.4164, 1.3443, 1.2035, 1.8823], [- 0.4451,0.1673] 1.2590,-2.0757, 1.7255], [0.2021, 0.3041, 0.1383, 0.3849,-1.6311]) tensor ([[4], [4], [4], [3]]) tensor ([[0.4053], [1.8823], [1.7255], [0.3849]) tensor ([[True, False, True]) True, False], [True, False, True, True, False], [True, True, False, True, False], [True, False, True, False, True]]) tensor ([[- inf,-0.3873,-inf,-inf, 0.4053], [- inf, 1.4164,-inf,-inf, 1.8823] [- inf,-inf, 1.2590,-inf, 1.7255], [- inf, 0.3041,-inf, 0.3849,-inf]])

Task 3:

Import numpy as npimport torchimport torch.utils.data.dataset as Datasetfrom torch.utils.data import Dataset,DataLoadertensor1=torch.tensor], [3, 4, 5, 5, 5, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 10, 5, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, # [6print 5print 4], # [1je 4je 7], # [9je 8je 7], dtype=torch.float32) # print ('tensor1:') print (tensor1) print (' Direct output topk) You'll get two things. What we need is a second indices') print (torch.topk (tensor1, KFL3, dim=1, largest=True) print ('topk [0] below') print (torch.topk (tensor1, KFL3, dim=1, largest=True) [0]) print ('topk [1] below') print (torch.topk (tensor1, KFL3, dim=1, largest=True) [1])''tensor1:tensor ([[10.1,1.2.1.1.1.1.1.1.1.1. 1., 10.], [3., 4., 5., 1., 1., 1., 1.], [7., 8., 9., 1., 1., 1., 1.], [1., 4., 7., 1., 1. 1, 1, 1, 1.]) If you output topk directly, you will get two things What we need is a second indicestorch.return_types.topk (values=tensor ([[10.10.10.2.], [5.4.3.], [9.8.7], [7.4.1]]), indices=tensor ([[0,10,2], [2,1,0], [2,1,0], [2] 1, 0]) topk [0] is as follows tensor ([[10.,10.2.], [5.4,3.], [9.8.7], [7.,4.1]]) topk [1] is as follows tensor ([[0,10,2], [2,1,0]) [2, 1, 0])''so far The study on "how to use the torch.topk () function in pytorch" is over. I hope to be able to solve your doubts. The collocation of theory and practice can better help you learn, go and try it! If you want to continue to learn more related knowledge, please continue to follow the website, the editor will continue to work hard to bring you more practical articles!

Welcome to subscribe "Shulou Technology Information " to get latest news, interesting things and hot topics in the IT industry, and controls the hottest and latest Internet news, technology news and IT industry trends.

Views: 0

*The comments in the above article only represent the author's personal views and do not represent the views and positions of this website. If you have more insights, please feel free to contribute and share.

Share To

Development

Wechat

© 2024 shulou.com SLNews company. All rights reserved.

12
Report