1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
| class AttentionResNet(nn.Module):
def __init__(self, in_channels=3, out_channels=1, num_filters=32, encoder_depth=34, pretrained=True): super(AttentionResNet, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.num_filters = num_filters
if encoder_depth == 34: self.encoder = torchvision.models.resnet34(pretrained=pretrained) bottom_channel_nr = 512 elif encoder_depth == 101: self.encoder = torchvision.models.resnet101(pretrained=pretrained) bottom_channel_nr = 2048 elif encoder_depth == 152: self.encoder = torchvision.models.resnet152(pretrained=pretrained) bottom_channel_nr = 2048 else: raise NotImplementedError('only 34, 101, 152 version of Resnet are implemented')
self.pool = nn.MaxPool2d(2, 2) self.relu = nn.ReLU(inplace=True) self.conv1 = nn.Sequential(self.encoder.conv1, self.encoder.bn1, self.encoder.relu, self.pool) self.conv2 = self.encoder.layer1 self.conv3 = self.encoder.layer2 self.conv4 = self.encoder.layer3 self.conv5 = self.encoder.layer4
self.center = DecoderBlockV2(bottom_channel_nr, num_filters * 8 * 2, num_filters * 8) self.dec5 = DecoderBlockV2(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8) self.dec4 = DecoderBlockV2(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8) self.dec3 = DecoderBlockV2(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2) self.dec2 = DecoderBlockV2(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2) self.dec1 = DecoderBlockV2(num_filters * 2 * 2, num_filters * 2 * 2, num_filters)
self.attention_map = nn.Sequential( ConvRelu(num_filters, num_filters), nn.Conv2d(num_filters, 1, kernel_size=1) )
def forward(self, x):
conv1 = self.conv1(x) conv2 = self.conv2(conv1) conv3 = self.conv3(conv2) conv4 = self.conv4(conv3) conv5 = self.conv5(conv4)
pool = self.pool(conv5) center = self.center( pool ) dec5 = self.dec5(torch.cat([center, conv5], 1)) dec4 = self.dec4(torch.cat([dec5, conv4], 1)) dec3 = self.dec3(torch.cat([dec4, conv3], 1)) dec2 = self.dec2(torch.cat([dec3, conv2], 1)) dec1 = self.dec1(dec2)
x = self.attention_map( dec1 ) return x
|