04. Pytorch中WeightedRandomSampler()的使用
04. Pytorch中WeightedRandomSampler()的使用
针对一些样本类别不均衡数据集,可使用加权的随机采样器平衡各类样本被采样(抽取)训练的概率,缓和偏向预测。
1. 加权随机采样器简介
参数介绍:
weights:
每个样本的采样权重(注意是每个样本),是一个长度为N
的list
,N
为数据集中的样本总个数。num_samples:
需要采样的样本个数。replacement:
是否可以重采样,如果可以重采样则num_samples
可以大于N
;否则num_samples <= N
,且当num_samples=N
时采样权重失效。generator:
用于生成采样器的生成器,一般不指定。
2. 使用示例
为了阐述WeightedRandomSampler()
的工作原理,首先简单创建一个dataset
。
1 |
|
上面创建了一个简易的数据集,里面只有10个(虚假)的图像数据,接下来将结合这个数据集类展示WeightedRandomSampler()
来的用法。
2.1 生成采样索引
1 |
|
当
replacement=False
时,各个样本的索引只能被采样一次,且当num_samples=N
时,输出的采样样本就是原始数据集。
2.2 结合dataset
演示
1 |
|
加入sampler
之后,dataloader
在每次加载的数据样本就会依据采样的样本索引去获取对应索引样本的数据,以达到平衡不同类别的样本数量。