PyTorchによる最尤推定の損失関数の実装例をまとめました。

【PyTorch】最尤推定の損失関数

機械学習で最尤推定を行うためには、どのような損失関数を設定すればいいんでしょうか?

この記事では、PyTorchを用いた最尤推定の損失関数の実装例をご紹介します。

理論(初心者向け)

ざっくり言うと、正解データに対する確率密度が大きくなるような平均・分散を出力できるように、損失関数を設定します。

以前にかみ砕いた理論の記事を書いたので、よろしければ読んでみてください。
機械学習で最尤推定【初心者向け】

損失関数の実装例

# outputs:ネットワーク出力
# labels:正解データ
# hogeGetMu():ネットワーク出力から平均を抽出する関数
# hogeGetCov():ネットワーク出力から共分散行列を抽出する関数

def criterion(outputs, labels):
	mu = hogeGetMu(outputs)
	cov = hogeGetCov(outputs)
	dist = torch.distributions.MultivariateNormal(mu, cov)
	loss = -dist.log_prob(labels)
	loss = loss.mean()
	return loss

解説

上記の実装例について、少し解説します。

ネットワークは何を出力するの?

ネットワークは、平均と共分散行列を出力します。厳密には、上記の実装例のように、何かしらの関数で平均と共分散行列を抽出することになるでしょう。

多変量正規分布

多変量正規分布はtorch.distributions.MultivariateNormalを使うことで簡単に実装できます。

引数の詳細などは公式ドキュメントで確認できます。
Probability distributions - torch.distributions

対数確率密度

MultivariateNormalのメンバ関数log_prob(value)を使うことで、確率密度の自然対数をとった値を計算できます。

自然対数をとる理由は、勾配消失をふせいだり、誤差逆伝播の計算コストを減らすためなどです。

マイナスをとる理由

機械学習は損失関数の値を小さくするように重みパラメータを更新するので、最大化したい対数確率密度にマイナスをかけます。

さいごに

少し複雑な最尤推定の損失関数も、PyTorchの関数を使うことで簡単に実装できました。参考になれば幸いです。


以上です。

Ad.