Post

BackPropagation 정리

BackPropagation 정리

backpropagation

등장 배경

  • 1960-80년대 초

    단일 퍼셉트론과 같은 얕은 신경망에 초점이 맞춰져 있었음

    다층 퍼셉트론과 같은 깊은 신경망의 학습은 역전파 알고리즘이 개발되기 전까지는 해결되기 어려운 문제

  • 1986년

    Geoffery Hinton, David Rumelhart, Ronald J. Williams 등이 역전파 알고리즘을 발표하여, 다층 퍼셉트론을 비롯한 깊은 신경망 학습이 가능해짐

필요성

  • MNIST 데이터셋을 이용하여 숫자를 인식하는 신경망 모델

  • 입력층 뉴런 784개
  • 은닉층 뉴런 100개
  • 출력층 뉴런 10개

  • 계산량

  • 연결 가중치 개수 784×100+100×10=79,400
  • 편향 개수 100+10=110
  • 한 개의 이미지 손실 계산 79,400번의 연산량(가중치만 고려)
  • 한 개의 파라미터를 업데이트 하기 위한 연산 횟수: 79,400×60,000=4,764,000,000
  • 모든 파라미터를 업데이트 4,764,000,000×79,511=378,790,404,000,000
  • CPU 처리 속도가 초당 8억 5천만 번이라면 378,790,404,000,000850,000,000=445,635(123시간)

따라서 1 epoch에 약 123시간이 필요함

역전파 과정

예시 신경망

  • 입력층 뉴런 2개
  • 은닉층 뉴런 2개
  • 출력층 뉴런 1개
  • 활성화 함수: sigmoid
  • 손실 함수: mse
  • 모든 가중치와 학습률은 랜덤 배치
graph LR
    %% Input Layer
    x1((x₁))
    x2((x₂))
    
    %% Hidden Layer
    h1((h₁))
    h2((h₂))
    
    %% Output Layer
    o1((o₁))
    
    %% Connections from input to hidden layer
    x1 -- w₁ --> h1
    x1 -- w₂ --> h2
    x2 -- w₃ --> h1
    x2 -- w₄ --> h2
    
    %% Connections from hidden to output layer
    h1 -- w₅ --> o1
    h2 -- w₆ --> o1
    
    %% Styling
    classDef inputNode fill:#93c5fd,stroke:#1e40af,stroke-width:2px;
    classDef hiddenNode fill:#a5b4fc,stroke:#3730a3,stroke-width:2px;
    classDef outputNode fill:#c4b5fd,stroke:#5b21b6,stroke-width:2px;
    
    class x1,x2 inputNode;
    class h1,h2 hiddenNode;
    class o1 outputNode;

    %% Labels
    subgraph Input
        x1
        x2
    end
    
    subgraph Hidden
        h1
        h2
    end
    
    subgraph Output
        o1
    end

Output

Hidden

Input

w₁

w₂

w₃

w₄

w₅

w₆

x₁

x₂

h₁

h₂

o₁

1. 순전파 (Feedforward)

  • 입력값

    x1, x2

  • 입력값과 연결 가중치의 곱을 합하여 은닉층으로 전달

    z1=x1w1+x2w3 z2=x1w2+x2w4
  • 은닉층에서 활성화 함수 적용

    h1=σ(z1) h2=σ(z2)
  • 출력층으로 전달

    z3=h1w5+h2w6
  • 최종 출력값

    o1=σ(z3)

2. 손실 계산

  • 손실 함수 (Mean Squared Error)

    C=1ni=1n(yiyi^)2=12(yo1)2

3. 역전파 (Backpropagation)

가중치 업데이트 공식

w:=wlrC

lr 은 학습률(Learning Rate), C는 손실 함수 C에 대한 기울기

w5에 대한 미분

chain rule 적용

