1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
| import random import torch import torch.nn as nn
class BasicConv2d(nn.Module): def __init__(self, in_channels, out_channels, **kwargs): super(BasicConv2d, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) self.bn = nn.BatchNorm2d(out_channels, eps=0.001) self.relu = nn.ReLU(inplace=True)
def forward(self, x): x = self.conv(x) x = self.bn(x)
return self.relu(x)
class Identity_layer(nn.Module): """ 恒等映射 """ def __init__(self): super(Identity_layer, self).__init__() pass
def forward(self, x): return x
class shufflev2_block(nn.Module): def __init__(self, in_channels, out_channels, mode=1, group=2): super(shufflev2_block, self).__init__() self.group = group self.in_channels = in_channels self.mode = mode self.out_channels = out_channels if self.mode == 1: assert in_channels == out_channels, "Under the MODE 1 input and output channel number should be equal." mid_channels = self.in_channels // group self.branch1 = Identity_layer() self.branch2 = nn.Sequential( BasicConv2d(mid_channels, mid_channels, kernel_size=1), nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=1, padding=1, groups=mid_channels, bias=False), nn.BatchNorm2d(mid_channels), BasicConv2d(mid_channels, mid_channels, kernel_size=1) ) elif mode == 2: mid_channels = self.out_channels // group self.branch1 = nn.Sequential( nn.Conv2d(self.in_channels, self.in_channels, kernel_size=3, stride=2, padding=1, groups=self.in_channels, bias=False), nn.BatchNorm2d(self.in_channels), BasicConv2d(self.in_channels, mid_channels, kernel_size=1)
) self.branch2 = nn.Sequential( BasicConv2d(self.in_channels, mid_channels, kernel_size=1), nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False), nn.BatchNorm2d(mid_channels), BasicConv2d(mid_channels, mid_channels, kernel_size=1) ) def forward(self, x): if self.mode == 1: channel_per_group = self.in_channels // self.group x = torch.cat([self.branch1(x[:, :channel_per_group]), self.branch2(x[:, channel_per_group:])], dim=1) elif self.mode == 2: x = torch.cat([self.branch1(x), self.branch2(x)], dim=1) return self._shuffle(x)
def _shuffle(self, x): channel = x.size(1) shuffle_list = random.sample(range(channel), channel) return x[:, shuffle_list]
fake_data = torch.randn(10, 512, 128, 128) net = shufflev2_block(in_channels=512, out_channels=512, mode=1) print(net(fake_data).size())
|