【Categorical()】torch.distributions.categorical.Categorical()
torch.distributions.categorical.Categorical 是PyTorch库中的一个类,它用于处理离散概率分布问题,特别是涉及在多个类别中进行选择的情况。这个类在机器学习和深度学习中非常有用,特别是在处理分类问题、强化学习的策略输出或者多分类任务的概率分布建模时。调用sample()方法来执行采样。
·
torch.distributions.categorical.Categorical 是PyTorch库中的一个类,它用于处理离散概率分布问题,特别是涉及在多个类别中进行选择的情况。
这个类在机器学习和深度学习中非常有用,特别是在处理分类问题、强化学习的策略输出或者多分类任务的概率分布建模时。
调用sample()方法来执行采样
import torch
from torch.distributions import Categorical
# 假设我们有三个类别,它们的概率分别为0.2, 0.5, 0.3
probs = torch.tensor([0.2, 0.5, 0.3])
# 创建Categorical分布实例
# dist = Categorical(probs=probs)
dist = Categorical(probs)
# 采样一个样本
sample = dist.sample()
print("单个样本:", sample.item())
# 采样一批样本,比如10个样本
samples = dist.sample((10,))
print("10个样本:", samples)
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐

所有评论(0)