Tensorflow Probability中Categorical


簡介

TensorFlow Probability 是 TensorFlow 中用於概率推理和統計分析的庫。

安裝

安裝最新版本的 TensorFlow Probability:

 pip install --upgrade tensorflow-probability

安裝指定版本的 TensorFlow Probability:

 pip install tensorflow-probability==版本號

有關 TensorFlow 和 TensorFlow Probability 之間的版本對應關系,請參閱 TFP 版本說明

使用

這里僅介紹我常用的一個根據概率分布采樣的功能,其余功能參考官方文檔

Categorical類

用途:創建一個用於表示不同種類概率分布的對象。

import tensorflow_probability as tfp

dist = tfp.distributions.Categorical(
    						logits=None,  # 傳入的是logits的分布(未經過sotfmax)
  							probs=None,  # 傳入的是概率分布
  							dtype=tf.int32,  # 種類的數據類型
  							validate_args=False,
    						allow_nan_stats=True, 
  							name='Categorical'
)

類的屬性和方法

  • dist.probs:得到傳入的probs
  • dist.logits:得到傳入的logits
  • dist.prob(value):返回某個種類的概率
  • dist.log_prob(value):返回某個種類的概率的log
  • dist.sample(sample_shape=(), seed=None, name='sample', **kwargs):按probs的分布采樣種類

舉例

>>> import tensorflow as tf
>>> import tensorflow_probability as tfp
>>> dist = tfp.distributions.Categorical(probs=[0.1, 0.2, 0.7], dtype='float32')

>>> print(dist.probs)
tf.Tensor([0.1 0.2 0.7], shape=(3,), dtype=float32)

>>> print(dist.logits)
None

>>> dist.sample()
<tf.Tensor: shape=(), dtype=float32, numpy=2.0>

>>> dist.log_prob(0)  # 計算種類0對應的prob的log,即log(0.1)
<tf.Tensor: shape=(), dtype=float32, numpy=-2.3025851>

>>> tf.math.log(0.1)  # 結果和上面一樣
<tf.Tensor: shape=(), dtype=float32, numpy=-2.3025851>
>>> dist2 = tfp.distributions.Categorical(logits=[0.1, 0.2, 0.3], dtype='float32')
# 傳入logits在執行prob()時,會自動對其作sotfmax操作
>>> dist2.prob(0)
<tf.Tensor: shape=(), dtype=float32, numpy=0.3006096>

>>> dist2.prob(1)
<tf.Tensor: shape=(), dtype=float32, numpy=0.33222497>

>>> dist2.prob(2)
<tf.Tensor: shape=(), dtype=float32, numpy=0.3671654>

>>> tf.nn.softmax([0.1, 0.2, 0.3])  # 結果和上面一樣
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0.3006096, 0.332225 , 0.3671654], dtype=float32)>

參考

https://tensorflow.google.cn/probability/api_docs/python/tfp/distributions/Categorical?skip_cache=true#attributes


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM