PyTorch实现BatchNorm2d:自定义BN层的深度探索

作者:渣渣辉2023.12.25 15:06浏览量:9

简介:DBN pytorch代码实现 pytorch batchnorm2d

DBN pytorch代码实现 pytorch batchnorm2d
深度学习中的批量归一化(Batch Normalization,简称BN)是一种重要的技术,它能够加速训练过程,提高模型的稳定性。PyTorch中已经内置了torch.nn.BatchNorm2d这样的模块,但在某些特殊情况下,我们可能需要进行更细粒度的控制,这就需要自定义实现BN层。下面将展示如何在PyTorch中手动实现一个BN层。
Batch Normalization(BN)的步骤主要如下:

  1. 对每一个batch的数据进行均值和方差的归一化。
  2. 对归一化后的数据添加一个线性变换(scale和shift)。
    以下是一个PyTorch实现的例子:
    1. import torch
    2. import torch.nn as nn
    3. class MyNet(nn.Module):
    4. def __init__(self):
    5. super(MyNet, self).__init__()
    6. self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
    7. self.bn1 = MyBatchNorm2d(64) # 在这里实例化我们自定义的BN层
    8. self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
    9. self.bn2 = MyBatchNorm2d(128) # 在这里实例化我们自定义的BN层
    10. self.fc1 = nn.Linear(128 * 7 * 7, 256)
    11. self.fc2 = nn.Linear(256, 10)
    12. def forward(self, x):
    13. x = F.relu(self.bn1(self.conv1(x))) # 先经过卷积,然后经过自定义的BN层,最后经过ReLU激活函数
    14. x = F.relu(self.bn2(self.conv2(x))) # 先经过卷积,然后经过自定义的BN层,最后经过ReLU激活函数
    15. x = x.view(-1, 128 * 7 * 7) # 展平操作,为全连接层做准备
    16. x = F.relu(self.fc1(x)) # 经过全连接层
    17. x = self.fc2(x) # 再次经过全连接层
    18. return x
    在上面的代码中,我们定义了一个自定义的BN层MyBatchNorm2d,并在卷积层之后使用这个自定义的BN层。下面是MyBatchNorm2d的实现:
    ```python
    class MyBatchNorm2d(nn.Module):
    def init(self, numfeatures):
    super(MyBatchNorm2d, self)._init
    ()
    self.num_features = num_features
    self.weight = nn.Parameter(torch.randn(num_features)) # weight在归一化时作为scale参数,因此初始化为随机数即可。
    self.bias = nn.Parameter(torch.zeros(num_features)) # bias在归一化时作为shift参数,因此初始化为0即可。
    self.running_mean = torch.zeros(num_features) # 运行均值,用于存储训练过程中的均值。这个变量需要是可训练的,因此初始化为0。
    self.running_var = torch.ones(num_features) # 运行方差,用于存储训练过程中的方差。这个变量需要是可训练的,因此初始化为1。
    self.epsilon = 1e-5 # 防止除0错误的小常数。
    def forward(self, x):

    计算均值和方差

    mean = x.mean([0, 2, 3], keepdim=True) # 对batch和channel进行均值计算。
    var = x.var([0, 2, 3], unbiased=False, keepdim=True) # 对batch和channel进行方差计算。注意这里的unbiased=False表示使用有偏估计。

    归一化和线性变换(也可以调用F.batch_norm函数来实现这两个步骤)

    x_normalized = (x-mean)/var.sqrt()+mean # 进行归一化,并进行线性变换。注意这里进行了两次mean操作,一次是减去均值,一次是加上均值。这是因为在归一化时减去均值后,需要再通过线性变换加上均值,以实现数据的平移。这个过程等价于先计算数据的偏移量并