논문 읽기/Style Transfer

[논문 읽기] PyTorch 구현 코드로 살펴보는 Style Transfer(2015)

AI 꿈나무 2021. 6. 11. 20:53
반응형

 안녕하세요, 오늘 읽은 논문은 A Neural Algorithm of Artistic Style 입니다. 해당 논문은 Style Transfer의 시초가 되는 논문이라고 하네요 ㅎㅎ 논문을 읽어도 이해가 잘 되지 않아, 구현코드를 살펴보면서 공부해보았습니다.

 

 전체 코드와 google colab 환경에서 style transfer을 진행하는 포스팅은 아래에서 확인하실 수 있습니다.

 

[논문 구현] PyTorch로 Style Transfer(2015)를 구현하고 학습하기

 안녕하세요, 이번 포스팅에서는 Style transfer의 시초가 되는 A Neural Algorithm of Artistic Style 논문을 구현하고 이미지 합성을 진행하겠습니다. 작업 환경은 Google Colab에서 진행했습니다.  논문 리뷰..

deep-learning-study.tistory.com

 

 Style transfer은 content 이미지와 style 이미지로 content + style 합성된 이미지를 생성합니다.

 

 

 style image를 CNN에 전달하면, 층이 깊어질 수록 layer는 pixel value보다 high-level 정보를 포함합니다. 얕은 층의 출력값은 detailed pixel value 정보를 포함하고 있습니다. 아래 그림을 살펴보면, 층이 깊어질수록 style image의 representation을 더 많이 포함하고 있는 것을 확인할 수 있습니다.

 

 

 어떻게 content image와 style image를 합성한 이미지를 생성할 수 있을까요??

 

 그것은 loss함수를 살펴보면 알 수 있습니다.

 

 우선 content image와 동일한 input image를 생성합니다. content image를 복사하여 input image를 생성하는 것입니다. 이제 신경망이 학습을 진행하면서, content loss와 style loss를 최소화 하는 방향으로 input image 픽셀 값을 수정할 것입니다. 

 

content loss

 content image를 신경망에 전달한 후에 마지막 레이어의 출력값을 저장합니다. 논문에서는 VGG19를 사용합니다. 또한 input image도 신경망에 전달하여 마지막 레이어의 출력값을 저장합니다. 두 출력값을 사용하여 content loss를 계산합니다. 손실 함수는 MSE를 사용합니다.

 

 

 아래 코드는 Style Transfer 구현 코드중 일부분을 추출해온 것이며 전체 코드는 아래 깃허브에서 확인하실 수 있습니다.

 

Seonghoon-Yu/Paper_Review_and_Implementation_in_PyTorch

공부 목적으로 논문을 리뷰하고 해당 논문 파이토치 재구현을 합니다. Contribute to Seonghoon-Yu/Paper_Review_and_Implementation_in_PyTorch development by creating an account on GitHub.

github.com

# content loss
def get_content_loss(pred_features, target_features, layer):
    target = target_features[layer]
    pred = pred_features[layer]
    loss = F.mse_loss(pred, target)
    return loss
    
# input image
input_tensor = con_tensor.clone().requires_grad_(True)

# 특징을 추출할 레이어 선택
feature_layers = {'0': 'conv1_1',
                  '5': 'conv2_1',
                  '10': 'conv3_1',
                  '19': 'conv4_1',
                  '21': 'conv4_2',
                  '28': 'conv5_1'}

# input image를 CNN에 전달하여 각 레이어의 출력값 추출
input_features = get_features(input_tensor, model_vgg, feature_layers)

# content image를 CNN에 전달하여 각 레이어의 출력값 추출
content_features = get_features(con_tensor, model_vgg, feature_layers)

# 신경망의 마지막 레이어 지정
content_layer = 'conv5_1'

# content와 input의 마지막 레이어 출력값을 사용하여 MSE 계산
content_loss = get_content_loss(input_features, content_features, content_layer)

 

 즉, input image와 content image의 CNN 마지막 레이어의 출력값의 MSE가 최소화 되는 방향으로 input image의 값을 변경합니다.

 

style loss

 style image도 CNN에 전달하여 각 layer에서 출력값을 추출합니다. style loss는 style image의 각 layer 출력값에서 Gram matrix를 계산한 후에, input image의 각 layer 출력값의 Gram matrix와의 MSE를 계산합니다. 두 Gram matrix가 비슷해지도록 input image의 pixel value가 학습되는 것입니다.

 

gram matrix

 

각 layer에서 gram matrix의 mse

 

각 layer의 mse에 가중치 적용

 

# Gram matrix를 계산하는 함수를 정의합니다.
def gram_matrix(x):
    n, c, h, w = x.size()
    x = x.view(n*c, h*w)
    gram = torch.mm(x,x.t()) # 행렬간 곱셈 수행
    return gram
    
    
# style loss
def get_style_loss(pred_features, target_features, style_layers_dict):
    loss = 0
    for layer in style_layers_dict:
        pred_fea = pred_features[layer]
        pred_gram = gram_matrix(pred_fea)
        n, c, h, w = pred_fea.shape
        target_gram = gram_matrix(target_features[layer])
        layer_loss = style_layers_dict[layer] * F.mse_loss(pred_gram, target_gram)
        loss += layer_loss / (n*c*h*w)
    return loss
    
# CNN의 각 layer에서 출력값 얻기
style_features = get_features(sty_tensor, model_vgg, feature_layers)

# 각 레이어 출력값의 가중치 지정
style_layers_dict = {'conv1_1':0.75,
                     'conv2_1':0.5,
                     'conv3_1':0.25,
                     'conv4_1':0.25,
                     'conv5_1':0.25}

# loss 계산
style_loss = get_style_loss(input_features, style_features, style_layers_dict)

 

 최종 loss은 content와 style loss를 더하여 계산합니다. 각 loss의 가중치를 조절하여 style, content의 emphasis를 조절할 수 있습니다.

 

 

neural_loss = content_weight * content_loss + style_weight * style_loss

 

 아래 그림은 layer 깊이와 style loss 가중치에 따른 합성된 이미지 결과입니다.

 


참고자료

[1] https://arxiv.org/abs/1508.06576

코드 출처

[1] https://github.com/PacktPublishing/PyTorch-Computer-Vision-Cookbook/blob/master/Chapter08/Chapter8.ipynb

 

반응형