[Paper] Squeeze-and-Excitation Networks
“FishNet: A Versatile Backbone for Image, Region, and Pixel Level Prediction”이란 논문에 대한 구현 코드입니다.
논문 원문은 링크에서 확인할 수 있습니다.
구현은 논문에서 공개한 깃헙(링크)를 참고하여 진행하였습니다.
Block
Residual block과 동일한 역할을 하는 block으로 모델 구조에서 기본 단위로 사용합니다.
class FishBlock(nn.Module):
def __init__(self, ch_in, ch_out, stride=1, mode='DR', k=1, dilation=1):
super(FishBlock, self).__init__()
self.mode = mode
self.relu = nn.ReLU()
self.k = k
bottle_neck_ch = ch_out // 4
self.bn1 = nn.BatchNorm2d(ch_in)
self.conv1 = nn.Conv2d(ch_in, bottle_neck_ch, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(bottle_neck_ch)
self.conv2 = nn.Conv2d(bottle_neck_ch, bottle_neck_ch, kernel_size=3, stride=stride, padding=dilation, dilation=dilation,bias=False)
self.bn3 = nn.BatchNorm2d(bottle_neck_ch)
self.conv3 = nn.Conv2d(bottle_neck_ch, ch_out, kernel_size=1, bias=False)
if mode == 'UR':
self.shortcut = None
elif ch_in != ch_out or stride > 1:
self.shortcut = nn.Sequential(
nn.BatchNorm2d(ch_in),
self.relu,
nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride, bias=False)
)
else:
self.shortcut = None
def channel_wise_reduction(self, data):
n, c, h, w = data.size()
return data.view(n, c//self.k, self.k, h, w).sum(2)
def forward(self, data):
out = self.conv1(self.relu(self.bn1(data)))
out = self.conv2(self.relu(self.bn2(out)))
out = self.conv3(self.relu(self.bn3(out)))
if self.mode == 'UR':
residual = self.channel_wise_reduction(data)
elif self.shortcut is not None:
residual = self.shortcut(data)
else:
residual = data
out += residual
return out
Model
import math
import torch
import torch.nn as nn
import sys
import os
sys.path.append(f'{os.path.dirname(os.path.abspath(__file__))}')
from fish_block import FishBlock
class FishNet(nn.Module):
def __init__(self, **kwargs):
super(FishNet, self).__init__()
ch_initial = kwargs['tail_ch_in'][0]
self.layer1 = self.layers(3, ch_initial // 2, stride=2)
self.layer2 = self.layers(ch_initial // 2, ch_initial // 2)
self.layer3 = self.layers(ch_initial // 2, ch_initial)
self.pool = nn.MaxPool2d(3, padding=1, stride=2)
self.fish = Fish(**kwargs)
self.init_weights()
@staticmethod
def layers(ch_in, ch_out, stride=1):
result = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=1, stride=stride, bias=False),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
return result
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, data):
out = self.layer1(data)
out = self.layer2(out)
out = self.layer3(out)
out = self.pool(out)
score = self.fish(out)
result = score.view(out.size(0), -1)
return result
class Fish(nn.Module):
def __init__(self, tail_ch_in, tail_ch_out, tail_res_blks, body_ch_in, body_ch_out, body_res_blks, body_trans_blks,
head_ch_in, head_ch_out, head_res_blks, head_trans_blks, num_cls):
super(Fish, self).__init__()
self.num_cls = num_cls
self.num_tail = len(body_ch_out)
self.tail_ch_in= tail_ch_in
self.tail_ch_out = tail_ch_out
self.tail_res_blks = tail_res_blks
self.num_body = len(body_ch_out)
self.body_ch_in = body_ch_in
self.body_ch_out = body_ch_out
self.body_res_blks = body_res_blks
self.body_trans_blks = body_trans_blks
self.num_head = len(head_ch_out)
self.head_ch_in = head_ch_in
self.head_ch_out = head_ch_out
self.head_res_blks = head_res_blks
self.head_trans_blks = head_trans_blks
self.tail, self.se, self.body, self.head, self.score = self.make_fish()
def make_fish(self):
tail = make_blocks(self.num_tail, self.tail_ch_in, self.tail_ch_out, self.tail_res_blks, 'tail')
se = make_se_block(self.tail_ch_out[-1], self.tail_ch_out[-1])
body = make_blocks(self.num_body, self.body_ch_in, self.body_ch_out, self.body_res_blks, 'body')
head = make_blocks(self.num_head, self.head_ch_in, self.head_ch_out, self.head_res_blks, 'head')
score = make_score(self.head_ch_out[-1]+self.tail_ch_out[-1], self.num_cls)
return tail, se, body, head, score
def forward(self, data):
# for i in range()
tail0 = self.tail[0](data)
tail1 = self.tail[1](tail0)
se = self.se[0](tail1)
body0 = self.body[0](se)
body1 = self.body[1](torch.cat((body0, tail0), dim=1))
head0 = self.head[0](torch.cat((body1, data), dim=1))
head1 = self.head[1](torch.cat((head0, body0), dim=1))
score = self.score[0](torch.cat((head1, tail1), dim=1))
return score
def make_blocks(num, ch_in, ch_out, res_blks, part):
blocks = []
is_down = True if part != 'body' else False
sampling = nn.MaxPool2d(2, stride=2) if is_down else nn.Upsample(scale_factor=2)
for i in range(num):
k = int(round(ch_in[i]/ch_out[i])) if part == 'body' else 1
block = []
block.extend(make_res_block(ch_in[i], ch_out[i], res_blks[i], k=k, is_down=is_down))
block.append(sampling)
block = nn.Sequential(*block)
blocks.append(block)
return nn.ModuleList(blocks)
def make_res_block(ch_in, ch_out, num_res_blocks, is_down=False, k=1, dilation=1):
layers = []
if is_down:
layers.append(FishBlock(ch_in, ch_out, stride=1))
else:
layers.append(FishBlock(ch_in, ch_out, mode='UR', dilation=dilation, k=k))
for i in range(1, num_res_blocks):
layers.append(FishBlock(ch_out, ch_out, stride=1, dilation=dilation))
return layers
def make_se_block(ch_in, ch_out):
bn = nn.BatchNorm2d(ch_in)
conv_sq = nn.Conv2d(ch_in, ch_out//16, kernel_size=1)
conv_ex = nn.Conv2d(ch_out//16, ch_out, kernel_size=1)
relu = nn.ReLU(inplace=True)
pool = nn.AdaptiveMaxPool2d(1)
sigmoid = nn.Sigmoid()
return nn.Sequential(bn, relu, pool, conv_sq, relu, conv_ex, sigmoid)
def make_score(ch_in, ch_out, has_pool=False):
bn_in = nn.BatchNorm2d(ch_in)
relu = nn.ReLU(inplace=True)
conv_trans = nn.Conv2d(ch_in, ch_in//2, kernel_size=1, bias=False)
bn_out = nn.BatchNorm2d(ch_in//2)
layers = nn.Sequential(bn_in, relu, conv_trans, bn_out, relu)
if has_pool:
fc = nn.Sequential(
nn.AdaptiveMaxPool2d(1),
nn.Conv2d(ch_in//2, ch_out, kernel_size=1, bias=True)
)
else:
fc = nn.Conv2d(ch_in//2, ch_out, kernel_size=1, bias=True)
return nn.Sequential(*[layers, fc])
댓글남기기