3 분 소요

“FishNet: A Versatile Backbone for Image, Region, and Pixel Level Prediction”이란 논문에 대한 구현 코드입니다.

논문 원문은 링크에서 확인할 수 있습니다.

구현은 논문에서 공개한 깃헙(링크)를 참고하여 진행하였습니다.

Block

Residual block과 동일한 역할을 하는 block으로 모델 구조에서 기본 단위로 사용합니다.

class FishBlock(nn.Module):
    def __init__(self, ch_in, ch_out, stride=1, mode='DR', k=1, dilation=1):
        super(FishBlock, self).__init__()
        self.mode = mode
        self.relu = nn.ReLU()
        self.k = k

        bottle_neck_ch = ch_out // 4

        self.bn1 = nn.BatchNorm2d(ch_in)
        self.conv1 = nn.Conv2d(ch_in, bottle_neck_ch, kernel_size=1, bias=False)

        self.bn2 = nn.BatchNorm2d(bottle_neck_ch)
        self.conv2 = nn.Conv2d(bottle_neck_ch, bottle_neck_ch, kernel_size=3, stride=stride, padding=dilation, dilation=dilation,bias=False)

        self.bn3 = nn.BatchNorm2d(bottle_neck_ch)
        self.conv3 = nn.Conv2d(bottle_neck_ch, ch_out, kernel_size=1, bias=False)

        if mode == 'UR':
            self.shortcut = None
        elif ch_in != ch_out or stride > 1:
            self.shortcut = nn.Sequential(
                nn.BatchNorm2d(ch_in),
                self.relu,
                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride, bias=False)
            )
        else:
            self.shortcut = None

    def channel_wise_reduction(self, data):

        n, c, h, w = data.size()
        return data.view(n, c//self.k, self.k, h, w).sum(2)

    def forward(self, data):

        out = self.conv1(self.relu(self.bn1(data)))
        out = self.conv2(self.relu(self.bn2(out)))
        out = self.conv3(self.relu(self.bn3(out)))

        if self.mode == 'UR':
            residual = self.channel_wise_reduction(data)
        elif self.shortcut is not None:
            residual = self.shortcut(data)
        else:
            residual = data

        out += residual
        return out

Model

import math
import torch
import torch.nn as nn
import sys
import os

sys.path.append(f'{os.path.dirname(os.path.abspath(__file__))}')

from fish_block import FishBlock


class FishNet(nn.Module):
    def __init__(self, **kwargs):
        super(FishNet, self).__init__()

        ch_initial = kwargs['tail_ch_in'][0]
        self.layer1 = self.layers(3, ch_initial // 2, stride=2)
        self.layer2 = self.layers(ch_initial // 2, ch_initial // 2)
        self.layer3 = self.layers(ch_initial // 2, ch_initial)

        self.pool = nn.MaxPool2d(3, padding=1, stride=2)
        self.fish = Fish(**kwargs)
        self.init_weights()

    @staticmethod
    def layers(ch_in, ch_out, stride=1):
        result = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=1, stride=stride, bias=False),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )
        return result

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, data):
        out = self.layer1(data)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.pool(out)
        score = self.fish(out)

        result = score.view(out.size(0), -1)
        return result


class Fish(nn.Module):
    def __init__(self, tail_ch_in, tail_ch_out, tail_res_blks, body_ch_in, body_ch_out, body_res_blks, body_trans_blks,
                 head_ch_in, head_ch_out, head_res_blks, head_trans_blks, num_cls):
        super(Fish, self).__init__()
        self.num_cls = num_cls

        self.num_tail = len(body_ch_out)
        self.tail_ch_in= tail_ch_in
        self.tail_ch_out = tail_ch_out
        self.tail_res_blks = tail_res_blks

        self.num_body = len(body_ch_out)
        self.body_ch_in = body_ch_in
        self.body_ch_out = body_ch_out
        self.body_res_blks = body_res_blks
        self.body_trans_blks = body_trans_blks

        self.num_head = len(head_ch_out)
        self.head_ch_in = head_ch_in
        self.head_ch_out = head_ch_out
        self.head_res_blks = head_res_blks
        self.head_trans_blks = head_trans_blks
        self.tail, self.se, self.body, self.head, self.score = self.make_fish()

    def make_fish(self):
        tail = make_blocks(self.num_tail, self.tail_ch_in, self.tail_ch_out, self.tail_res_blks, 'tail')

        se = make_se_block(self.tail_ch_out[-1], self.tail_ch_out[-1])

        body = make_blocks(self.num_body, self.body_ch_in, self.body_ch_out, self.body_res_blks, 'body')

        head = make_blocks(self.num_head, self.head_ch_in, self.head_ch_out, self.head_res_blks, 'head')

        score = make_score(self.head_ch_out[-1]+self.tail_ch_out[-1], self.num_cls)
        return tail, se, body, head, score

    def forward(self, data):

        # for i in range()
        tail0 = self.tail[0](data)
        tail1 = self.tail[1](tail0)
        se = self.se[0](tail1)

        body0 = self.body[0](se)
        body1 = self.body[1](torch.cat((body0, tail0), dim=1))

        head0 = self.head[0](torch.cat((body1, data), dim=1))
        head1 = self.head[1](torch.cat((head0, body0), dim=1))

        score = self.score[0](torch.cat((head1, tail1), dim=1))
        return score


def make_blocks(num, ch_in, ch_out, res_blks, part):
    blocks = []
    is_down = True if part != 'body' else False
    sampling = nn.MaxPool2d(2, stride=2) if is_down else nn.Upsample(scale_factor=2)
    for i in range(num):
        k = int(round(ch_in[i]/ch_out[i])) if part == 'body' else 1
        block = []
        block.extend(make_res_block(ch_in[i], ch_out[i], res_blks[i], k=k, is_down=is_down))
        block.append(sampling)
        block = nn.Sequential(*block)
        blocks.append(block)
    return nn.ModuleList(blocks)


def make_res_block(ch_in, ch_out, num_res_blocks, is_down=False, k=1, dilation=1):
    layers = []

    if is_down:
        layers.append(FishBlock(ch_in, ch_out, stride=1))
    else:
        layers.append(FishBlock(ch_in, ch_out, mode='UR', dilation=dilation, k=k))

    for i in range(1, num_res_blocks):
        layers.append(FishBlock(ch_out, ch_out, stride=1, dilation=dilation))

    return layers


def make_se_block(ch_in, ch_out):
    bn = nn.BatchNorm2d(ch_in)
    conv_sq = nn.Conv2d(ch_in, ch_out//16, kernel_size=1)
    conv_ex = nn.Conv2d(ch_out//16, ch_out, kernel_size=1)
    relu = nn.ReLU(inplace=True)
    pool = nn.AdaptiveMaxPool2d(1)
    sigmoid = nn.Sigmoid()
    return nn.Sequential(bn, relu, pool, conv_sq, relu, conv_ex, sigmoid)


def make_score(ch_in, ch_out, has_pool=False):
    bn_in = nn.BatchNorm2d(ch_in)
    relu = nn.ReLU(inplace=True)
    conv_trans = nn.Conv2d(ch_in, ch_in//2, kernel_size=1, bias=False)
    bn_out = nn.BatchNorm2d(ch_in//2)
    layers = nn.Sequential(bn_in, relu, conv_trans, bn_out, relu)
    if has_pool:
        fc = nn.Sequential(
            nn.AdaptiveMaxPool2d(1),
            nn.Conv2d(ch_in//2, ch_out, kernel_size=1, bias=True)
        )
    else:
        fc = nn.Conv2d(ch_in//2, ch_out, kernel_size=1, bias=True)
    return nn.Sequential(*[layers, fc])

댓글남기기