04. Pytorch中WeightedRandomSampler()的使用

04. Pytorch中WeightedRandomSampler()的使用


针对一些样本类别不均衡数据集,可使用加权的随机采样器平衡各类样本被采样(抽取)训练的概率,缓和偏向预测。

1. 加权随机采样器简介

image-20220329232206466

参数介绍:

  • weights:每个样本的采样权重(注意是每个样本),是一个长度为NlistN为数据集中的样本总个数。
  • num_samples:需要采样的样本个数。
  • replacement:是否可以重采样,如果可以重采样则num_samples可以大于N;否则num_samples <= N,且当num_samples=N时采样权重失效。
  • generator:用于生成采样器的生成器,一般不指定。

2. 使用示例

为了阐述WeightedRandomSampler()的工作原理,首先简单创建一个dataset

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
'''
创建简易数据集
'''
from torch.utils import data
import torch

class My_dataset(data.Dataset):
def __init__(self):
self.data = torch.randn(10, 3, 256, 256) # 生成虚拟图像数据
self.labels = torch.randint(0, 5, [10]) # 生层虚拟数据的label

def __getitem__(self, idx):
return self.data[idx], self.labels[idx], idx

def __len__(self):
return self.data.size(0)

上面创建了一个简易的数据集,里面只有10个(虚假)的图像数据,接下来将结合这个数据集类展示WeightedRandomSampler()来的用法。

2.1 生成采样索引

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from torch.utils import data
import torch

weights = [1, 1, 1, 1, 1, 1, 1, 1, 1, 9] # 10个样本各自的采样权重
sampler = data.WeightedRandomSampler(weights, num_samples=20, replacement=True)
print(list(sampler))
### out
# 因为最后一个样本的采样权重比较大,所以采样得到样本索引大多数都是最后一个样本
# [9, 9, 1, 9, 9, 9, 4, 9, 4, 3, 4, 9, 0, 9, 2, 9, 9, 9, 7, 9]

weights = [9, 1, 1, 1, 1, 1, 1, 1, 1, 1] # 10个样本各自的采样权重
sampler = data.WeightedRandomSampler(weights, num_samples=20, replacement=True)
print(list(sampler))
### out
# 同理,因为第一个样本的采样权重比较大,所以采样得到样本索引大多数都是第一个样本
# [0, 0, 4, 7, 4, 0, 1, 0, 0, 4, 0, 0, 3, 6, 3, 0, 0, 8, 0, 0]

replacement=False时,各个样本的索引只能被采样一次,且当num_samples=N时,输出的采样样本就是原始数据集。

2.2 结合dataset演示

1
2
3
4
5
6
7
8
9
10
11
12
13
from torch.utils import data
import torch

weights = [1, 1, 1, 1, 1, 1, 1, 1, 1, 9] # 10个样本各自的采样权重
sampler = data.WeightedRandomSampler(weights, num_samples=20, replacement=True)
my_dataset = My_dataset()
my_dataloader = data.DataLoader(my_dataset, batch_size=10, sampler=sampler)
for (data, label, index) in my_dataloader:
print(index)

### out
# tensor([1, 9, 8, 8, 9, 3, 1, 9, 4, 9])
# tensor([9, 9, 0, 3, 9, 9, 5, 9, 7, 9])

加入sampler之后,dataloader在每次加载的数据样本就会依据采样的样本索引去获取对应索引样本的数据,以达到平衡不同类别的样本数量。