수학/딥러닝 이론

04-2. 오차역전파법 (2) - 간단한 역전파 계층 구현

AI 꿈나무 2020. 9. 12. 13:35
반응형

 (밑바닥부터 시작하는 딥러닝, 사이토고키)를 바탕으로 작성하였습니다.


 오차역전파법 (2) - 간단한 역전파의 계층 구현

 이전 포스팅에서는 역전파의 이해를 돕기 위해 계산 그래프의 역전파가 연쇄법칙에 따라 진행되는 모습을 공부하였습니다.

 이번 포스팅에서는 사과 쇼핑의 예를 통해 간단한 역전파의 계층을 구현해보겠습니다.

 


3. 역전파

 '+'와 'X' 등의 연산을 예로 들어 역전파의 구조를 설명하겠습니다.

 

 

 3.1 덧셈 노드의 역전파

 먼저 덧셈 노드의 역전파입니다. 여기에서는 $z = x + y$라는 식을 대상으로 역전파를 살펴보겠습니다. 우선, $z = x + y$의 미분은 다음과 같이 해석적으로 계산할 수 있습니다.

 

$$\frac{\partial z}{\partial x} = 1$$

$$\frac{\partial z}{\partial y} = 1$$

 

 이를 계산 그래프로는 다음처럼 그릴 수 있습니다.

 

덧셈 노드의 역전파 : 왼쪽이 순전파, 오른쪽이 역전파이다. 덧셈 노드의 역전파는 입력값을 그래도 흘러보낸다.

 

 상류에서 전해진 미분($\frac{\partial L}{\partial z}$)에 1을 곱하여 하류로 흘립니다. 덧셈 노드의 역전파는 1을 곱하기만 할 뿐이므로 입력된 값을 그대로 다음 노드로 보냅니다. 여기에서 L은 최종적으로 L 값을 출력하는 계산 그래프를 가정하기 때문입니다.

 

최종 출력으로 가는 계산의 중간에 덧셈 노드가 존재한다. 역전파에서는 국소적 미분이 가장 오른쪽의 출력에서 시작하여 노드를 타고 역방향(왼쪽)으로 전파된다.

 

 이제 구체적인 예를 살펴봅니다. '10 + 5 = 15'라는 계산이 있고, 상류에서 1.3이라는 값이 흘러옵니다. 이를 계산 그래프로 그려보겠습니다.

 

덧셈 노드 역전파의 구체적인 예

 

 덧셈 노드 역전파는 입력 신호를 다음 노드로 출력할 뿐이므로 1.3을 그대로 다음 노드로 전달합니다.

 

 

 3.2 곱셈 노드의 역전파

 이어서 곱셈 노드의 역전파를 설명하겠습니다. $z = xy$라는 식을 생각해봅시다. 이 식의 미분은 다음과 같습니다.

 

$$\frac{\partial z}{\partial x} = y$$

$$\frac{\partial z}{\partial y} = x$$

 

 계산 그래프는 다음과 같이 그릴 수 있습니다.

 

곱셈 노드의 역전파 : 왼쪽이 순전파, 오른쪽이 역전파다.

 

 곱셈 노드 역전파는 상류의 값에 순전파 때의 입력 신호들을 '서로 바꾼 값'을 곱해서 하류로 보냅니다.

 

 구체적인 예를 들어보겠습니다. '10 X 5 = 50'이라는 계산이 있고, 역전파 때 상류에서 1.3 값이 흘러온다고 합시다. 이를 계산 그래프로 그려보겠습니다.

 

곱셈 노드 역전파의 구체적인 예

 

 곱셈의 역전파는 순방향 입력 신호의 값이 필요합니다. 그래서 곱셈 노드를 구현할 때는 순전파의 입력 신호를 변수에 저장해둡시다.

 

 

 3.3 사과 쇼핑의 예

 사과 쇼핑 문제에서 사과의 가격, 사과의 개수, 소비세라는 세 변수 각각이 최종 금액에 어떻게 영향을 주느냐 풀어보겠습니다. 이는 '사과 가격에 대한 지불 금액의 미분', '사과 개수에 대한 지불 금액의 미분', '소비세에 대한 지불 금액의 미분'을 구하는 것에 해당합니다. 이를 계산 그래프의 역전파를 사용해서 풀면 다음과 같습니다.

 

