Boostcamp AI Tech

Fully Convolutional Networks(FCN) and VGGNet Segmentation Implemenation

ju_young 2021. 9. 12. 14:42
728x90

Segmentation

Classification은 특정 대상이 있는지 없는지를 확인하는 기술이고 흔히 CNN에서 마지막에 FC layer가 오는 형태를 사용한다. 하지만 Detection은 classification과 달리 대상의 존재 여부와 위치 정보도 포함해야한다. 이때 bounding box를 통해서 위치정보를 포함하는데 class를 가리는 softmax와 같이 구성된다.

 

Semantic segmentation은 이미지 안에 있는 pixel 단위로 클래스를 구분하는 것을 목표로 한다. 아래에서처럼 고양이, 하늘, 땅, 산 등을 각 픽셀을 segment하는 것이다.

한 가지 더 예를 들면 아래에서처럼 강아지와 고양이 사진을 input 값으로 넣으면 개에 해당하는 pixel, 고양이에 해당하는 pixel, 배경에 해당하는 pixel을 classification한다.

하지만 segmentation은 각 pixel이 무엇이냐를 알아낼 뿐만아니라 위치 정보도 알아야한다.

Fully Convolutional Networks(FCN)

  • 픽셀별로 라벨링이된 CNN을 사용
  • down-sampling과 up-sampling 연산을 수행
  • 공간 상의 위치가 주어지면 channel dimension의 output은 위치에 대응하는 pixel의 category prediction이 될 것이다.

다음과 같은 그림을 볼 경우 down-sampling이 수행 될때 feature extraction에 의해 각 pixel이 작아지면서 공간 정보도 같이 없어진다. 이렇게 공간 정보가 없어짐에따라 up-sampling을 수행할 때 문제가 된다. 즉, up-sampling을 하기 어렵다는 얘기이다.

이런 문제를 해결하기 위해 skip connection을 사용하여 이전의 pixel 위치 정보를 고려하는 방법을 사용하였다. (skip connection은 resnet에서도 사용한 방법이다.)

또한 가장 큰 장점은 이미지의 크기에 상관없이 작동할 수 있다는 점이다.

Implementation

1. Backbone

import torch
import torch.nn as nn

class VGG11BackBone(nn.Module):
  def __init__(self):
    super(VGG11BackBone, self).__init__()

    self.relu = nn.ReLU(inplace=True)

    self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
    self.bn1   = nn.BatchNorm2d(64)
    self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # 64 x 112 x 112

    self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
    self.bn2   = nn.BatchNorm2d(128)
    self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # 128 x 56 x 56

    self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
    self.bn3_1   = nn.BatchNorm2d(256)
    self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
    self.bn3_2   = nn.BatchNorm2d(256)
    self.pool3   = nn.MaxPool2d(kernel_size=2, stride=2) # 256 x 28 x 28

    self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
    self.bn4_1   = nn.BatchNorm2d(512)
    self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
    self.bn4_2   = nn.BatchNorm2d(512)
    self.pool4   = nn.MaxPool2d(kernel_size=2, stride=2) # 512 x 14 x 14

    self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
    self.bn5_1   = nn.BatchNorm2d(512)
    self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
    self.bn5_2   = nn.BatchNorm2d(512)

  def forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.pool1(x)

    x = self.conv2(x)
    x = self.bn2(x)
    x = self.relu(x)
    x = self.pool2(x)

    x = self.conv3_1(x)
    x = self.bn3_1(x)
    x = self.relu(x)
    x = self.conv3_2(x)
    x = self.bn3_2(x)
    x = self.relu(x)
    x = self.pool3(x)

    x = self.conv4_1(x)
    x = self.bn4_1(x)
    x = self.relu(x)
    x = self.conv4_2(x)
    x = self.bn4_2(x)
    x = self.relu(x)
    x = self.pool4(x)

    x = self.conv5_1(x)
    x = self.bn5_1(x)
    x = self.relu(x)
    x = self.conv5_2(x)
    x = self.bn5_2(x)
    x = self.relu(x)

    return x

2. Classifier

class VGG11Classification(nn.Module):
  def __init__(self, num_classes = 7):
    super(VGG11Classification, self).__init__()

    self.backbone = VGG11BackBone()
    self.pool5   = nn.MaxPool2d(kernel_size=2, stride=2) # 512 x 7 x 7
    self.gap      = nn.AdaptiveAvgPool2d(1) # 512 x 1 x 1
    self.fc_out   = nn.Linear(512, num_classes)

  def forward(self, x):
    x = self.backbone(x)
    x = self.pool5(x)
    x = self.gap(x)
    x = torch.flatten(x, 1)
    x = self.fc_out(x)

    return x

3. Segmentation

class VGG11Segmentation(nn.Module):
  def __init__(self, num_classes = 7):
    super(VGG11Segmentation, self).__init__()

    self.backbone = VGG11BackBone()

    with torch.no_grad():
      self.conv_out = nn.Conv2d(512, num_classes, kernel_size=1, padding=0, stride=1)

    self.upsample = torch.nn.Upsample(scale_factor=16, mode='bilinear', align_corners=False)

  def forward(self, x):
    x = self.backbone(x) # 512 x 14 x 14
    x = self.conv_out(x) # 7 x 14 x 14
    x = self.upsample(x) # 1 x 7 x 14 * 16 x 14 * 16
    assert x.shape == (1, 7, 224, 224)

    return x

  def copy_last_layer(self, fc_out):

    reshaped_fc_out = fc_out.weight.detach()
    reshaped_fc_out = torch.reshape(reshaped_fc_out, (7,512,1,1))
    self.conv_out.weight = torch.nn.Parameter(reshaped_fc_out)

    assert self.conv_out.weight[0][0] == fc_out.weight[0][0]

    return

[Reference]

POSTECH Industrial AI Lab.

728x90