tensoreflusso:: ops:: Calcola colpi accidentali
#include <candidate_sampling_ops.h>
Calcola gli ID delle posizioni in sampled_candidates che corrispondono a true_labels.
Riepilogo
Quando si eseguono NCE log-odds, il risultato di questa operazione dovrebbe essere passato attraverso un'operazione SparseToDense, quindi aggiunto ai logit dei candidati campionati. Ciò ha l'effetto di "rimuovere" le etichette campionate che corrispondono alle etichette reali assicurando al classificatore che si tratti di etichette campionate.
Argomenti:
- scope: un oggetto Scope
- true_classes: l'output true_classes di UnpackSparseLabels.
- sampled_candidates: l'output sampled_candidates di CandidateSampler.
- num_true: numero di etichette vere per contesto.
Attributi facoltativi (vedi Attrs
):
- seme: se seme o seme2 sono impostati su un valore diverso da zero, il generatore di numeri casuali viene seminato dal seme specificato. Altrimenti, viene seminato da un seme casuale.
- seed2: un secondo seme per evitare la collisione del seme.
Resi:
- Indici
Output
: un vettore di indici corrispondenti a righe di true_candidates. - ID
Output
: un vettore di ID di posizioni in sampled_candidates che corrispondono a true_label per la riga con l'indice corrispondente in indicis. - Pesi
Output
: un vettore della stessa lunghezza di indici e ID, in cui ogni elemento è -FLOAT_MAX.
Costruttori e distruttori | |
---|---|
ComputeAccidentalHits (const :: tensorflow::Scope & scope, :: tensorflow::Input true_classes, :: tensorflow::Input sampled_candidates, int64 num_true) | |
ComputeAccidentalHits (const :: tensorflow::Scope & scope, :: tensorflow::Input true_classes, :: tensorflow::Input sampled_candidates, int64 num_true, const ComputeAccidentalHits::Attrs & attrs) |
Attributi pubblici | |
---|---|
ids | |
indices | |
operation | |
weights |
Funzioni pubbliche statiche | |
---|---|
Seed (int64 x) | |
Seed2 (int64 x) |
Strutture | |
---|---|
tensorflow:: ops:: ComputeAccidentalHits:: Attrs | Setter di attributi facoltativi per ComputeAccidentalHits . |
Attributi pubblici
ID
::tensorflow::Output ids
indici
::tensorflow::Output indices
operazione
Operation operation
pesi
::tensorflow::Output weights
Funzioni pubbliche
Calcola colpi accidentali
ComputeAccidentalHits( const ::tensorflow::Scope & scope, ::tensorflow::Input true_classes, ::tensorflow::Input sampled_candidates, int64 num_true )
Calcola colpi accidentali
ComputeAccidentalHits( const ::tensorflow::Scope & scope, ::tensorflow::Input true_classes, ::tensorflow::Input sampled_candidates, int64 num_true, const ComputeAccidentalHits::Attrs & attrs )
Funzioni pubbliche statiche
Seme
Attrs Seed( int64 x )
Seme2
Attrs Seed2( int64 x )