사과 쇼핑의 역전파 예

 

 곱셈 노드의 역전파에서는 입력 신호를 서로 바꿔서 하류로 흘립니다.

 


 

 

 4. 단순한 계층 구현하기

 사과 쇼핑 예를 파이썬으로 구현하겠습니다. 여기에서 계산 그래프의 곱셈 노드를 'MulLayer', 덧셈 노드를 'AddLayer'라는 이름으로 구현하겠습니다.

 

 

 4.1 곱셈 계층

 모든 계층은 forward()와 backward()라는 공통의 메서드를 갖도록 구현하겠습니다. forward()는 순전파, backward()는 역전파를 처리합니다.

 

 곱셈 계층 구현

class MulLayer:
    def __init__(self):         # 변수 x와 y초기화
        self.x = None
        self.y = None
        
    def forward(self, x, y):
        self.x = x
        self.y = y
        out = x * y
        
        return out
        
    def backward(self, dout): # dout은 상류에서 넘어온 미분(dout)
        dx = dout * self.y    # x와 y를 바꾼다.
        dy = dout * self.x
        
        return dx, dy

 

 이 MulLayer를 사용해서 앞에서 본'사과 쇼핑'을 구현해보겠습니다.

 

사과 2개 구입

apple = 100
apple_num = 2
tax = 1.1

# 계층들
mul_apple_layer = MulLayer()
mul_tax_layer = MulLayer()

# 순전파
apple_price = mul_apple_layer.forward(apple, apple_num)
price = mul_tax_layer.forward(apple_price, tax)

#역전파
dprice = 1
dapple_price, dtax = mul_tax_layer.backward(dprice)
dapple, dapple_num = mul_apple_layer.backward(dapple_price)

 

 

 4.2 덧셈 계층

 덧셈 계층 구현

class AddLayer:
    def __init__(self):
        pass                    # 덧셈 계층에서는 초기화가 필요 없습니다.
        
    def forward(self, x, y):
        out = x + y
        
        return out
        
    def backward(self, dout):
        dx = dout * 1           # 상류에서 내려온 미분을 그대로 하류로 흘립니다.
        dy = dout * 1
        
        return dx, dy

 

 덧셈 계층과 곱셈 계층을 사용하여 사과 2개와 귤 3개를 사는 상황을 구현하겠습니다.

 

사과 2개와 귤 3개 구입

 

apple = 100
apple_num = 2
orange = 150
orange_num = 3
tax = 1.1

# 계층들
mul_apple_layer = MulLayer()
mul_orange_layer = MulLayer()
add_apple_orange_layer = AddLayer()
mul_tax_layer = MulLayer()

# 순전파
apple_price = mul_apple_layer.forward(apple, apple_num)  # (1)
orange_price = mul_orange_layer.forward(orange, orange_num)  #(2)
all_price = add_apple_orange_layer.forward(apple_price, orange_price)  # (3)
price = mul_tax_layer.forward(all_price, tax)  # (4)

# 역전파
dprice = 1
dall_price, dtax = mul_tax_layer.backward(dprice)  # (4)
dapple_price, dorange_price = add_apple_orange_layer.backward(dall_price)  # (3)
dorange, dorange_num = mul_orange_layer.backward(dorange_price)  # (2)
dapple, dapple_num = mul_apple_layer.backward(dapple_price)  # (1)


print(price) # 715
print(dapple_num, dapple, dorange, dorange_num, dtax) # 110, 2.2, 3.3, 165, 650

 

 필요한 계층을 만들어 순전파 메서드인 forward()를 적절한 순서로 호출합니다. 그런 다음 순전파와 반대 순서로 역전파 메서드인 backward()를 호출하면 원하는 미분이 나옵니다.


 

 이번 포스팅에서는 사과 쇼핑 예를 통해 역전파의 구조에 대해 알아보았습니다.

 다음 포스팅에서는 역전파 활성화 함수 계층을 구현해보겠습니다. 감사합니다.

 

반응형