AI/PyTorch

[PyTorch] torch.mm과 torch.matmul의 차이

sangwonYoon 2023. 3. 14. 00:11

PyTorch에서 tensor의 곱을 연산하는 torch.mm과 torch.matmul의 각각의 특징과 차이점을 알아보자.


torch.mm

  • 2D tensor와 2D tensor의 행렬 곱셈을 수행한다.
  • broadcast를 지원하지 않는다.
torch.mm(input, mat2, *, out=None) → Tensor

 

input의 크기가 (n x m)인 경우, mat2의 크기는 (m x p)여야 하고, output의 크기는 (n x p)가 된다.

 

 

torch.matmul

  • tensor와 tensor의 행렬 곱셈을 수행한다.
  • broadcasting을 지원한다. 따라서 의도치 않은 결과가 나올 수 있다는 점에 주의해야 한다.

 

vector와 vector의 연산

tensor1 = torch.tensor([1,2,3])
tensor2 = torch.tensor([4,5,6])

print(torch.matmul(tensor1, tensor2))
# 출력: tensor(32)

두 인자가 모두 벡터인 경우, 내적 연산이 수행된다.

 

matrix와 vector의 연산

tensor1 = torch.tensor([[1,2,3], [4,5,6], [7,8,9]])
tensor2 = torch.tensor([1, 10, 100])

# case1
result1 = torch.matmul(tensor1, tensor2)
print(result1)
# 출력: tensor([321, 654, 987])

# case2
result2 = torch.mm(tensor1, torch.unsqueeze(tensor2, 1)).squeeze()
print(result2)
# 출력: tensor([321, 654, 987])

case1의 연산은 case2의 연산과 동일하다.

 

batched matrix와 broadcasted vector의 연산

 

3차원 이상의 tensor가 인자로 주어질 경우, 해당 tensor의 0차원은 batch로 간주한다.

tensor1 = torch.tensor(np.arange(24).reshape(2,3,4))

print(tensor1)
# 출력:
# tensor([[[ 0,  1,  2,  3],
#          [ 4,  5,  6,  7],
#          [ 8,  9, 10, 11]],

#         [[12, 13, 14, 15],
#          [16, 17, 18, 19],
#          [20, 21, 22, 23]]])

tensor2 = torch.tensor(np.arange(4))

print(tensor2)
# 출력: tensor([0, 1, 2, 3])

# case1
result1 = torch.matmul(tensor1, tensor2)
print(result1)
# 출력: 
# tensor([[ 14,  38,  62],
#         [ 86, 110, 134]])

# case2
result2 = torch.stack((
	torch.matmul(tensor1[0], tensor2),
	torch.matmul(tensor1[1], tensor2),
))
print(result2)
# 출력:
# tensor([[ 14,  38,  62],
#         [ 86, 110, 134]])

case1의 연산은 case2의 연산과 동일하다.

 

batched matrix와 batched matrix의 연산

tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)

print(torch.matmul(tensor1, tensor2).size())
# 출력: torch.Size([10, 3, 5])

 

batched matrix와 broadcasted matrix의 연산

tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4, 5)

print(torch.matmul(tensor1, tensor2).size())
# 출력: torch.Size([10, 3, 5])

 

 

<참조>

  • https://pytorch.org/docs/stable/generated/torch.matmul.html#torch.matmul
  • https://pytorch.org/docs/stable/generated/torch.mm.html?highlight=matrix+multiplication
  • https://stackoverflow.com/questions/73924697/whats-the-difference-between-torch-mm-torch-matmul-and-torch-mul