본문 바로가기
R, Python

[Pytorch] scatter_add_ 함수 이해

by 방구석 몽상가 2021. 12. 3.

Tensor.scatter_add_ 함수에 대해 이해해보자. parameter는 dim, index (LongTensor), src다. 이 함수의 기능은 dim-axis를 따라 self 텐서의 index에 src를 더해주는 함수다. 

원본 문서에서 다음과 같이 좋은 예시를 보여주고 있다.

Example

다음과 같이 src와 index 텐서를 매개변수로 함수를 호출했다고 하자.

src = torch.ones((2,5)) # [2x5]
index = torch.tensor([[0,1,2,0,0]]) # [1x5]
torch.zeros(3,5,dtype=src.dtype).scatter_add_(0,index,src)
"""
tensor([[1., 0., 0., 1., 1.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.]])
"""

2차원 행렬이기 때문에 좀 더 이해하기 쉬운데 쉽게 생각하면 index 텐서의 각 요소는 행의 선택을 의미하고 이 텐서의 index는 열의 선택을 의미한다고 볼 수 있다. 그래서 0행 0열, 1행, 1열, 2행 2열, 0행 3열, 0행 4열이 선택된다고 보면 된다. 이제i index 텐서의 0 axis를 확장해보자.

index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2]]) # [2x5]
torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src)
"""
tensor([[2., 0., 0., 1., 1.],
        [0., 2., 0., 0., 0.],
        [0., 0., 2., 1., 1.]])
"""

1차원이 추가되도 동일하게 적용할 수 있는데 [0행 0열, 1행 1열, 2행 2열, 0행 3열, 0행 4열], [0행 0열, 1행 1열, 2행 2열, 2행 3열, 2행 4열]이 되어 위와 같은 결과가 나오게 된다.

 

댓글