Gotowe estymatory

Zobacz na TensorFlow.org Uruchom w Google Colab Wyświetl źródło na GitHub Pobierz notatnik

Ten samouczek pokazuje, jak rozwiązać problem klasyfikacji tęczówki w TensorFlow za pomocą estymatorów. Estymator to starsza, wysokopoziomowa reprezentacja kompletnego modelu w TensorFlow. Aby uzyskać więcej informacji, zobacz Estymatory .

Najpierw najważniejsze rzeczy

Aby rozpocząć, najpierw zaimportujesz TensorFlow i kilka potrzebnych bibliotek.

import tensorflow as tf

import pandas as pd

Zestaw danych

Przykładowy program w tym dokumencie buduje i testuje model, który klasyfikuje kwiaty irysa na trzy różne gatunki na podstawie wielkości ich działek i płatków .

Wytrenujesz model za pomocą zestawu danych Iris. Zestaw danych Iris zawiera cztery cechy i jedną etykietę . Cztery cechy identyfikują następujące cechy botaniczne poszczególnych kwiatów irysa:

  • długość kielicha
  • sepal szerokość
  • długość płatka
  • szerokość płatka

Na podstawie tych informacji możesz zdefiniować kilka stałych przydatnych do analizowania danych:

CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'Species']
SPECIES = ['Setosa', 'Versicolor', 'Virginica']

Następnie pobierz i przeanalizuj zestaw danych Iris za pomocą Keras i Pandas. Pamiętaj, że przechowujesz różne zestawy danych do trenowania i testowania.

train_path = tf.keras.utils.get_file(
    "iris_training.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv")
test_path = tf.keras.utils.get_file(
    "iris_test.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv")

train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv
16384/2194 [================================================================================================================================================================================================================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv
16384/573 [=========================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================] - 0s 0us/step

Możesz sprawdzić swoje dane, aby zobaczyć, że masz cztery zmiennoprzecinkowe kolumny funkcji i jedną etykietę int32.

train.head()

Dla każdego zestawu danych podziel etykiety, które model będzie wytrenowany do przewidywania.

train_y = train.pop('Species')
test_y = test.pop('Species')

# The label column has now been removed from the features.
train.head()

Przegląd programowania z estymatorami

Teraz, gdy masz już skonfigurowane dane, możesz zdefiniować model za pomocą estymatora TensorFlow. Estimator to dowolna klasa wywodząca się z tf.estimator.Estimator . TensorFlow udostępnia kolekcję tf.estimator (na przykład LinearRegressor ) do implementacji typowych algorytmów ML. Poza tym możesz napisać własne niestandardowe estymatory . Zaleca się korzystanie z gotowych estymatorów na początku.

Aby napisać program TensorFlow w oparciu o gotowe Estymatory, musisz wykonać następujące zadania:

  • Utwórz jedną lub więcej funkcji wejściowych.
  • Zdefiniuj kolumny elementów modelu.
  • Utwórz wystąpienie estymatora, określając kolumny funkcji i różne hiperparametry.
  • Wywołaj co najmniej jedną metodę w obiekcie Estimator, przekazując odpowiednią funkcję wejściową jako źródło danych.

Zobaczmy, jak te zadania są realizowane dla klasyfikacji tęczówki.

Utwórz funkcje wejściowe

Musisz utworzyć funkcje wejściowe, aby dostarczać dane do uczenia, oceny i przewidywania.

Funkcja wejściowa to funkcja, która zwraca obiekt tf.data.Dataset , który wyprowadza następującą dwuelementową krotkę:

  • features - Słownik Pythona, w którym:
    • Każdy klucz to nazwa funkcji.
    • Każda wartość jest tablicą zawierającą wszystkie wartości tej funkcji.
  • label — tablica zawierająca wartości etykiety dla każdego przykładu.

Aby zademonstrować format funkcji input, oto prosta implementacja:

def input_evaluation_set():
    features = {'SepalLength': np.array([6.4, 5.0]),
                'SepalWidth':  np.array([2.8, 2.3]),
                'PetalLength': np.array([5.6, 3.3]),
                'PetalWidth':  np.array([2.2, 1.0])}
    labels = np.array([2, 1])
    return features, labels

Twoja funkcja wejściowa może generować słownik features i listę label w dowolny sposób. Zaleca się jednak korzystanie z interfejsu API Dataset firmy TensorFlow, który może analizować wszystkie rodzaje danych.

Interfejs Dataset API poradzi sobie z wieloma typowymi przypadkami. Na przykład, korzystając z interfejsu Dataset API, możesz łatwo odczytywać równolegle rekordy z dużej kolekcji plików i łączyć je w jeden strumień.

Aby uprościć ten przykład, załadujesz dane za pomocą pandas i zbudujesz potok wejściowy z tych danych w pamięci:

def input_fn(features, labels, training=True, batch_size=256):
    """An input function for training or evaluating"""
    # Convert the inputs to a Dataset.
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

    # Shuffle and repeat if you are in training mode.
    if training:
        dataset = dataset.shuffle(1000).repeat()

    return dataset.batch(batch_size)

Zdefiniuj kolumny funkcji

Kolumna cech to obiekt opisujący sposób, w jaki model powinien wykorzystywać surowe dane wejściowe ze słownika cech. Tworząc model estymatora, przekazujesz mu listę kolumn funkcji, które opisują wszystkie funkcje, których ma używać model. Moduł tf.feature_column udostępnia wiele opcji reprezentacji danych w modelu.

W przypadku Iris te 4 surowe funkcje są wartościami liczbowymi, więc utworzysz listę kolumn funkcji, aby poinformować model Estimator, aby reprezentował każdą z czterech funkcji jako 32-bitowe wartości zmiennoprzecinkowe. Dlatego kod do utworzenia kolumny funkcji to:

# Feature columns describe how to use the input.
my_feature_columns = []
for key in train.keys():
    my_feature_columns.append(tf.feature_column.numeric_column(key=key))

Kolumny funkcji mogą być znacznie bardziej wyrafinowane niż te pokazane tutaj. Więcej informacji o kolumnach funkcji można znaleźć w tym przewodniku .

Teraz, gdy masz już opis, w jaki sposób model ma reprezentować surowe cechy, możesz zbudować estymator.

Utwórz estymator

Problem tęczówki to klasyczny problem klasyfikacyjny. Na szczęście TensorFlow udostępnia kilka gotowych estymatorów klasyfikatorów, w tym:

W przypadku problemu tęczówki najlepszym wyborem wydaje się tf.estimator.DNNClassifier . Oto jak utworzyłeś wystąpienie tego estymatora:

# Build a DNN with 2 hidden layers with 30 and 10 hidden nodes each.
classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    # Two hidden layers of 30 and 10 nodes respectively.
    hidden_units=[30, 10],
    # The model must choose between 3 classes.
    n_classes=3)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpxdgumb2t
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpxdgumb2t', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

