Exécute plusieurs prédicteurs d'ensemble de régression additive sur les instances d'entrée et
calcule la mise à jour des logits mis en cache. Il est conçu pour être utilisé pendant l’entraînement. Il parcourt les arbres en commençant par l'ID d'arbre mis en cache et l'ID de nœud mis en cache et calcule les mises à jour à pousser vers le cache.
Constantes
Chaîne | OP_NAME | Le nom de cette opération, tel que connu par le moteur principal TensorFlow |
Méthodes publiques
statique BoostedTreesTrainingPredict | |
Sortie < TInt32 > | ID de nœud () Tenseur de rang 1 contenant de nouveaux identifiants de nœuds dans les nouveaux tree_ids. |
Sortie < TFloat32 > | partielLogits () Tenseur de rang 2 contenant la mise à jour des logits (par rapport aux valeurs mises en cache stockées) pour chaque exemple. |
Sortie < TInt32 > | ID d'arbre () Tenseur de rang 1 contenant de nouveaux identifiants d'arbre pour chaque exemple. |
Méthodes héritées
Constantes
chaîne finale statique publique OP_NAME
Le nom de cette opération, tel que connu par le moteur principal TensorFlow
Méthodes publiques
public static BoostedTreesTrainingPredict créer ( Scope scope, Operand <?> treeEnsembleHandle, Operand < TInt32 > cachedTreeIds, Operand < TInt32 > cachedNodeIds, Iterable < Operand < TInt32 >> bucketizedFeatures, Long logitsDimension)
Méthode d'usine pour créer une classe encapsulant une nouvelle opération BoostedTreesTrainingPredict.
Paramètres
portée | portée actuelle |
---|---|
ID d'arbre mis en cache | Tenseur de rang 1 contenant les identifiants d'arbre mis en cache qui est l'arbre de prédiction de départ. |
ID de nœud mis en cache | Tenseur de rang 1 contenant l'identifiant du nœud mis en cache qui est le nœud de départ de la prédiction. |
bucketizedCaractéristiques | Une liste de Tensors de rang 1 contenant l'identifiant du compartiment pour chaque fonctionnalité. |
logitsDimension | scalaire, dimension des logits, à utiliser pour la forme des logits partiels. |
Retour
- une nouvelle instance de BoostedTreesTrainingPredict
Sortie publique < TInt32 > nodeIds ()
Tenseur de rang 1 contenant de nouveaux identifiants de nœuds dans les nouveaux tree_ids.
Sortie publique < TFloat32 > partialLogits ()
Tenseur de rang 2 contenant la mise à jour des logits (par rapport aux valeurs mises en cache stockées) pour chaque exemple.