使用方法以及要点
一、不用sampler
# 训练数据集的加载器,自动将数据分割成batch,顺序随机打乱 train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, drop_last = True , shuffle=True)
二、使用sampler
首先,我们定义下标数组indices,它相当于对所有test_dataset中数据的编码
# 然后定义下标indices_val来表示校验集数据的那些下标,indices_test表示测试集的下标 indices = range(len(test_dataset)) indices_val = indices[:5000] indices_test = indices[5000:] # 根据这些下标,构造两个数据集的SubsetRandomSampler采样器,它会对下标进行采样 sampler_val = torch.utils.data.sampler.SubsetRandomSampler(indices_val) sampler_test = torch.utils.data.sampler.SubsetRandomSampler(indices_test) # 根据两个采样器来定义加载器,注意将sampler_val和sampler_test分别赋值给了validation_loader和test_loader validation_loader = torch.utils.data.DataLoader(dataset =test_dataset, batch_size = batch_size, sampler = sampler_val ) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, sampler = sampler_test )
特别注意
可能出现batch_size小于预期的情况,请指定drop_last = True解决