Tensor.scatter_add_ 함수에 대해 이해해보자. parameter는 dim, index (LongTensor), src다. 이 함수의 기능은 dim-axis를 따라 self 텐서의 index에 src를 더해주는 함수다.
원본 문서에서 다음과 같이 좋은 예시를 보여주고 있다.
다음과 같이 src와 index 텐서를 매개변수로 함수를 호출했다고 하자.
src = torch.ones((2,5)) # [2x5]
index = torch.tensor([[0,1,2,0,0]]) # [1x5]
torch.zeros(3,5,dtype=src.dtype).scatter_add_(0,index,src)
"""
tensor([[1., 0., 0., 1., 1.],
[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.]])
"""
2차원 행렬이기 때문에 좀 더 이해하기 쉬운데 쉽게 생각하면 index 텐서의 각 요소는 행의 선택을 의미하고 이 텐서의 index는 열의 선택을 의미한다고 볼 수 있다. 그래서 0행 0열, 1행, 1열, 2행 2열, 0행 3열, 0행 4열이 선택된다고 보면 된다. 이제i index 텐서의 0 axis를 확장해보자.
index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2]]) # [2x5]
torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src)
"""
tensor([[2., 0., 0., 1., 1.],
[0., 2., 0., 0., 0.],
[0., 0., 2., 1., 1.]])
"""
1차원이 추가되도 동일하게 적용할 수 있는데 [0행 0열, 1행 1열, 2행 2열, 0행 3열, 0행 4열], [0행 0열, 1행 1열, 2행 2열, 2행 3열, 2행 4열]이 되어 위와 같은 결과가 나오게 된다.
'R, Python' 카테고리의 다른 글
[Python, AI] 효율적인 리스트 연산 (0) | 2022.01.19 |
---|---|
[PyQt5] Pytorch3D를 이용한 Mesh Viewer 구현 (0) | 2020.04.13 |
[python] PyQt5로 이미지 뷰어 만들기 (간단한 라벨링 툴 만들기 1단계) (6) | 2019.11.14 |
[C/C++/C#] 콘솔에서 글자 색 변경하기 (예시: KMP 알고리즘) (0) | 2019.11.05 |
[python] 네이트판 웹 크롤러를 만들어보자! (0) | 2019.02.13 |
댓글