Network Security Internet Technology Development Database Servers Mobile Phone Android Software Apple Software Computer Software News IT Information

In addition to Weibo, there is also WeChat

Please pay attention

WeChat public account

Shulou

How to calculate kl divergence by pytorch

2025-04-11 Update From: SLTechnology News&Howtos shulou NAV: SLTechnology News&Howtos > Development >

Share

Shulou(Shulou.com)06/01 Report--

This article will explain in detail how pytorch calculates kl divergence. The editor thinks it is very practical, so I share it with you as a reference. I hope you can get something after reading this article.

cause

Editor accidentally from the pytorch discussion forum to see a problem, kl divergence in TensorFlow and pytorch calculation results are different, usually do not notice, record.

Kl divergence introduction

KL divergence (Kullback-Leibler divergence), also known as relative entropy, is a method to describe the difference between two probability distributions P and Q. Calculation formula:

It can be found that the number of elements in P and Q is not equal, only the discrete elements in the two distributions are the same.

Take a simple example:

The two discrete distributions are P and Q respectively.

The distribution of P is {1pm 1pm 2pm 2pm 3}.

The distribution of Q is as follows: {1pm 1pm 1pm 2pm 3pm 3pm 3pm 3}

We find that although the number of elements in the two distributions is different, the number of elements in P is 5 and the number of elements in Q is 10. But all the elements have "1", "2" and "3".

When x = 1, in P distribution, the number of "1" element is 2, so P (x = 1) = 2max 5 = 0.4, in Q distribution, the number of "1" element is 5, so Q (x = 1) = 5max 10 = 0.5

The same principle

When x = 2, P (x = 2) = 2max 5 = 0.4, Q (x = 2) = 1max 10 = 0.1

When x = 3, P (x = 3) = 1max 5 = 0.2, Q (x = 3) = 4max 10 = 0.4

Bring the above probability into the formula:

At this point, the KL divergence of the distribution of two discrete variables is calculated.

Kl_div function in pytorch

There is a function kl_div in pytorch for calculating kl divergence.

Torch.nn.functional.kl_div (input, target, size_average=None, reduce=None, reduction='mean')

Calculate D (p | | Q)

1. The calculated result without this function is:

The result is the same as the manual calculation.

2. Use the function:

(this calculation is correct, and the result is different because the pytorch function is based on e by default.)

Note:

1. The position of p / Q in the function is opposite (that is, if you want to calculate D (p | | Q), it should be written in the form of kl_div (q.log (), p), and Q should be taken as log first.

2. Reduction chooses what to do with each part of the result. The default is to take the average, and here you choose summation.

What an awkward usage. I don't know why the authorities designed it like this.

Supplement: the implementation of KL divergence of pytorch

Look at the code ~ import torch.nn.functional as F# p_logit: [batch, class_num] # q_logit: [batch, class_num] def kl_categorical (p_logit, q_logit): P = F.softmax (p_logit, dim=-1) _ kl = torch.sum (p * (F.log_softmax (p_logit, dim=-1)-F.log_softmax (q_logit, dim=-1)) 1) this is the end of the return torch.mean (_ kl) article on "how pytorch calculates kl divergence". Hope that the above content can be helpful to you, so that you can learn more knowledge, if you think the article is good, please share it for more people to see.

Welcome to subscribe "Shulou Technology Information " to get latest news, interesting things and hot topics in the IT industry, and controls the hottest and latest Internet news, technology news and IT industry trends.

Views: 0

*The comments in the above article only represent the author's personal views and do not represent the views and positions of this website. If you have more insights, please feel free to contribute and share.

Share To

Development

Wechat

© 2024 shulou.com SLNews company. All rights reserved.

12
Report