본문 바로가기

내가 보려고 만든 Pytorch

내가 공부하려 만든 Pytorch6(view,squeeze)

View 
원소의 수를 유지하면서 텐서의 크기 변경.

파이토치 텐서의 뷰(View)는 넘파이에서의 리쉐이프(Reshape)와 같은 역할을 합니다. Reshape라는 이름에서 알 수 있듯이, 텐서의 크기(Shape)를 변경해주는 역할을 합니다.

 

1. 3d 텐서 생성

t = np.array([[[0, 1, 2],
               [3, 4, 5]],
              [[6, 7, 8],
               [9, 10, 11]]])
ft = torch.FloatTensor(t)
print(ft)
print(ft.shape)
print(ft.dim())

tensor([[[ 0.,  1.,  2.],
         [ 3.,  4.,  5.]],

        [[ 6.,  7.,  8.],
         [ 9., 10., 11.]]])
torch.Size([2, 2, 3])
3

 

2. 3d 텐서에서 2d 텐서로 변경

print(ft.view([-1,3]))
print(ft.view([-1,3]).shape)
print(ft.view([-1,3]).dim())

tensor([[ 0.,  1.,  2.],
        [ 3.,  4.,  5.],
        [ 6.,  7.,  8.],
        [ 9., 10., 11.]])
torch.Size([4, 3])
2

view([-1, 3])이 가지는 의미는 이와 같습니다. -1은 첫번째 차원은 사용자가 잘 모르겠으니 파이토치에 맡기겠다는 의미이고, 3은 두번째 차원의 길이는 3을 가지도록 하라는 의미입니다. 다시 말해 현재 3차원 텐서를 2차원 텐서로 변경하되 (?, 3)의 크기로 변경하라는 의미입니다. 

 

 

Squeeze

1인 차원을 제거한다.

스퀴즈는 차원이 1인 경우에는 해당 차원을 제거합니다.

 

1. 2d 행령 생성

ft = torch.FloatTensor([[0], [1], [2]])
print(ft)
print(ft.shape)
print(ft.dim())

tensor([[0.],
        [1.],
        [2.]])
torch.Size([3, 1])
2

 

2. squeeze 사용

print(ft.squeeze())
print(ft.squeeze().shape)
print(ft.squeeze().dim())

tensor([0., 1., 2.])
torch.Size([3])#1차원 벡터로 바뀜
1#차원축소

 

5차원 squeeeze

1. 5d 텐서 생성

# (A, B, 1, C, 1) 차원 형태 텐서
x = torch.ones(10, 5, 1, 3, 1)
print(x.dim())
print(x.shape)

5
torch.Size([10, 5, 1, 3, 1])

2. size가 1인 차원 전체 삭제 : (A, B, C) 차원 형태

x1 = x.squeeze() # torch.squeeze(x) 가능
x1.shape # torch.Size([10, 5, 3])
print(x1.dim())
print(x1.shape)

3
torch.Size([10, 5, 3])

 

3-1. size가 1인 차원 일부 삭제 : (A, B, 1, C) 차원 형태

x2 = x.squeeze(dim = 2) # x.squeeze(2) 가능
x2.shape # torch.Size([10, 5, 3, 1])
print(x2.dim())
print(x2.shape)

4
torch.Size([10, 5, 3, 1])

3-2

x3 = x.squeeze(dim = -1) # dim = 4와 동일한 결과
x3.shape # torch.Size([10, 5, 1, 3])
print(x3.dim())
print(x3.shape)

4
torch.Size([10, 5, 1, 3])

 

4. size가 1이 아닌 차원 삭제 시도(불가능)

x4 = x.squeeze(dim = 1)
x4.shape # torch.Size([10, 5, 1, 3, 1])
print(x4.dim())
print(x4.shape)

5
torch.Size([10, 5, 1, 3, 1])

 

 

 

출처 : https://wikidocs.net/52846