Esegue più predittori di ensemble di regressione additiva su istanze di input e
calcola l'aggiornamento ai logit memorizzati nella cache. È progettato per essere utilizzato durante l'allenamento. Attraversa gli alberi a partire dall'ID dell'albero memorizzato nella cache e dall'ID del nodo memorizzato nella cache e calcola gli aggiornamenti da inviare alla cache.
Costanti
Corda | OP_NAME | Il nome di questa operazione, come noto al motore principale di TensorFlow |
Metodi pubblici
statico BoostedTreesTrainingPredict | |
Uscita < TInt32 > | ID nodo () Tensore di rango 1 contenente i nuovi ID dei nodi nei nuovi tree_ids. |
Uscita < TFloat32 > | log parziali () Tensore di grado 2 contenente l'aggiornamento dei logit (rispetto ai valori memorizzati nella cache) per ogni esempio. |
Uscita < TInt32 > | IDalbero () Tensore di rango 1 contenente nuovi ID albero per ogni esempio. |
Metodi ereditati
Costanti
Stringa finale statica pubblica OP_NAME
Il nome di questa operazione, come noto al motore principale di TensorFlow
Metodi pubblici
public static BoostedTreesTrainingPredict create ( Scope scope, Operand <?> treeEnsembleHandle, Operand < TInt32 > cachedTreeIds, Operand < TInt32 > cachedNodeIds, Iterable< Operand < TInt32 >> bucketizedFeatures, Long logitsDimension)
Metodo factory per creare una classe che racchiude una nuova operazione BoostedTreesTrainingPredict.
Parametri
scopo | ambito attuale |
---|---|
cachedTreeIds | Tensore di rango 1 contenente gli ID degli alberi memorizzati nella cache che è l'albero di previsione iniziale. |
cachedNodeId | Tensore di rango 1 contenente l'ID del nodo memorizzato nella cache che è il nodo iniziale della previsione. |
Caratteristiche con bucket | Un elenco di tensori di rango 1 contenente l'ID bucket per ciascuna funzionalità. |
logitsDimension | scalare, dimensione dei logit, da utilizzare per la forma dei logit parziali. |
ritorna
- una nuova istanza di BoostedTreesTrainingPredict
output pubblico < TInt32 > nodeId ()
Tensore di rango 1 contenente i nuovi ID dei nodi nei nuovi tree_ids.
Output pubblico < TFloat32 > partialLogits ()
Tensore di grado 2 contenente l'aggiornamento dei logit (rispetto ai valori memorizzati nella cache) per ogni esempio.