简介:nn.Flatten()是PyTorch中的一个函数,用于将多维张量平坦化为二维张量。该函数通过连续堆叠输入张量的所有元素,生成一个新的二维张量。本篇文章将详细介绍nn.Flatten()函数的作用、参数和使用方法,并通过示例展示其应用。
nn.Flatten()函数的作用是将多维张量平坦化为二维张量。在深度学习中,常常需要将高维输入数据展平为一维向量或二维矩阵,以便进行后续的神经网络操作。nn.Flatten()函数提供了简便的方法来实现这一需求。
参数:
start_dim(int):起始维度。从哪个维度开始展平输入张量。默认为0。end_dim(int):结束维度。展平到哪个维度结束。默认为-1,表示展平到倒数第二个维度。在上面的示例中,我们首先创建了一个形状为[batch_size, channels, height, width]的四维张量x,然后使用nn.Flatten()函数将其展平为二维张量x_flat。最后,我们输出了展平后的张量形状。需要注意的是,nn.Flatten()函数会自动计算展平后的维度大小,因此不需要手动指定展平后的维度。
import torchimport torch.nn as nn# 创建一个四维张量,形状为[batch_size, channels, height, width]x = torch.randn(32, 3, 28, 28) # 假设batch_size=32,channels=3,height=28,width=28# 使用nn.Flatten()函数将张量展平为二维张量x_flat = nn.Flatten()(x)# 输出展平后的张量形状print(x_flat.shape) # 输出:[32, 3 * 28 * 28]