torch的使用
pytorch 的使用
random_split
不用自己写划分数据集的函数,pytorch 已经给我们封装好了,那就是 torch.utils.data.random_split()。
torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)
dataset(Dataset) – 要划分的数据集
lengths(Sequence) – 要划分的长度,这是个Sequence
generator (Generator) – 用于随机排列的生成器。简单使用
import torch
from torch.utils.data import random_split
# 随机生成 10 个数
dataset = range(10)
train_dataset, test_dataset = random_split(
dataset=dataset,
lengths=[7, 3],
generator=torch.Generator().manual_seed(0)
)
print(list(train_dataset))
print(list(test_dataset))
# 输出
# [4, 1, 7, 5, 3, 9, 0]
# [8, 6, 2]实际使用
from torch.utils.data import DataLoader, random_split
# 2/10 作为验证数据,剩余作为训练数据
val_size = int(len(data_list) * 0.2)
train_size = len(data_list) - val_size
# 这样就将数据分成2份
train_list, val_list = random_split(data_list, [train_size, val_size])torch.randn
生成 符合标准正态分布(均值=0,方差=1)的浮点数张量。
torch.randn(*sizes, dtype=None, device=None)*sizes:张量的形状,例如(10, 3, 32, 32)dtype:数据类型,默认float32device:设备,如"cpu"或"cuda"