简介:PyTorch中Linear层的原理 | PyTorch系列(十六)
在PyTorch中,线性层(Linear layer)是一种重要的基础层,它提供了全连接、批量输入和输出以及多维数据输入等功能。线性层在很多深度学习模型中都起着核心作用,特别是在神经网络中。本文将深入探讨PyTorch中Linear层的原理以及其重要特性。
在PyTorch中,Linear层是由torch.nn.Linear类定义的,它基本上是一个全连接的神经网络层。全连接意味着该层中的每个神经元都与前一层的所有神经元相连。
import torch.nn as nnclass MyNet(nn.Module):def __init__(self):super(MyNet, self).__init__()self.fc1 = nn.Linear(in_features=784, out_features=128)def forward(self, x):x = self.fc1(x)return x
在这个例子中,我们定义了一个名为fc1的线性层,输入特征数为784,输出特征数为128。在前向传播的过程中,该层将接受一个784维的输入向量,然后通过线性变换(权重乘以输入再加上偏置项)生成一个128维的输出向量。
线性层中的两个关键参数是权重(weights)和偏置项(bias)。权重是用于映射输入特征的系数,而偏置项则为输出向量添加一个固定的偏移。在训练过程中,这些参数会通过反向传播和梯度下降等算法进行更新。
PyTorch中的nn.Linear类会自动为每个输入通道创建一个权重矩阵和一个偏置向量。当输入具有多个通道时(如彩色图像),这些权重和偏置项会被重复使用。
线性层的运算过程主要包括三个步骤:
在这个例子中,我们添加了一个ReLU激活函数,用于对线性层的输出进行非线性转换。在训练深度学习模型时,激活函数可以帮助引入非线性因素,从而提高模型的表达能力。
import torchimport torch.nn as nnimport torch.nn.functional as Fclass MyNet(nn.Module):def __init__(self):super(MyNet, self).__init__()self.fc1 = nn.Linear(in_features=784, out_features=128)self.relu = nn.ReLU() # 定义ReLU激活函数def forward(self, x):x = self.fc1(x)x = self.relu(x) # 通过ReLU激活函数return x