ดูบน TensorFlow.org | ทำงานใน Google Colab | ดูบน GitHub | ดาวน์โหลดโน๊ตบุ๊ค |
ใน colab นี้ คุณจะได้เรียนรู้วิธีตรวจสอบและสร้างโครงสร้างของแบบจำลองโดยตรง เราถือว่าคุณมีความคุ้นเคยกับแนวคิดที่นำมาใช้ใน การเริ่มต้น และ ขั้นกลาง colabs
ใน colab นี้ คุณจะ:
ฝึกโมเดล Random Forest และเข้าถึงโครงสร้างโดยทางโปรแกรม
สร้างแบบจำลอง Random Forest ด้วยมือและใช้เป็นแบบจำลองคลาสสิก
ติดตั้ง
# Install TensorFlow Dececision Forests.
pip install tensorflow_decision_forests
# Use wurlitzer to capture training logs.
pip install wurlitzer
import tensorflow_decision_forests as tfdf
import os
import numpy as np
import pandas as pd
import tensorflow as tf
import math
import collections
try:
from wurlitzer import sys_pipes
except:
from colabtools.googlelog import CaptureLog as sys_pipes
from IPython.core.magic import register_line_magic
from IPython.display import Javascript
WARNING:root:Failure to load the custom c++ tensorflow ops. This error is likely caused the version of TensorFlow and TensorFlow Decision Forests are not compatible. WARNING:root:TF Parameter Server distributed training not available.
เซลล์โค้ดที่ซ่อนอยู่จะจำกัดความสูงของเอาต์พุตใน colab
# Some of the model training logs can cover the full
# screen if not compressed to a smaller viewport.
# This magic allows setting a max height for a cell.
@register_line_magic
def set_cell_height(size):
display(
Javascript("google.colab.output.setIframeHeight(0, true, {maxHeight: " +
str(size) + "})"))
ฝึก Random Forest ง่ายๆ
เราฝึกอบรมป่าสุ่มเหมือนใน Colab เริ่มต้น :
# Download the dataset
!wget -q https://storage.googleapis.com/download.tensorflow.org/data/palmer_penguins/penguins.csv -O /tmp/penguins.csv
# Load a dataset into a Pandas Dataframe.
dataset_df = pd.read_csv("/tmp/penguins.csv")
# Show the first three examples.
print(dataset_df.head(3))
# Convert the pandas dataframe into a tf dataset.
dataset_tf = tfdf.keras.pd_dataframe_to_tf_dataset(dataset_df, label="species")
# Train the Random Forest
model = tfdf.keras.RandomForestModel(compute_oob_variable_importances=True)
model.fit(x=dataset_tf)
species island bill_length_mm bill_depth_mm flipper_length_mm \ 0 Adelie Torgersen 39.1 18.7 181.0 1 Adelie Torgersen 39.5 17.4 186.0 2 Adelie Torgersen 40.3 18.0 195.0 body_mass_g sex year 0 3750.0 male 2007 1 3800.0 female 2007 2 3250.0 female 2007 /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_decision_forests/keras/core.py:1612: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only features_dataframe = dataframe.drop(label, 1) 6/6 [==============================] - 4s 17ms/step [INFO kernel.cc:736] Start Yggdrasil model training [INFO kernel.cc:737] Collect training examples [INFO kernel.cc:392] Number of batches: 6 [INFO kernel.cc:393] Number of examples: 344 [INFO kernel.cc:759] Dataset: Number of records: 344 Number of columns: 8 Number of columns by type: NUMERICAL: 5 (62.5%) CATEGORICAL: 3 (37.5%) Columns: NUMERICAL: 5 (62.5%) 0: "bill_depth_mm" NUMERICAL num-nas:2 (0.581395%) mean:17.1512 min:13.1 max:21.5 sd:1.9719 1: "bill_length_mm" NUMERICAL num-nas:2 (0.581395%) mean:43.9219 min:32.1 max:59.6 sd:5.4516 2: "body_mass_g" NUMERICAL num-nas:2 (0.581395%) mean:4201.75 min:2700 max:6300 sd:800.781 3: "flipper_length_mm" NUMERICAL num-nas:2 (0.581395%) mean:200.915 min:172 max:231 sd:14.0411 6: "year" NUMERICAL mean:2008.03 min:2007 max:2009 sd:0.817166 CATEGORICAL: 3 (37.5%) 4: "island" CATEGORICAL has-dict vocab-size:4 zero-ood-items most-frequent:"Biscoe" 168 (48.8372%) 5: "sex" CATEGORICAL num-nas:11 (3.19767%) has-dict vocab-size:3 zero-ood-items most-frequent:"male" 168 (50.4505%) 7: "__LABEL" CATEGORICAL integerized vocab-size:4 no-ood-item Terminology: nas: Number of non-available (i.e. missing) values. ood: Out of dictionary. manually-defined: Attribute which type is manually defined by the user i.e. the type was not automatically inferred. tokenized: The attribute value is obtained through tokenization. has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string. vocab-size: Number of unique values. [INFO kernel.cc:762] Configure learner [INFO kernel.cc:787] Training config: learner: "RANDOM_FOREST" features: "bill_depth_mm" features: "bill_length_mm" features: "body_mass_g" features: "flipper_length_mm" features: "island" features: "sex" features: "year" label: "__LABEL" task: CLASSIFICATION [yggdrasil_decision_forests.model.random_forest.proto.random_forest_config] { num_trees: 300 decision_tree { max_depth: 16 min_examples: 5 in_split_min_examples_check: true missing_value_policy: GLOBAL_IMPUTATION allow_na_conditions: false categorical_set_greedy_forward { sampling: 0.1 max_num_items: -1 min_item_frequency: 1 } growing_strategy_local { } categorical { cart { } } num_candidate_attributes_ratio: -1 axis_aligned_split { } internal { sorting_strategy: PRESORTED } } winner_take_all_inference: true compute_oob_performances: true compute_oob_variable_importances: true adapt_bootstrap_size_ratio_for_maximum_training_duration: false } [INFO kernel.cc:790] Deployment config: num_threads: 6 [INFO kernel.cc:817] Train model [INFO random_forest.cc:315] Training random forest on 344 example(s) and 7 feature(s). [INFO random_forest.cc:628] Training of tree 1/300 (tree index:0) done accuracy:0.964286 logloss:1.28727 [INFO random_forest.cc:628] Training of tree 11/300 (tree index:10) done accuracy:0.956268 logloss:0.584301 [INFO random_forest.cc:628] Training of tree 22/300 (tree index:21) done accuracy:0.965116 logloss:0.378823 [INFO random_forest.cc:628] Training of tree 35/300 (tree index:34) done accuracy:0.968023 logloss:0.178185 [INFO random_forest.cc:628] Training of tree 46/300 (tree index:45) done accuracy:0.973837 logloss:0.170304 [INFO random_forest.cc:628] Training of tree 58/300 (tree index:57) done accuracy:0.973837 logloss:0.171223 [INFO random_forest.cc:628] Training of tree 70/300 (tree index:69) done accuracy:0.979651 logloss:0.169564 [INFO random_forest.cc:628] Training of tree 83/300 (tree index:82) done accuracy:0.976744 logloss:0.17074 [INFO random_forest.cc:628] Training of tree 96/300 (tree index:95) done accuracy:0.976744 logloss:0.0736925 [INFO random_forest.cc:628] Training of tree 106/300 (tree index:105) done accuracy:0.976744 logloss:0.0748649 [INFO random_forest.cc:628] Training of tree 117/300 (tree index:116) done accuracy:0.976744 logloss:0.074671 [INFO random_forest.cc:628] Training of tree 130/300 (tree index:129) done accuracy:0.976744 logloss:0.0736275 [INFO random_forest.cc:628] Training of tree 140/300 (tree index:139) done accuracy:0.976744 logloss:0.0727718 [INFO random_forest.cc:628] Training of tree 152/300 (tree index:151) done accuracy:0.976744 logloss:0.0715068 [INFO random_forest.cc:628] Training of tree 162/300 (tree index:161) done accuracy:0.976744 logloss:0.0708994 [INFO random_forest.cc:628] Training of tree 173/300 (tree index:172) done accuracy:0.976744 logloss:0.069447 [INFO random_forest.cc:628] Training of tree 184/300 (tree index:183) done accuracy:0.976744 logloss:0.0695926 [INFO random_forest.cc:628] Training of tree 195/300 (tree index:194) done accuracy:0.976744 logloss:0.0690138 [INFO random_forest.cc:628] Training of tree 205/300 (tree index:204) done accuracy:0.976744 logloss:0.0694597 [INFO random_forest.cc:628] Training of tree 217/300 (tree index:216) done accuracy:0.976744 logloss:0.068122 [INFO random_forest.cc:628] Training of tree 229/300 (tree index:228) done accuracy:0.976744 logloss:0.0687641 [INFO random_forest.cc:628] Training of tree 239/300 (tree index:238) done accuracy:0.976744 logloss:0.067988 [INFO random_forest.cc:628] Training of tree 250/300 (tree index:249) done accuracy:0.976744 logloss:0.0690187 [INFO random_forest.cc:628] Training of tree 260/300 (tree index:259) done accuracy:0.976744 logloss:0.0690134 [INFO random_forest.cc:628] Training of tree 270/300 (tree index:269) done accuracy:0.976744 logloss:0.0689877 [INFO random_forest.cc:628] Training of tree 280/300 (tree index:279) done accuracy:0.976744 logloss:0.0689845 [INFO random_forest.cc:628] Training of tree 290/300 (tree index:288) done accuracy:0.976744 logloss:0.0690742 [INFO random_forest.cc:628] Training of tree 300/300 (tree index:299) done accuracy:0.976744 logloss:0.068949 [INFO random_forest.cc:696] Final OOB metrics: accuracy:0.976744 logloss:0.068949 [INFO kernel.cc:828] Export model in log directory: /tmp/tmpoqki9pfl [INFO kernel.cc:836] Save model in resources [INFO kernel.cc:988] Loading model from path [INFO decision_forest.cc:590] Model loaded with 300 root(s), 5080 node(s), and 7 input feature(s). [INFO abstract_model.cc:993] Engine "RandomForestGeneric" built [INFO kernel.cc:848] Use fast generic engine <keras.callbacks.History at 0x7f09eaa9cb90>
หมายเหตุ compute_oob_variable_importances=True
Hyper-พารามิเตอร์ในตัวสร้างแบบจำลอง ความคิดเห็นนี้คำนวณความสำคัญของตัวแปร Out-of-bag (OOB) ระหว่างการฝึก นี้เป็นที่นิยม สำคัญตัวแปรการเปลี่ยนแปลง สำหรับรุ่นที่สุ่มป่า
การคำนวณความสำคัญของตัวแปร OOB จะไม่ส่งผลกระทบต่อโมเดลสุดท้าย แต่จะทำให้การฝึกอบรมบนชุดข้อมูลขนาดใหญ่ช้าลง
ตรวจสอบสรุปแบบจำลอง:
%set_cell_height 300
model.summary()
<IPython.core.display.Javascript object> Model: "random_forest_model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= ================================================================= Total params: 1 Trainable params: 0 Non-trainable params: 1 _________________________________________________________________ Type: "RANDOM_FOREST" Task: CLASSIFICATION Label: "__LABEL" Input Features (7): bill_depth_mm bill_length_mm body_mass_g flipper_length_mm island sex year No weights Variable Importance: MEAN_DECREASE_IN_ACCURACY: 1. "bill_length_mm" 0.151163 ################ 2. "island" 0.008721 # 3. "bill_depth_mm" 0.000000 4. "body_mass_g" 0.000000 5. "sex" 0.000000 6. "year" 0.000000 7. "flipper_length_mm" -0.002907 Variable Importance: MEAN_DECREASE_IN_AP_1_VS_OTHERS: 1. "bill_length_mm" 0.083305 ################ 2. "island" 0.007664 # 3. "flipper_length_mm" 0.003400 4. "bill_depth_mm" 0.002741 5. "body_mass_g" 0.000722 6. "sex" 0.000644 7. "year" 0.000000 Variable Importance: MEAN_DECREASE_IN_AP_2_VS_OTHERS: 1. "bill_length_mm" 0.508510 ################ 2. "island" 0.023487 3. "bill_depth_mm" 0.007744 4. "flipper_length_mm" 0.006008 5. "body_mass_g" 0.003017 6. "sex" 0.001537 7. "year" -0.000245 Variable Importance: MEAN_DECREASE_IN_AP_3_VS_OTHERS: 1. "island" 0.002192 ################ 2. "bill_length_mm" 0.001572 ############ 3. "bill_depth_mm" 0.000497 ####### 4. "sex" 0.000000 #### 5. "year" 0.000000 #### 6. "body_mass_g" -0.000053 #### 7. "flipper_length_mm" -0.000890 Variable Importance: MEAN_DECREASE_IN_AUC_1_VS_OTHERS: 1. "bill_length_mm" 0.071306 ################ 2. "island" 0.007299 # 3. "flipper_length_mm" 0.004506 # 4. "bill_depth_mm" 0.002124 5. "body_mass_g" 0.000548 6. "sex" 0.000480 7. "year" 0.000000 Variable Importance: MEAN_DECREASE_IN_AUC_2_VS_OTHERS: 1. "bill_length_mm" 0.108642 ################ 2. "island" 0.014493 ## 3. "bill_depth_mm" 0.007406 # 4. "flipper_length_mm" 0.005195 5. "body_mass_g" 0.001012 6. "sex" 0.000480 7. "year" -0.000053 Variable Importance: MEAN_DECREASE_IN_AUC_3_VS_OTHERS: 1. "island" 0.002126 ################ 2. "bill_length_mm" 0.001393 ########### 3. "bill_depth_mm" 0.000293 ##### 4. "sex" 0.000000 ### 5. "year" 0.000000 ### 6. "body_mass_g" -0.000037 ### 7. "flipper_length_mm" -0.000550 Variable Importance: MEAN_DECREASE_IN_PRAUC_1_VS_OTHERS: 1. "bill_length_mm" 0.083122 ################ 2. "island" 0.010887 ## 3. "flipper_length_mm" 0.003425 4. "bill_depth_mm" 0.002731 5. "body_mass_g" 0.000719 6. "sex" 0.000641 7. "year" 0.000000 Variable Importance: MEAN_DECREASE_IN_PRAUC_2_VS_OTHERS: 1. "bill_length_mm" 0.497611 ################ 2. "island" 0.024045 3. "bill_depth_mm" 0.007734 4. "flipper_length_mm" 0.006017 5. "body_mass_g" 0.003000 6. "sex" 0.001528 7. "year" -0.000243 Variable Importance: MEAN_DECREASE_IN_PRAUC_3_VS_OTHERS: 1. "island" 0.002187 ################ 2. "bill_length_mm" 0.001568 ############ 3. "bill_depth_mm" 0.000495 ####### 4. "sex" 0.000000 #### 5. "year" 0.000000 #### 6. "body_mass_g" -0.000053 #### 7. "flipper_length_mm" -0.000886 Variable Importance: MEAN_MIN_DEPTH: 1. "__LABEL" 3.479602 ################ 2. "year" 3.463891 ############### 3. "sex" 3.430498 ############### 4. "body_mass_g" 2.898112 ########### 5. "island" 2.388925 ######## 6. "bill_depth_mm" 2.336100 ####### 7. "bill_length_mm" 1.282960 8. "flipper_length_mm" 1.270079 Variable Importance: NUM_AS_ROOT: 1. "flipper_length_mm" 157.000000 ################ 2. "bill_length_mm" 76.000000 ####### 3. "bill_depth_mm" 52.000000 ##### 4. "island" 12.000000 5. "body_mass_g" 3.000000 Variable Importance: NUM_NODES: 1. "bill_length_mm" 778.000000 ################ 2. "bill_depth_mm" 463.000000 ######### 3. "flipper_length_mm" 414.000000 ######## 4. "island" 342.000000 ###### 5. "body_mass_g" 338.000000 ###### 6. "sex" 36.000000 7. "year" 19.000000 Variable Importance: SUM_SCORE: 1. "bill_length_mm" 36515.793787 ################ 2. "flipper_length_mm" 35120.434174 ############### 3. "island" 14669.408395 ###### 4. "bill_depth_mm" 14515.446617 ###### 5. "body_mass_g" 3485.330881 # 6. "sex" 354.201073 7. "year" 49.737758 Winner take all: true Out-of-bag evaluation: accuracy:0.976744 logloss:0.068949 Number of trees: 300 Total number of nodes: 5080 Number of nodes by tree: Count: 300 Average: 16.9333 StdDev: 3.10197 Min: 11 Max: 31 Ignored: 0 ---------------------------------------------- [ 11, 12) 6 2.00% 2.00% # [ 12, 13) 0 0.00% 2.00% [ 13, 14) 46 15.33% 17.33% ##### [ 14, 15) 0 0.00% 17.33% [ 15, 16) 70 23.33% 40.67% ######## [ 16, 17) 0 0.00% 40.67% [ 17, 18) 84 28.00% 68.67% ########## [ 18, 19) 0 0.00% 68.67% [ 19, 20) 46 15.33% 84.00% ##### [ 20, 21) 0 0.00% 84.00% [ 21, 22) 30 10.00% 94.00% #### [ 22, 23) 0 0.00% 94.00% [ 23, 24) 13 4.33% 98.33% ## [ 24, 25) 0 0.00% 98.33% [ 25, 26) 2 0.67% 99.00% [ 26, 27) 0 0.00% 99.00% [ 27, 28) 2 0.67% 99.67% [ 28, 29) 0 0.00% 99.67% [ 29, 30) 0 0.00% 99.67% [ 30, 31] 1 0.33% 100.00% Depth by leafs: Count: 2690 Average: 3.53271 StdDev: 1.06789 Min: 2 Max: 7 Ignored: 0 ---------------------------------------------- [ 2, 3) 545 20.26% 20.26% ###### [ 3, 4) 747 27.77% 48.03% ######## [ 4, 5) 888 33.01% 81.04% ########## [ 5, 6) 444 16.51% 97.55% ##### [ 6, 7) 62 2.30% 99.85% # [ 7, 7] 4 0.15% 100.00% Number of training obs by leaf: Count: 2690 Average: 38.3643 StdDev: 44.8651 Min: 5 Max: 155 Ignored: 0 ---------------------------------------------- [ 5, 12) 1474 54.80% 54.80% ########## [ 12, 20) 124 4.61% 59.41% # [ 20, 27) 48 1.78% 61.19% [ 27, 35) 74 2.75% 63.94% # [ 35, 42) 58 2.16% 66.10% [ 42, 50) 85 3.16% 69.26% # [ 50, 57) 96 3.57% 72.83% # [ 57, 65) 87 3.23% 76.06% # [ 65, 72) 49 1.82% 77.88% [ 72, 80) 23 0.86% 78.74% [ 80, 88) 30 1.12% 79.85% [ 88, 95) 23 0.86% 80.71% [ 95, 103) 42 1.56% 82.27% [ 103, 110) 62 2.30% 84.57% [ 110, 118) 115 4.28% 88.85% # [ 118, 125) 115 4.28% 93.12% # [ 125, 133) 98 3.64% 96.77% # [ 133, 140) 49 1.82% 98.59% [ 140, 148) 31 1.15% 99.74% [ 148, 155] 7 0.26% 100.00% Attribute in nodes: 778 : bill_length_mm [NUMERICAL] 463 : bill_depth_mm [NUMERICAL] 414 : flipper_length_mm [NUMERICAL] 342 : island [CATEGORICAL] 338 : body_mass_g [NUMERICAL] 36 : sex [CATEGORICAL] 19 : year [NUMERICAL] Attribute in nodes with depth <= 0: 157 : flipper_length_mm [NUMERICAL] 76 : bill_length_mm [NUMERICAL] 52 : bill_depth_mm [NUMERICAL] 12 : island [CATEGORICAL] 3 : body_mass_g [NUMERICAL] Attribute in nodes with depth <= 1: 250 : bill_length_mm [NUMERICAL] 244 : flipper_length_mm [NUMERICAL] 183 : bill_depth_mm [NUMERICAL] 170 : island [CATEGORICAL] 53 : body_mass_g [NUMERICAL] Attribute in nodes with depth <= 2: 462 : bill_length_mm [NUMERICAL] 320 : flipper_length_mm [NUMERICAL] 310 : bill_depth_mm [NUMERICAL] 287 : island [CATEGORICAL] 162 : body_mass_g [NUMERICAL] 9 : sex [CATEGORICAL] 5 : year [NUMERICAL] Attribute in nodes with depth <= 3: 669 : bill_length_mm [NUMERICAL] 410 : bill_depth_mm [NUMERICAL] 383 : flipper_length_mm [NUMERICAL] 328 : island [CATEGORICAL] 286 : body_mass_g [NUMERICAL] 32 : sex [CATEGORICAL] 10 : year [NUMERICAL] Attribute in nodes with depth <= 5: 778 : bill_length_mm [NUMERICAL] 462 : bill_depth_mm [NUMERICAL] 413 : flipper_length_mm [NUMERICAL] 342 : island [CATEGORICAL] 338 : body_mass_g [NUMERICAL] 36 : sex [CATEGORICAL] 19 : year [NUMERICAL] Condition type in nodes: 2012 : HigherCondition 378 : ContainsBitmapCondition Condition type in nodes with depth <= 0: 288 : HigherCondition 12 : ContainsBitmapCondition Condition type in nodes with depth <= 1: 730 : HigherCondition 170 : ContainsBitmapCondition Condition type in nodes with depth <= 2: 1259 : HigherCondition 296 : ContainsBitmapCondition Condition type in nodes with depth <= 3: 1758 : HigherCondition 360 : ContainsBitmapCondition Condition type in nodes with depth <= 5: 2010 : HigherCondition 378 : ContainsBitmapCondition Node format: NOT_SET Training OOB: trees: 1, Out-of-bag evaluation: accuracy:0.964286 logloss:1.28727 trees: 11, Out-of-bag evaluation: accuracy:0.956268 logloss:0.584301 trees: 22, Out-of-bag evaluation: accuracy:0.965116 logloss:0.378823 trees: 35, Out-of-bag evaluation: accuracy:0.968023 logloss:0.178185 trees: 46, Out-of-bag evaluation: accuracy:0.973837 logloss:0.170304 trees: 58, Out-of-bag evaluation: accuracy:0.973837 logloss:0.171223 trees: 70, Out-of-bag evaluation: accuracy:0.979651 logloss:0.169564 trees: 83, Out-of-bag evaluation: accuracy:0.976744 logloss:0.17074 trees: 96, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0736925 trees: 106, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0748649 trees: 117, Out-of-bag evaluation: accuracy:0.976744 logloss:0.074671 trees: 130, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0736275 trees: 140, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0727718 trees: 152, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0715068 trees: 162, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0708994 trees: 173, Out-of-bag evaluation: accuracy:0.976744 logloss:0.069447 trees: 184, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0695926 trees: 195, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0690138 trees: 205, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0694597 trees: 217, Out-of-bag evaluation: accuracy:0.976744 logloss:0.068122 trees: 229, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0687641 trees: 239, Out-of-bag evaluation: accuracy:0.976744 logloss:0.067988 trees: 250, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0690187 trees: 260, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0690134 trees: 270, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0689877 trees: 280, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0689845 trees: 290, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0690742 trees: 300, Out-of-bag evaluation: accuracy:0.976744 logloss:0.068949
หมายเหตุ importances ตัวแปรหลายที่มีชื่อ MEAN_DECREASE_IN_*
*
วางโมเดล
ถัดไป พล็อตโมเดล
Random Forest เป็นโมเดลขนาดใหญ่ (โมเดลนี้มีต้นไม้ 300 ต้นและโหนดประมาณ 5,000 โหนด ดูข้อมูลสรุปด้านบน) ดังนั้น ให้พล็อตเฉพาะต้นไม้ต้นแรก และจำกัดโหนดไว้ที่ความลึก 3
tfdf.model_plotter.plot_model_in_colab(model, tree_idx=0, max_depth=3)
ตรวจสอบโครงสร้างโมเดล
โครงสร้างและรูปแบบข้อมูล meta สามารถใช้ได้ผ่านการตรวจสอบที่สร้างขึ้นโดย make_inspector()
inspector = model.make_inspector()
สำหรับโมเดลของเรา ฟิลด์ตัวตรวจสอบที่มีอยู่คือ:
[field for field in dir(inspector) if not field.startswith("_")]
['MODEL_NAME', 'dataspec', 'evaluation', 'export_to_tensorboard', 'extract_all_trees', 'extract_tree', 'features', 'iterate_on_nodes', 'label', 'label_classes', 'model_type', 'num_trees', 'objective', 'specialized_header', 'task', 'training_logs', 'variable_importances', 'winner_take_all_inference']
อย่าลืมดู API การอ้างอิง หรือการใช้งาน ?
สำหรับเอกสารในตัว
?inspector.model_type
ข้อมูลเมตาของโมเดลบางส่วน:
print("Model type:", inspector.model_type())
print("Number of trees:", inspector.num_trees())
print("Objective:", inspector.objective())
print("Input features:", inspector.features())
Model type: RANDOM_FOREST Number of trees: 300 Objective: Classification(label=__LABEL, class=None, num_classes=3) Input features: ["bill_depth_mm" (1; #0), "bill_length_mm" (1; #1), "body_mass_g" (1; #2), "flipper_length_mm" (1; #3), "island" (4; #4), "sex" (4; #5), "year" (1; #6)]
evaluate()
คือการประเมินผลของรูปแบบการคำนวณระหว่างการฝึกอบรม ชุดข้อมูลที่ใช้สำหรับการประเมินนี้ขึ้นอยู่กับอัลกอริทึม ตัวอย่างเช่น อาจเป็นชุดข้อมูลการตรวจสอบความถูกต้อง หรือชุดข้อมูล out-of-bag-dataset
inspector.evaluation()
Evaluation(num_examples=344, accuracy=0.9767441860465116, loss=0.06894904488784283, rmse=None, ndcg=None, aucs=None)
ความสำคัญของตัวแปรคือ:
print(f"Available variable importances:")
for importance in inspector.variable_importances().keys():
print("\t", importance)
Available variable importances: MEAN_DECREASE_IN_AUC_3_VS_OTHERS NUM_AS_ROOT MEAN_DECREASE_IN_AUC_2_VS_OTHERS MEAN_DECREASE_IN_AP_2_VS_OTHERS MEAN_DECREASE_IN_ACCURACY SUM_SCORE MEAN_DECREASE_IN_PRAUC_2_VS_OTHERS MEAN_DECREASE_IN_PRAUC_3_VS_OTHERS MEAN_DECREASE_IN_AP_3_VS_OTHERS MEAN_DECREASE_IN_AUC_1_VS_OTHERS MEAN_MIN_DEPTH MEAN_DECREASE_IN_PRAUC_1_VS_OTHERS NUM_NODES MEAN_DECREASE_IN_AP_1_VS_OTHERS
ตัวแปรที่มีความสำคัญต่างกันมีความหมายต่างกัน ตัวอย่างเช่นภาพยนตร์ที่มีการลดลงของค่าเฉลี่ยใน AUC ของ 0.05
หมายความว่าเอาคุณลักษณะนี้จากชุดข้อมูลการฝึกอบรมจะลด / เจ็บ AUC 5%
# Mean decrease in AUC of the class 1 vs the others.
inspector.variable_importances()["MEAN_DECREASE_IN_AUC_1_VS_OTHERS"]
[("bill_length_mm" (1; #1), 0.0713061951754389), ("island" (4; #4), 0.007298519736842035), ("flipper_length_mm" (1; #3), 0.004505893640351366), ("bill_depth_mm" (1; #0), 0.0021244517543865804), ("body_mass_g" (1; #2), 0.0005482456140351033), ("sex" (4; #5), 0.00047971491228060437), ("year" (1; #6), 0.0)]
สุดท้าย เข้าถึงโครงสร้างต้นไม้จริง:
inspector.extract_tree(tree_idx=0)
Tree(NonLeafNode(condition=(bill_length_mm >= 43.25; miss=True), pos_child=NonLeafNode(condition=(island in ['Biscoe']; miss=True), pos_child=NonLeafNode(condition=(bill_depth_mm >= 17.225584030151367; miss=False), pos_child=LeafNode(value=ProbabilityValue([0.16666666666666666, 0.0, 0.8333333333333334],n=6.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 0.0, 1.0],n=104.0)), value=ProbabilityValue([0.00909090909090909, 0.0, 0.990909090909091],n=110.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 1.0, 0.0],n=61.0)), value=ProbabilityValue([0.005847953216374269, 0.3567251461988304, 0.6374269005847953],n=171.0)), neg_child=NonLeafNode(condition=(bill_depth_mm >= 15.100000381469727; miss=True), pos_child=NonLeafNode(condition=(flipper_length_mm >= 187.5; miss=True), pos_child=LeafNode(value=ProbabilityValue([1.0, 0.0, 0.0],n=104.0)), neg_child=NonLeafNode(condition=(bill_length_mm >= 42.30000305175781; miss=True), pos_child=LeafNode(value=ProbabilityValue([0.0, 1.0, 0.0],n=5.0)), neg_child=NonLeafNode(condition=(bill_length_mm >= 40.55000305175781; miss=True), pos_child=LeafNode(value=ProbabilityValue([0.8, 0.2, 0.0],n=5.0)), neg_child=LeafNode(value=ProbabilityValue([1.0, 0.0, 0.0],n=53.0)), value=ProbabilityValue([0.9827586206896551, 0.017241379310344827, 0.0],n=58.0)), value=ProbabilityValue([0.9047619047619048, 0.09523809523809523, 0.0],n=63.0)), value=ProbabilityValue([0.9640718562874252, 0.03592814371257485, 0.0],n=167.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 0.0, 1.0],n=6.0)), value=ProbabilityValue([0.930635838150289, 0.03468208092485549, 0.03468208092485549],n=173.0)), value=ProbabilityValue([0.47093023255813954, 0.19476744186046513, 0.33430232558139533],n=344.0)),label_classes={self.label_classes})
การสกัดต้นไม้ไม่ได้ผล ถ้าความเร็วเป็นสิ่งที่สำคัญการตรวจสอบรูปแบบสามารถทำได้ด้วย iterate_on_nodes()
วิธีการแทน วิธีนี้เป็นตัววนซ้ำแบบสั่งจองล่วงหน้าแบบ Depth First Pre-order บนโหนดทั้งหมดของโมเดล
สำหรับตัวอย่างต่อไปนี้จะคำนวณจำนวนครั้งที่แต่ละคุณลักษณะถูกใช้ (นี่คือชนิดของตัวแปรโครงสร้างที่สำคัญ):
# number_of_use[F] will be the number of node using feature F in its condition.
number_of_use = collections.defaultdict(lambda: 0)
# Iterate over all the nodes in a Depth First Pre-order traversals.
for node_iter in inspector.iterate_on_nodes():
if not isinstance(node_iter.node, tfdf.py_tree.node.NonLeafNode):
# Skip the leaf nodes
continue
# Iterate over all the features used in the condition.
# By default, models are "oblique" i.e. each node tests a single feature.
for feature in node_iter.node.condition.features():
number_of_use[feature] += 1
print("Number of condition nodes per features:")
for feature, count in number_of_use.items():
print("\t", feature.name, ":", count)
Number of condition nodes per features: bill_length_mm : 778 bill_depth_mm : 463 flipper_length_mm : 414 island : 342 body_mass_g : 338 year : 19 sex : 36
การสร้างแบบจำลองด้วยมือ
ในส่วนนี้ คุณจะต้องสร้างโมเดล Random Forest ขนาดเล็กด้วยมือ เพื่อให้ง่ายเป็นพิเศษ แบบจำลองจะมีต้นไม้อย่างง่ายเพียงต้นเดียว:
3 label classes: Red, blue and green.
2 features: f1 (numerical) and f2 (string categorical)
f1>=1.5
├─(pos)─ f2 in ["cat","dog"]
│ ├─(pos)─ value: [0.8, 0.1, 0.1]
│ └─(neg)─ value: [0.1, 0.8, 0.1]
└─(neg)─ value: [0.1, 0.1, 0.8]
# Create the model builder
builder = tfdf.builder.RandomForestBuilder(
path="/tmp/manual_model",
objective=tfdf.py_tree.objective.ClassificationObjective(
label="color", classes=["red", "blue", "green"]))
ต้นไม้แต่ละต้นจะถูกเพิ่มทีละต้น
# So alias
Tree = tfdf.py_tree.tree.Tree
SimpleColumnSpec = tfdf.py_tree.dataspec.SimpleColumnSpec
ColumnType = tfdf.py_tree.dataspec.ColumnType
# Nodes
NonLeafNode = tfdf.py_tree.node.NonLeafNode
LeafNode = tfdf.py_tree.node.LeafNode
# Conditions
NumericalHigherThanCondition = tfdf.py_tree.condition.NumericalHigherThanCondition
CategoricalIsInCondition = tfdf.py_tree.condition.CategoricalIsInCondition
# Leaf values
ProbabilityValue = tfdf.py_tree.value.ProbabilityValue
builder.add_tree(
Tree(
NonLeafNode(
condition=NumericalHigherThanCondition(
feature=SimpleColumnSpec(name="f1", type=ColumnType.NUMERICAL),
threshold=1.5,
missing_evaluation=False),
pos_child=NonLeafNode(
condition=CategoricalIsInCondition(
feature=SimpleColumnSpec(name="f2",type=ColumnType.CATEGORICAL),
mask=["cat", "dog"],
missing_evaluation=False),
pos_child=LeafNode(value=ProbabilityValue(probability=[0.8, 0.1, 0.1], num_examples=10)),
neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.8, 0.1], num_examples=20))),
neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.1, 0.8], num_examples=30)))))
สรุปการเขียนต้นไม้
builder.close()
[INFO kernel.cc:988] Loading model from path [INFO decision_forest.cc:590] Model loaded with 1 root(s), 5 node(s), and 2 input feature(s). [INFO kernel.cc:848] Use fast generic engine 2021-11-08 12:19:14.555155: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. INFO:tensorflow:Assets written to: /tmp/manual_model/assets INFO:tensorflow:Assets written to: /tmp/manual_model/assets
ตอนนี้คุณสามารถเปิดโมเดลเป็นโมเดล keras ปกติ และทำการคาดการณ์:
manual_model = tf.keras.models.load_model("/tmp/manual_model")
[INFO kernel.cc:988] Loading model from path [INFO decision_forest.cc:590] Model loaded with 1 root(s), 5 node(s), and 2 input feature(s). [INFO kernel.cc:848] Use fast generic engine
examples = tf.data.Dataset.from_tensor_slices({
"f1": [1.0, 2.0, 3.0],
"f2": ["cat", "cat", "bird"]
}).batch(2)
predictions = manual_model.predict(examples)
print("predictions:\n",predictions)
predictions: [[0.1 0.1 0.8] [0.8 0.1 0.1] [0.1 0.8 0.1]]
เข้าถึงโครงสร้าง:
yggdrasil_model_path = manual_model.yggdrasil_model_path_tensor().numpy().decode("utf-8")
print("yggdrasil_model_path:",yggdrasil_model_path)
inspector = tfdf.inspector.make_inspector(yggdrasil_model_path)
print("Input features:", inspector.features())
yggdrasil_model_path: /tmp/manual_model/assets/ Input features: ["f1" (1; #1), "f2" (4; #2)]
และแน่นอน คุณสามารถพล็อตโมเดลที่สร้างด้วยตนเองนี้ได้:
tfdf.model_plotter.plot_model_in_colab(manual_model)