PyTorch中的随机数生成函数:torch.rand()、torch.randn()、torch.randint()和torch.randperm()用法详解

作者:php是最好的2024.02.16 18:12浏览量:39

简介:本文将详细介绍PyTorch中用于生成随机数的四个函数:torch.rand()、torch.randn()、torch.randint()和torch.randperm()。这些函数在深度学习和科学计算中经常被用来初始化参数或进行随机采样。本文将通过示例和说明,帮助您理解它们的用法和工作原理,以及在何时使用它们。

1. torch.rand()

torch.rand() 用于生成0到1之间的随机浮点数。这个函数返回一个形状为 [dim1, dim2, ...] 的张量,其中 dim1, dim2, ... 是输入参数。

示例:

  1. import torch
  2. x = torch.rand(3, 4) # 生成一个3x4的张量,元素值为0到1之间的随机数
  3. print(x)

2. torch.randn()

torch.randn() 用于生成均值为0、标准差为1的正态分布随机数。这个函数返回一个形状为 [dim1, dim2, ...] 的张量,其中 dim1, dim2, ... 是输入参数。

示例:

  1. import torch
  2. x = torch.randn(3, 4) # 生成一个3x4的张量,元素值来自均值为0、标准差为1的正态分布
  3. print(x)

3. torch.randint()

torch.randint() 用于生成指定范围内的随机整数。该函数接受三个参数:low(包含下限)、high(不包含上限)和size(生成的张量形状)。

示例:

  1. import torch
  2. x = torch.randint(0, 5, (3, 4)) # 生成一个3x4的张量,元素值为0到4的随机整数(包含0,不包含5)
  3. print(x)

4. torch.randperm()

torch.randperm() 用于生成一个随机排列的整数张量。该函数接受一个参数:生成的张量形状的长度。返回的张量中每个元素都是从0到形状长度减1的整数,且这些整数是随机排列的。

示例:

  1. import torch
  2. x = torch.randperm(5) # 生成一个包含0到4的随机排列整数的张量
  3. print(x)