LSTM Cell

위 그림은 LSTM Cell의 구조이다.
LSTM의 Cell을 구현한 코드는 아래와 같다.
from typing import Optional, Tuple
import torch
from torch import nn
class LSTMCell(nn.Module):
def __init__(self, input_size: int, hidden_size: int):
super().__init__()
self.hidden_lin = nn.Linear(hidden_size, 4 * hidden_size)
self.input_lin = nn.Linear(input_size, 4 * hidden_size, bias = False)
def forward(self, x: torch.Tensor, h: torch.Tensor, c: torch.Tensor):
ifgo = self.hidden_lin(h) + self.input_lin(x)
i, f, g, o = ifgo.chunk(4, dim = -1) # ifgo 텐서를 4등분한다.
c_next = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g)
h_next = torch.sigmoid(o) * torch.tanh(c_next)
return h_next, c_next

h와 x가 각각 self.hidden_lin, self.input_lin를 통해 선형 변환 한 뒤 합쳐지기 때문에 self.input_lin에 bias가 없어도 된다.
def forward(self, x: torch.Tensor, h: torch.Tensor, c: torch.Tensor):
ifgo = self.hidden_lin(h) + self.input_lin(x)

i, f, g, o = ifgo.chunk(4, dim = -1) # ifgo 텐서를 4등분한다.

c_next = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g)

h_next = torch.sigmoid(o) * torch.tanh(c_next)
return h_next, c_next

LSTM

위 그림은 2층짜리 LSTM의 구조이다.
LSTM의 구현 코드는 아래와 같다.
class LSTM(nn.Module):
def __init__(self, input_size: int, hidden_size: int, n_layers: int):
super().__init__()
self.n_layers = n_layers
self.hidden_size = hidden_size
self.cells = nn.ModuleList([LSTMCell(input_size, hidden_size)] + [LSTMCell(hidden_size, hidden_size) for _ in range(n_layers - 1)])
def forward(self, x: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
seq_len, batch_size = x.shape[:2] # x의 크기 : [seq_len, batch_size, input_size]
if state is None:
h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
else:
(h, c) = state
h, c = list(torch.unbind(h)), list(torch.unbind(c))
out = []
for t in range(seq_len):
input = x[t]
for layer in range(self.n_layers):
h[layer], c[layer] = self.cells[layer](input, h[layer], c[layer])
input = h[layer]
out.append(h[-1])
out = torch.stack(out)
h = torch.stack(h)
c = torch.stack(c)
return out, (h, c)
코드를 자세히 살펴보자.
class LSTM(nn.Module):
def __init__(self, input_size: int, hidden_size: int, n_layers: int):
super().__init__()
self.n_layers = n_layers
self.hidden_size = hidden_size
self.cells = nn.ModuleList([LSTMCell(input_size, hidden_size)] + [LSTMCell(hidden_size, hidden_size) for _ in range(n_layers - 1)])
첫번째 layer의 LSTM Cell은 x를 입력으로 받지만,
두번째 이후 layer의 LSTM Cell은 h를 입력으로 받기 때문에 LSTM Cell의 input size를 달리 한다.


def forward(self, x: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
seq_len, batch_size = x.shape[:2] # x의 크기 : [seq_len, batch_size, input_size]

if state is None:
h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
else:
(h, c) = state
h, c = list(torch.unbind(h)), list(torch.unbind(c))

state 파라미터를 통해 초기 hidden state와 cell state의 정보가 전달된 경우 해당 값을 사용하고,
전달되지 않은 경우 0 값으로 채운 텐서를 생성한다.
hidden state와 cell state의 크기는 [n_layers, batch_size, hidden_size]이다.
out = [] # 각 time step에서 제일 마지막 layer의 hidden state를 담는 리스트
for t in range(seq_len):
input = x[t]
for layer in range(self.n_layers):
h[layer], c[layer] = self.cells[layer](input, h[layer], c[layer])
input = h[layer]
out.append(h[-1])

LSTM Cell을 통해 h[layer]과 c[layer]의 값이 갱신된다.
out = torch.stack(out)
h = torch.stack(h)
c = torch.stack(c)
return out, (h, c)
h와 c는 각각 마지막 time step의 모든 layer에서의 hidden state와 cell state를 stack한 것이다.

'AI > Deep Learning' 카테고리의 다른 글
[Deep Learning] RNN의 구조 (2) | 2023.03.24 |
---|---|
[Deep Learning] modern CNN의 특징 (0) | 2023.03.24 |
[Deep Learning] Regularization (0) | 2023.03.21 |
[Deep Learning] 최적화 기법 (0) | 2023.03.21 |
[Deep Learning] 모델 최적화를 위한 중요한 개념들 (0) | 2023.03.21 |