AI/PyTorch

[PyTorch] torch.reshape과 torch.Tensor.view의 차이

sangwonYoon 2023. 3. 14. 01:48

PyTorch에서 tensor의 차원을 재구성하는 함수인 reshape 함수와 view 함수의 각각의 특징과 차이점에 대해 알아보자.


torch.reshape

  • contiguous 속성을 만족하는 tensor가 인자로 주어진 경우, 데이터를 복사하지 않고, 같은 메모리 공간을 공유하는 새로운 tensor를 반환한다.
  • 반면, contiguous 속성을 만족하지 않는 tensor가 인자로 주어진 경우, 데이터를 복사해 새로운 tensor를 생성해 반환한다.

 

contiguous란 무엇일까?

 

contiguous는 ‘인접한’이라는 뜻을 가진 단어로, tensor의 데이터들이 인덱스 순서대로 메모리가 인접해 있는지를 나타내는 속성이다.

tensor1 = torch.randn(2, 3)
tensor2 = tensor1.t()

r1 = [tensor1[i][j].data_ptr() for i in range(2) for j in range(3)]
print(r1)
# 출력: [140550742229568, 140550742229572, 140550742229576, 140550742229580, 140550742229584, 140550742229588]

r2 = [tensor2[i][j].data_ptr() for i in range(3) for j in range(2)]
print(r2)
# 출력: [140550742229568, 140550742229580, 140550742229572, 140550742229584, 140550742229576, 140550742229588]

tensor1은 torch.float32 자료형의 크기인 4바이트 간격으로 인덱스 순서대로 메모리가 인접해 있는 반면, tensor2는 인덱스 순서대로 메모리가 인접해 있지 않다.

print(tensor1.is_contiguous())
# 출력: True

print(tensor2.is_contiguous())
# 출력: False

 

tensor3 = tensor1.reshape(6)
tensor4 = tensor2.reshape(6)

tensor1[0][0] = 0

# tensor2는 tensor1과 같은 메모리 공간을 공유하므로, tensor2[0][0] 또한 값이 0이 된다.
print(tensor2)
# 출력:
# tensor([[ 0.0000,  0.5888],
#         [-1.0566,  0.2708],
#         [-0.6716,  0.8845]])

# tensor3 또한 tensor1과 같은 메모리 공간을 공유하므로, tensor3[0]의 값이 0이 된다.
print(tensor3)
# 출력: tensor([ 0.0000, -1.0566, -0.6716,  0.5888,  0.2708,  0.8845])

# tensor4는 tensor2의 값을 복사하여 생성된 tensor이므로, 영향을 받지 않는다.
print(tensor4)
# 출력: tensor([ 0.6098,  0.5888, -1.0566,  0.2708, -0.6716,  0.8845])
  • contiguous한 tensor인 tensor1로부터 생성된 tensor3tensor1과 같은 메모리 공간을 공유한다.
  • contiguous하지 않은 tensor인 tensor2로부터 생성된 tensor4tensor2와 독립된 공간에 존재한다.

 

torch.Tensor.view

  • torch.Tensor.view는 기존의 tensor와 새로운 tensor가 같은 메모리 공간을 공유한다.
tensor1 = torch.zeros(3, 2) 
tensor2 = tensor1.view(2, 3) 
tensor1.fill_(1)

print(tensor2)
# 출력:
# tensor([[1., 1., 1.],
#         [1., 1., 1.]])

tensor1의 값을 바꾸자, tensor1로부터 생성된 tensor2의 값도 바뀌는 것을 확인할 수 있다.

 

  • torch.Tensor.view는 contiguous 속성을 만족하는 tensor만 인자로 받는다.
tensor1 = torch.randn(2, 3)
tensor2 = tensor1.t()

print(tensor2.is_contiguous())
# 출력: False

tensor3 = tensor2.view(6)
# RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

contiguous하지 않은 tensor를 인자로 넣을 경우, torch.reshape 함수를 사용하라는 에러가 발생한다.

 

<참조>

  • https://pytorch.org/docs/stable/generated/torch.matmul.html#torch.matmul
  • https://pytorch.org/docs/stable/generated/torch.mm.html?highlight=matrix+multiplication
  • https://stackoverflow.com/questions/73924697/whats-the-difference-between-torch-mm-torch-matmul-and-torch-mul

'AI > PyTorch' 카테고리의 다른 글

[Pytorch] Dataset과 DataLoader  (0) 2023.03.16
[PyTorch] 파라미터 구현하기  (0) 2023.03.16
[PyTorch] 신경망 모델과 torch.nn.module  (0) 2023.03.16
[PyTorch] torch.mm과 torch.matmul의 차이  (0) 2023.03.14
[PyTorch] PyTorch 기초  (2) 2023.03.13