View source on GitHub |
User entry point for the TensorFlow Decision Forest API.
Basic usage:
# Imports
import tensorflow_decision_forests as tfdf
import pandas as pd
from wurlitzer import sys_pipes
# Load a dataset into a Pandas Dataframe.
dataset_df = pd.read_csv("/tmp/penguins.csv")
# Display the first 3 examples.
dataset_df.head(3)
# Convert the Pandas dataframe to a tf dataset
tf_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(dataset_df,label="species")
model = tfdf.keras.RandomForestModel()
with sys_pipes():
model.fit(tf_dataset)
# Note: The `sys_pipes` part is to display logs during training.
# Evaluate model.
model.compile(metrics=["accuracy"])
model.evaluate(...test_dataset...)
# Save model.
model.save("/tmp/my_saved_model")
# ...
# Load a model: it loads as a generic keras model.
loaded_model = tf_keras.models.load_model("/tmp/my_saved_model")
Modules
builder
module: Model builder.
check_version
module: Check that version of TensorFlow is compatible with TF-DF.
inspector
module: Model inspector.
keras
module: Decision Forest in a Keras Model.
model_plotter
module: Plotting of decision forest models.
py_tree
module: Decision trees stored as python objects.
tuner
module: Specification of the parameters of a tuner.
Other Members | |
---|---|
version |
'1.11.0'
|
compatible_tf_versions |
['2.18.0']
|