반응형
Single object detection을 위한 간단한 모델을 생성하겠습니다.
convolutional layer, pooling layer, skip connection을 활용한 모델입니다.
# implement the model class
import torch.nn as nn
import torch.nn.functional as F
# define the bulk of the model class
class Net(nn.Module):
def __init__(self, params):
super(Net, self).__init__()
C_in, H_in, W_in = params['input_shape']
init_f = params['initial_filters']
num_outputs = params['num_outputs']
self.conv1 = nn.Conv2d(C_in, init_f, kernel_size=3, stride=2, padding=1)
self.conv2 = nn.Conv2d(init_f+C_in, 2*init_f, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(3*init_f+C_in, 4*init_f, kernel_size=3, padding=1)
self.conv4 = nn.Conv2d(7*init_f+C_in, 8*init_f, kernel_size=3, padding=1)
self.conv5 = nn.Conv2d(15*init_f+C_in, 16*init_f, kernel_size=3, padding=1)
self.fc1 = nn.Linear(16*init_f, num_outputs)
def forward(self, x):
identity = F.avg_pool2d(x, 4, 4)
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = torch.cat((x, identity), dim=1)
identity = F.avg_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = torch.cat((x, identity), dim=1)
identity = F.avg_pool2d(x, 2, 2)
x = F.relu(self.conv3(x))
x = F.max_pool2d(x, 2, 2)
x = torch.cat((x, identity), dim=1)
identity = F.avg_pool2d(x, 2, 2)
x = F.relu(self.conv4(x))
x = F.max_pool2d(x, 2, 2)
x = torch.cat((x, identity), dim=1)
x = F.relu(self.conv5(x))
x = F.adaptive_avg_pool2d(x, 1)
x = x.reshape(x.size(0), -1)
x = self.fc1(x)
return x
single object는 모든 이미지에 object가 포함되어 있으므로 class를 분류하지 않습니다. 바운딩 박스 중심 좌표만 출력하면 됩니다. 따라서 num_outputs를 2로 설정합니다.
# define an object of the Net class
params_model = {
'input_shape': (3, 256, 256),
'initial_filters': 16,
'num_outputs': 2,
}
model = Net(params_model)
# move the model to the CUDA device
if torch.cuda.is_available():
device = torch.device('cuda')
model = model.to(device)
print(model)
model summary 함수로 잘 구현되었는지 확인합니다.
# get the model summary
from torchsummary import summary
summary(model, input_size=(3, 256, 256), device=device.type)
반응형
'Python > PyTorch 공부' 카테고리의 다른 글
[PyTorch] YOLOv3 학습을 위한 VOC2007 커스텀 데이터셋 생성하기 (2) | 2021.03.15 |
---|---|
[PyTorch] 러닝 레이트 스케쥴러(Learning Rate Scheduler) ReducedLROnPlateau 함수 (2) | 2021.03.06 |
[PyTorch] 커스텀 데이터셋(custom dataset) 생성하기 (0) | 2021.03.06 |
[PyTorch] data augmentation(resize, flip, shift, brightness, contrast, gamma) 함수 정의하기 (0) | 2021.03.06 |
[PyTorch] 이미지 크기와 바운딩박스 좌표를 resize 하는 함수 정의 (0) | 2021.03.06 |