AI/Deep Learning

[Deep Learning] LSTM을 직접 구현해보자!

sangwonYoon 2023. 3. 29. 23:58

LSTM Cell

LSTM Cell의 구조

위 그림은 LSTM Cell의 구조이다.

LSTM의 Cell을 구현한 코드는 아래와 같다.

from typing import Optional, Tuple

import torch
from torch import nn

class LSTMCell(nn.Module):

    def __init__(self, input_size: int, hidden_size: int):
        super().__init__()
        self.hidden_lin = nn.Linear(hidden_size, 4 * hidden_size)
        self.input_lin = nn.Linear(input_size, 4 * hidden_size, bias = False)

    def forward(self, x: torch.Tensor, h: torch.Tensor, c: torch.Tensor):
        ifgo = self.hidden_lin(h) + self.input_lin(x)
        i, f, g, o = ifgo.chunk(4, dim = -1) # ifgo 텐서를 4등분한다.
        c_next = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g)
        h_next = torch.sigmoid(o) * torch.tanh(c_next)

        return h_next, c_next

 

h와 x가 각각 self.hidden_lin, self.input_lin를 통해 선형 변환 한 뒤 합쳐지기 때문에 self.input_lin에 bias가 없어도 된다.

 

    def forward(self, x: torch.Tensor, h: torch.Tensor, c: torch.Tensor):
        ifgo = self.hidden_lin(h) + self.input_lin(x)

 

        i, f, g, o = ifgo.chunk(4, dim = -1) # ifgo 텐서를 4등분한다.

 

        c_next = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g)

 

        h_next = torch.sigmoid(o) * torch.tanh(c_next)
	
        return h_next, c_next

 

LSTM

위 그림은 2층짜리 LSTM의 구조이다.

LSTM의 구현 코드는 아래와 같다.

class LSTM(nn.Module):

    def __init__(self, input_size: int, hidden_size: int, n_layers: int):
        super().__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.cells = nn.ModuleList([LSTMCell(input_size, hidden_size)] + [LSTMCell(hidden_size, hidden_size) for _ in range(n_layers - 1)])

    def forward(self, x: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
        seq_len, batch_size = x.shape[:2] # x의 크기 : [seq_len, batch_size, input_size]

        if state is None:
            h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
            c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
        else:
            (h, c) = state
            h, c = list(torch.unbind(h)), list(torch.unbind(c))

        out = []
        for t in range(seq_len):
            input = x[t]
            for layer in range(self.n_layers):
                h[layer], c[layer] = self.cells[layer](input, h[layer], c[layer])
                input = h[layer]
            out.append(h[-1])

        out = torch.stack(out)
        h = torch.stack(h)
        c = torch.stack(c)

        return out, (h, c)

 

코드를 자세히 살펴보자.

class LSTM(nn.Module):

    def __init__(self, input_size: int, hidden_size: int, n_layers: int):
        super().__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.cells = nn.ModuleList([LSTMCell(input_size, hidden_size)] + [LSTMCell(hidden_size, hidden_size) for _ in range(n_layers - 1)])

첫번째 layer의 LSTM Cell은 x를 입력으로 받지만,
두번째 이후 layer의 LSTM Cell은 h를 입력으로 받기 때문에 LSTM Cell의 input size를 달리 한다.

첫번째 layer의 LSTM Cell
두번째 이후 layer의 LSTM Cell

 

    def forward(self, x: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
        seq_len, batch_size = x.shape[:2] # x의 크기 : [seq_len, batch_size, input_size]

 

        if state is None:
            h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
            c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
        else:
            (h, c) = state
            h, c = list(torch.unbind(h)), list(torch.unbind(c))

state 파라미터를 통해 초기 hidden state와 cell state의 정보가 전달된 경우 해당 값을 사용하고,
전달되지 않은 경우 0 값으로 채운 텐서를 생성한다.

hidden state와 cell state의 크기는 [n_layers, batch_size, hidden_size]이다.

 

        out = [] # 각 time step에서 제일 마지막 layer의 hidden state를 담는 리스트
        for t in range(seq_len):
            input = x[t]
            for layer in range(self.n_layers):
                h[layer], c[layer] = self.cells[layer](input, h[layer], c[layer])
                input = h[layer]
            out.append(h[-1])

LSTM Cell을 통해 h[layer]과 c[layer]의 값이 갱신된다.

 

        out = torch.stack(out)
        h = torch.stack(h)
        c = torch.stack(c)

        return out, (h, c)

h와 c는 각각 마지막 time step의 모든 layer에서의 hidden state와 cell state를 stack한 것이다.