简介:pytorch自带的one-hot编码方法
pytorch自带的one-hot编码方法
在PyTorch中,我们通常使用torch.nn.functional.one_hot函数来进行one-hot编码。one-hot编码是一种将类别变量转换为机器学习模型可以处理的格式的技术。在one-hot编码中,每个类别都被视为一个独立的类别,并且分配一个唯一的二进制向量。这个向量在该类别对应的索引位置为1,其他位置为0。
在PyTorch中,torch.nn.functional.one_hot函数可以非常方便地实现这一转换。该函数的基本语法如下:
torch.nn.functional.one_hot(indices, num_classes=None, dtype=None, device=None, requires_grad=False)
其中:
indices是输入的类别索引张量。num_classes是类别的总数。如果未指定,则假定输入的张量中的最大值决定了类别的总数。dtype是输出的数据类型。如果未指定,则默认为torch.uint8。device是将张量放在哪个设备上。如果未指定,则默认为CPU。requires_grad如果为True,则梯度将被计算并存储在返回的张量中。torch.nn.functional.one_hot进行one-hot编码的例子:输出:
import torchimport torch.nn.functional as F# 假设我们有一个类别索引张量indices = torch.tensor([0, 2, 1, 3])# 使用F.one_hot进行one-hot编码one_hot = F.one_hot(indices)print(one_hot)
注意,上述代码将类别索引张量
tensor([[1, 0, 0, 0],[0, 0, 1, 0],[0, 1, 0, 0],[0, 0, 0, 1]])
indices转换为one-hot编码张量one_hot。在这个例子中,我们假设总共有4个类别,并且类别索引从0开始。因此,indices中的每个值都被转换为一个长度为4的二进制向量,其中只有一个位置为1,其余位置为0。例如,索引0被转换为[1, 0, 0, 0],索引2被转换为[0, 0, 1, 0],等等。