AllToAll

공개 최종 수업 AllToAll

TPU 복제본 간에 데이터를 교환하는 작업입니다.

각 복제본에서 입력은 `split_dimension`을 따라 `split_count` 블록으로 분할되고 group_location이 지정된 다른 복제본으로 전송됩니다. 다른 복제본으로부터 'split_count' - 1개의 블록을 받은 후 출력으로 'concat_dimension'을 따라 블록을 연결합니다.

예를 들어 TPU 복제본이 2개 있다고 가정합니다. 복제본 0은 입력: `[[A, B]]`를 수신합니다. 복제본 1은 입력: `[[C, D]]`를 수신합니다.

group_location=`[[0, 1]]` concat_dimension=0 분할_dimension=1 분할_count=2

복제본 0의 출력: `[[A], [C]]` 복제본 1의 출력: `[[B], [D]]`

상수

OP_NAME TensorFlow 핵심 엔진에서 알려진 이 작업의 이름

공개 방법

출력 <T>
출력 ()
텐서의 기호 핸들을 반환합니다.
static <T는 TType을 확장합니다. > AllToAll <T>
생성 ( 범위 범위, 피연산자 <T> 입력, 피연산자 < TInt32 > groupAssignment, Long concatDimension, Long SplitDimension, Long SplitCount)
새로운 AllToAll 작업을 래핑하는 클래스를 생성하는 팩토리 메서드입니다.
출력 <T>
출력 ()
교환된 결과입니다.

상속된 메서드

상수

공개 정적 최종 문자열 OP_NAME

TensorFlow 핵심 엔진에서 알려진 이 작업의 이름

상수 값: "AllToAll"

공개 방법

공개 출력 <T> asOutput ()

텐서의 기호 핸들을 반환합니다.

TensorFlow 작업에 대한 입력은 다른 TensorFlow 작업의 출력입니다. 이 메서드는 입력 계산을 나타내는 기호 핸들을 얻는 데 사용됩니다.

public static AllToAll <T> create ( 범위 범위, 피연산자 <T> 입력, 피연산자 < TInt32 > groupAssignment, Long concatDimension, Long SplitDimension, Long SplitCount)

새로운 AllToAll 작업을 래핑하는 클래스를 생성하는 팩토리 메서드입니다.

매개변수
범위 현재 범위
입력 합계에 대한 로컬 입력입니다.
그룹할당 [num_groups, num_replicas_per_group] 형태의 int32 텐서. `group_location[i]`는 i번째 하위 그룹의 복제본 ID를 나타냅니다.
concatDimension 연결할 차원 번호입니다.
분할차원 분할할 차원 번호입니다.
분할 개수 분할 수, 이 수는 하위 그룹 크기(group_location.get_shape()[1])와 동일해야 합니다.
보고
  • AllToAll의 새 인스턴스

공개 출력 <T> 출력 ()

교환된 결과입니다.