Trenuj, oceniaj i prognozuj

Teraz, gdy masz już obiekt Estimator, możesz wywołać metody, aby wykonać następujące czynności:

  • Trenuj modelkę.
  • Oceń wyszkolony model.
  • Użyj wytrenowanego modelu, aby dokonać prognoz.

Trenuj modelkę

Trenuj model, wywołując metodę train estymatora w następujący sposób:

# Train the Model.
classifier.train(
    input_fn=lambda: input_fn(train, train_y, training=True),
    steps=5000)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:397: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/adagrad.py:84: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpxdgumb2t/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 1.6787335, step = 0
INFO:tensorflow:global_step/sec: 305.625
INFO:tensorflow:loss = 1.1945828, step = 100 (0.328 sec)
INFO:tensorflow:global_step/sec: 375.48
INFO:tensorflow:loss = 1.0221117, step = 200 (0.266 sec)
INFO:tensorflow:global_step/sec: 376.21
INFO:tensorflow:loss = 0.9240805, step = 300 (0.266 sec)
INFO:tensorflow:global_step/sec: 377.968
INFO:tensorflow:loss = 0.85917354, step = 400 (0.265 sec)
INFO:tensorflow:global_step/sec: 376.297
INFO:tensorflow:loss = 0.81545967, step = 500 (0.265 sec)
INFO:tensorflow:global_step/sec: 367.549
INFO:tensorflow:loss = 0.7771524, step = 600 (0.272 sec)
INFO:tensorflow:global_step/sec: 378.887
INFO:tensorflow:loss = 0.74371505, step = 700 (0.264 sec)
INFO:tensorflow:global_step/sec: 379.26
INFO:tensorflow:loss = 0.717993, step = 800 (0.264 sec)
INFO:tensorflow:global_step/sec: 370.102
INFO:tensorflow:loss = 0.6952705, step = 900 (0.270 sec)
INFO:tensorflow:global_step/sec: 373.034
INFO:tensorflow:loss = 0.68044865, step = 1000 (0.268 sec)
INFO:tensorflow:global_step/sec: 372.193
INFO:tensorflow:loss = 0.65181077, step = 1100 (0.269 sec)
INFO:tensorflow:global_step/sec: 339.238
INFO:tensorflow:loss = 0.6319051, step = 1200 (0.295 sec)
INFO:tensorflow:global_step/sec: 334.252
INFO:tensorflow:loss = 0.63433766, step = 1300 (0.299 sec)
INFO:tensorflow:global_step/sec: 343.436
INFO:tensorflow:loss = 0.61748827, step = 1400 (0.291 sec)
INFO:tensorflow:global_step/sec: 346.575
INFO:tensorflow:loss = 0.606356, step = 1500 (0.288 sec)
INFO:tensorflow:global_step/sec: 351.362
INFO:tensorflow:loss = 0.59807724, step = 1600 (0.285 sec)
INFO:tensorflow:global_step/sec: 366.628
INFO:tensorflow:loss = 0.5832784, step = 1700 (0.273 sec)
INFO:tensorflow:global_step/sec: 367.034
INFO:tensorflow:loss = 0.5664347, step = 1800 (0.273 sec)
INFO:tensorflow:global_step/sec: 372.339
INFO:tensorflow:loss = 0.5684726, step = 1900 (0.268 sec)
INFO:tensorflow:global_step/sec: 368.957
INFO:tensorflow:loss = 0.56011164, step = 2000 (0.271 sec)
INFO:tensorflow:global_step/sec: 373.128
INFO:tensorflow:loss = 0.5483226, step = 2100 (0.268 sec)
INFO:tensorflow:global_step/sec: 377.334
INFO:tensorflow:loss = 0.5447233, step = 2200 (0.265 sec)
INFO:tensorflow:global_step/sec: 370.421
INFO:tensorflow:loss = 0.5358016, step = 2300 (0.270 sec)
INFO:tensorflow:global_step/sec: 367.076
INFO:tensorflow:loss = 0.53145075, step = 2400 (0.273 sec)
INFO:tensorflow:global_step/sec: 373.596
INFO:tensorflow:loss = 0.50931674, step = 2500 (0.268 sec)
INFO:tensorflow:global_step/sec: 368.939
INFO:tensorflow:loss = 0.5253717, step = 2600 (0.271 sec)
INFO:tensorflow:global_step/sec: 354.814
INFO:tensorflow:loss = 0.52558273, step = 2700 (0.282 sec)
INFO:tensorflow:global_step/sec: 372.243
INFO:tensorflow:loss = 0.51422054, step = 2800 (0.269 sec)
INFO:tensorflow:global_step/sec: 366.891
INFO:tensorflow:loss = 0.49747026, step = 2900 (0.272 sec)
INFO:tensorflow:global_step/sec: 370.952
INFO:tensorflow:loss = 0.49974674, step = 3000 (0.270 sec)
INFO:tensorflow:global_step/sec: 364.158
INFO:tensorflow:loss = 0.4978399, step = 3100 (0.275 sec)
INFO:tensorflow:global_step/sec: 365.383
INFO:tensorflow:loss = 0.5030147, step = 3200 (0.273 sec)
INFO:tensorflow:global_step/sec: 366.791
INFO:tensorflow:loss = 0.4772169, step = 3300 (0.273 sec)
INFO:tensorflow:global_step/sec: 372.438
INFO:tensorflow:loss = 0.46993533, step = 3400 (0.269 sec)
INFO:tensorflow:global_step/sec: 371.25
INFO:tensorflow:loss = 0.47242266, step = 3500 (0.269 sec)
INFO:tensorflow:global_step/sec: 369.725
INFO:tensorflow:loss = 0.46513358, step = 3600 (0.271 sec)
INFO:tensorflow:global_step/sec: 371.002
INFO:tensorflow:loss = 0.4762191, step = 3700 (0.270 sec)
INFO:tensorflow:global_step/sec: 369.304
INFO:tensorflow:loss = 0.44923267, step = 3800 (0.271 sec)
INFO:tensorflow:global_step/sec: 369.344
INFO:tensorflow:loss = 0.45467538, step = 3900 (0.271 sec)
INFO:tensorflow:global_step/sec: 375.58
INFO:tensorflow:loss = 0.46056622, step = 4000 (0.266 sec)
INFO:tensorflow:global_step/sec: 347.461
INFO:tensorflow:loss = 0.4489282, step = 4100 (0.288 sec)
INFO:tensorflow:global_step/sec: 368.435
INFO:tensorflow:loss = 0.45647347, step = 4200 (0.272 sec)
INFO:tensorflow:global_step/sec: 369.159
INFO:tensorflow:loss = 0.4444633, step = 4300 (0.271 sec)
INFO:tensorflow:global_step/sec: 371.995
INFO:tensorflow:loss = 0.44425523, step = 4400 (0.269 sec)
INFO:tensorflow:global_step/sec: 373.586
INFO:tensorflow:loss = 0.44025964, step = 4500 (0.268 sec)
INFO:tensorflow:global_step/sec: 373.136
INFO:tensorflow:loss = 0.44341013, step = 4600 (0.269 sec)
INFO:tensorflow:global_step/sec: 369.751
INFO:tensorflow:loss = 0.42856425, step = 4700 (0.269 sec)
INFO:tensorflow:global_step/sec: 364.219
INFO:tensorflow:loss = 0.44144967, step = 4800 (0.275 sec)
INFO:tensorflow:global_step/sec: 372.675
INFO:tensorflow:loss = 0.42951846, step = 4900 (0.268 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5000...
INFO:tensorflow:Saving checkpoints for 5000 into /tmp/tmpxdgumb2t/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5000...
INFO:tensorflow:Loss for final step: 0.42713496.
<tensorflow_estimator.python.estimator.canned.dnn.DNNClassifierV2 at 0x7fad05e33910>

Zauważ, że opakowujesz wywołanie input_fn w lambda , aby przechwycić argumenty, jednocześnie udostępniając funkcję input, która nie przyjmuje żadnych argumentów, zgodnie z oczekiwaniami estymatora. Argument steps mówi metodzie, aby zatrzymać trenowanie po kilku krokach trenowania.

Oceń wyszkolony model

Teraz, gdy model został przeszkolony, możesz uzyskać statystyki dotyczące jego wydajności. Poniższy blok kodu ocenia dokładność wytrenowanego modelu na danych testowych:

eval_result = classifier.evaluate(
    input_fn=lambda: input_fn(test, test_y, training=False))

print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-01-26T06:41:28
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpxdgumb2t/model.ckpt-5000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 0.40087s
INFO:tensorflow:Finished evaluation at 2022-01-26-06:41:28
INFO:tensorflow:Saving dict for global step 5000: accuracy = 0.8666667, average_loss = 0.49953422, global_step = 5000, loss = 0.49953422
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5000: /tmp/tmpxdgumb2t/model.ckpt-5000

Test set accuracy: 0.867

W przeciwieństwie do wywołania metody train , argument steps nie został przekazany do oceny. input_fn dla eval zwraca tylko jedną epokę danych.

Słownik eval_result zawiera również średnią average_loss (średnią stratę na próbkę), loss (średnią stratę na mini-partię) oraz wartość global_step estymatora (liczbę iteracji uczących, które przeszedł).

Wykonywanie prognoz (wnioskowanie) z wytrenowanego modelu

Masz teraz wyszkolony model, który daje dobre wyniki oceny. Teraz możesz użyć wytrenowanego modelu, aby przewidzieć gatunek kwiatu tęczówki na podstawie niektórych nieoznakowanych pomiarów. Podobnie jak w przypadku treningu i oceny, prognozy dokonujesz za pomocą pojedynczego wywołania funkcji:

# Generate predictions from the model
expected = ['Setosa', 'Versicolor', 'Virginica']
predict_x = {
    'SepalLength': [5.1, 5.9, 6.9],
    'SepalWidth': [3.3, 3.0, 3.1],
    'PetalLength': [1.7, 4.2, 5.4],
    'PetalWidth': [0.5, 1.5, 2.1],
}

def input_fn(features, batch_size=256):
    """An input function for prediction."""
    # Convert the inputs to a Dataset without labels.
    return tf.data.Dataset.from_tensor_slices(dict(features)).batch(batch_size)

predictions = classifier.predict(
    input_fn=lambda: input_fn(predict_x))

Metoda predict zwraca iterację Pythona, dając słownik wyników przewidywania dla każdego przykładu. Poniższy kod wyświetla kilka prognoz i ich prawdopodobieństw:

for pred_dict, expec in zip(predictions, expected):
    class_id = pred_dict['class_ids'][0]
    probability = pred_dict['probabilities'][class_id]

    print('Prediction is "{}" ({:.1f}%), expected "{}"'.format(
        SPECIES[class_id], 100 * probability, expec))
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpxdgumb2t/model.ckpt-5000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
Prediction is "Setosa" (84.4%), expected "Setosa"
Prediction is "Versicolor" (49.3%), expected "Versicolor"
Prediction is "Virginica" (57.7%), expected "Virginica"