Cw5=Co1o1z3z3w5
  1. Co1 구하기

    C=12(yo1)2

    이를 예측 값 o1에 대해 미분

    Co1=(yo1)
  2. o1z3 구하기

    시그모이드 함수 σ(z3)의 미분

    o1z3=o1(1o1)
  3. z3에 대한 w5의 편미분

    z3=h1w5+h2w6

    따라서

    z3w5=h1

    최종적으로 w5에 대한 손실 기울기

    Cw5=(yo1)o1(1o1)h1

유사한 방식으로 w6에 대한 미분을 구하면

Cw6=(yo1)o1(1o1)h2
graph LR
    %% Input Layer
    x1((x₁))
    x2((x₂))
    
    %% Hidden Layer
    h1((h₁))
    h2((h₂))
    
    %% Output Layer
    o1((o₁))
    
    %% Connections from input to hidden layer
    x1 -- w₁ --> h1
    x1 -- w₂ --> h2
    x2 -- w₃ --> h1
    x2 -- w₄ --> h2
    
    %% Connections from hidden to output layer
    h1 -- w₅' --> o1
    h2 -- w₆' --> o1
    
    %% Styling
    classDef inputNode fill:#93c5fd,stroke:#1e40af,stroke-width:2px;
    classDef hiddenNode fill:#a5b4fc,stroke:#3730a3,stroke-width:2px;
    classDef outputNode fill:#c4b5fd,stroke:#5b21b6,stroke-width:2px;
    
    class x1,x2 inputNode;
    class h1,h2 hiddenNode;
    class o1 outputNode;

    %% Labels
    subgraph Input
        x1
        x2
    end
    
    subgraph Hidden
        h1
        h2
    end
    
    subgraph Output
        o1
    end

Output

Hidden

Input

w₁

w₂

w₃

w₄

w₅'

w₆'

x₁

x₂

h₁

h₂

o₁

다음 레이어 w1에 대한 미분 계산

Cw1=Ch1h1z1z1w1
  1. 손실 함수에 대한 h1의 편미분

    h1z3에 영향을 미치므로, 다음과 같이 미분

    Ch1=Cz3z3h1

    이미 구한 Cz3

    Cz3=(yo1)o1(1o1)

    따라서

    Ch1=(yo1)o1(1o1)w5
  2. h1에 대한 z1의 편미분

    h1은 시그모이드 함수이므로

    h1z1=h1(1h1)
  3. z1에 대한 w1의 편미분

    z1은 입력값과 가중치 w1의 곱

    z1w1=x1

최종적으로 w1에 대한 손실 기울기

Cw1=(yo1)o1(1o1)w5h1(1h1)x1

위의 과정을 반복해서 각 가중치에 대한 손실 기울기를 계산하고, 가중치 업데이트

graph LR
    %% Input Layer
    x1((x₁))
    x2((x₂))
    
    %% Hidden Layer
    h1((h₁))
    h2((h₂))
    
    %% Output Layer
    o1((o₁))
    
    %% Connections from input to hidden layer
    x1 -- w₁' --> h1
    x1 -- w₂' --> h2
    x2 -- w₃' --> h1
    x2 -- w₄' --> h2
    
    %% Connections from hidden to output layer
    h1 -- w₅' --> o1
    h2 -- w₆' --> o1
    
    %% Styling
    classDef inputNode fill:#93c5fd,stroke:#1e40af,stroke-width:2px;
    classDef hiddenNode fill:#a5b4fc,stroke:#3730a3,stroke-width:2px;
    classDef outputNode fill:#c4b5fd,stroke:#5b21b6,stroke-width:2px;
    
    class x1,x2 inputNode;
    class h1,h2 hiddenNode;
    class o1 outputNode;

    %% Labels
    subgraph Input
        x1
        x2
    end
    
    subgraph Hidden
        h1
        h2
    end
    
    subgraph Output
        o1
    end

Output

Hidden

Input

w₁'

w₂'

w₃'

w₄'

w₅'

w₆'

x₁

x₂

h₁

h₂

o₁

References

This post is licensed under CC BY 4.0 by the author.