PyTorch에서 tensor의 곱을 연산하는 torch.mm과 torch.matmul의 각각의 특징과 차이점을 알아보자.
torch.mm
- 2D tensor와 2D tensor의 행렬 곱셈을 수행한다.
- broadcast를 지원하지 않는다.
torch.mm(input, mat2, *, out=None) → Tensor
input의 크기가 (n x m)인 경우, mat2의 크기는 (m x p)여야 하고, output의 크기는 (n x p)가 된다.
torch.matmul
- tensor와 tensor의 행렬 곱셈을 수행한다.
- broadcasting을 지원한다. 따라서 의도치 않은 결과가 나올 수 있다는 점에 주의해야 한다.
vector와 vector의 연산
tensor1 = torch.tensor([1,2,3])
tensor2 = torch.tensor([4,5,6])
print(torch.matmul(tensor1, tensor2))
# 출력: tensor(32)
두 인자가 모두 벡터인 경우, 내적 연산이 수행된다.
matrix와 vector의 연산
tensor1 = torch.tensor([[1,2,3], [4,5,6], [7,8,9]])
tensor2 = torch.tensor([1, 10, 100])
# case1
result1 = torch.matmul(tensor1, tensor2)
print(result1)
# 출력: tensor([321, 654, 987])
# case2
result2 = torch.mm(tensor1, torch.unsqueeze(tensor2, 1)).squeeze()
print(result2)
# 출력: tensor([321, 654, 987])
case1의 연산은 case2의 연산과 동일하다.
batched matrix와 broadcasted vector의 연산
3차원 이상의 tensor가 인자로 주어질 경우, 해당 tensor의 0차원은 batch로 간주한다.
tensor1 = torch.tensor(np.arange(24).reshape(2,3,4))
print(tensor1)
# 출력:
# tensor([[[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]],
# [[12, 13, 14, 15],
# [16, 17, 18, 19],
# [20, 21, 22, 23]]])
tensor2 = torch.tensor(np.arange(4))
print(tensor2)
# 출력: tensor([0, 1, 2, 3])
# case1
result1 = torch.matmul(tensor1, tensor2)
print(result1)
# 출력:
# tensor([[ 14, 38, 62],
# [ 86, 110, 134]])
# case2
result2 = torch.stack((
torch.matmul(tensor1[0], tensor2),
torch.matmul(tensor1[1], tensor2),
))
print(result2)
# 출력:
# tensor([[ 14, 38, 62],
# [ 86, 110, 134]])
case1의 연산은 case2의 연산과 동일하다.
batched matrix와 batched matrix의 연산
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
print(torch.matmul(tensor1, tensor2).size())
# 출력: torch.Size([10, 3, 5])
batched matrix와 broadcasted matrix의 연산
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4, 5)
print(torch.matmul(tensor1, tensor2).size())
# 출력: torch.Size([10, 3, 5])
<참조>
- https://pytorch.org/docs/stable/generated/torch.matmul.html#torch.matmul
- https://pytorch.org/docs/stable/generated/torch.mm.html?highlight=matrix+multiplication
- https://stackoverflow.com/questions/73924697/whats-the-difference-between-torch-mm-torch-matmul-and-torch-mul
'AI > PyTorch' 카테고리의 다른 글
[Pytorch] Dataset과 DataLoader (0) | 2023.03.16 |
---|---|
[PyTorch] 파라미터 구현하기 (0) | 2023.03.16 |
[PyTorch] 신경망 모델과 torch.nn.module (0) | 2023.03.16 |
[PyTorch] torch.reshape과 torch.Tensor.view의 차이 (0) | 2023.03.14 |
[PyTorch] PyTorch 기초 (2) | 2023.03.13 |