Краткое руководство по TensorFlow 2 для начинающих

Посмотреть на TensorFlow.org Запустить в Google Colab Посмотреть исходный код на GitHub Скачать блокнот

Это краткое введение использует Keras для:

  1. Загрузите готовый набор данных.
  2. Создайте модель машинного обучения нейронной сети, которая классифицирует изображения.
  3. Обучите эту нейронную сеть.
  4. Оцените точность модели.

Это руководство представляет собой записную книжку Google Colaboratory . Программы на Python запускаются непосредственно в браузере — отличный способ изучить и использовать TensorFlow. Чтобы следовать этому руководству, запустите блокнот в Google Colab, нажав кнопку в верхней части этой страницы.

  1. В Colab подключитесь к среде выполнения Python: в правом верхнем углу строки меню выберите ПОДКЛЮЧИТЬСЯ .
  2. Запустите все ячейки кода записной книжки. Выберите « Среда выполнения » > « Выполнить все» .

Настроить TensorFlow.

Импортируйте TensorFlow в свою программу, чтобы начать:

import tensorflow as tf
print("TensorFlow version:", tf.__version__)
TensorFlow version: 2.8.0-rc1

Если вы используете собственную среду разработки, а не Colab , см. руководство по установке для настройки TensorFlow для разработки.

Загрузите набор данных

Загрузите и подготовьте набор данных MNIST . Преобразуйте данные выборки из целых чисел в числа с плавающей запятой:

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

Построить модель машинного обучения

Создайте модель tf.keras.Sequential путем наложения слоев.

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])

Для каждого примера модель возвращает вектор оценок логитов или логарифмических шансов , по одному для каждого класса.

predictions = model(x_train[:1]).numpy()
predictions
array([[ 0.2760778 , -0.39324787, -0.17098302,  1.2016621 , -0.03416392,
         0.5461229 , -0.7203061 , -0.41886678, -0.59480035, -0.7580608 ]],
      dtype=float32)

Функция tf.nn.softmax преобразует эти логиты в вероятности для каждого класса:

tf.nn.softmax(predictions).numpy()
array([[0.11960829, 0.06124588, 0.0764901 , 0.30181262, 0.08770514,
        0.15668967, 0.04416083, 0.05969675, 0.05006609, 0.04252464]],
      dtype=float32)

Определите функцию потерь для обучения с использованием losses.SparseCategoricalCrossentropy , которая принимает вектор логитов и индекс True и возвращает скалярную потерю для каждого примера.

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

Эта потеря равна отрицательной логарифмической вероятности истинного класса: потеря равна нулю, если модель уверена в правильном классе.

Эта необученная модель дает вероятности, близкие к случайным (1/10 для каждого класса), поэтому начальная потеря должна быть близка к -tf.math.log(1/10) ~= 2.3 .

loss_fn(y_train[:1], predictions).numpy()
1.8534881

Перед началом обучения настройте и скомпилируйте модель с помощью Model.compile . Установите класс optimizer на adam , установите loss на функцию loss_fn , которую вы определили ранее, и укажите метрику, которая будет оцениваться для модели, установив параметр metrics на accuracy .

model.compile(optimizer='adam',
              loss=loss_fn,
              metrics=['accuracy'])

Обучите и оцените свою модель

Используйте метод Model.fit , чтобы настроить параметры модели и минимизировать потери:

model.fit(x_train, y_train, epochs=5)
Epoch 1/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.2950 - accuracy: 0.9143
Epoch 2/5
1875/1875 [==============================] - 3s 2ms/step - loss: 0.1451 - accuracy: 0.9567
Epoch 3/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.1080 - accuracy: 0.9668
Epoch 4/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.0906 - accuracy: 0.9717
Epoch 5/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.0749 - accuracy: 0.9761
<keras.callbacks.History at 0x7f062c606850>

Метод Model.evaluate проверяет производительность моделей, обычно на " Validation-set " или " Test-set ".

model.evaluate(x_test,  y_test, verbose=2)
313/313 - 1s - loss: 0.0783 - accuracy: 0.9755 - 588ms/epoch - 2ms/step
[0.07825208455324173, 0.9754999876022339]

Классификатор изображений теперь обучен с точностью ~ 98% на этом наборе данных. Чтобы узнать больше, прочитайте туториалы по TensorFlow .

Если вы хотите, чтобы ваша модель возвращала вероятность, вы можете обернуть обученную модель и прикрепить к ней softmax:

probability_model = tf.keras.Sequential([
  model,
  tf.keras.layers.Softmax()
])
probability_model(x_test[:5])
<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[2.72807270e-08, 2.42517650e-08, 7.75602894e-06, 1.28684027e-04,
        7.66215633e-11, 3.54162950e-07, 3.04894151e-14, 9.99857187e-01,
        2.32766553e-08, 5.97762892e-06],
       [7.37396704e-08, 4.73638036e-04, 9.99523997e-01, 7.20633352e-07,
        4.54133671e-17, 1.42298268e-06, 5.96959016e-09, 1.23534145e-13,
        7.77225608e-08, 6.98619169e-16],
       [1.95462448e-07, 9.99295831e-01, 1.02249986e-04, 1.86699708e-05,
        5.65737491e-06, 1.12115902e-06, 5.32719559e-06, 5.22767776e-04,
        4.79981136e-05, 1.76624681e-07],
       [9.99649286e-01, 1.80224735e-09, 3.73612856e-05, 1.52324446e-07,
        1.30824594e-06, 2.82781020e-05, 6.99703523e-05, 3.30940424e-07,
        2.13184350e-07, 2.13106396e-04],
       [1.53770895e-06, 1.72272063e-08, 1.98980865e-06, 3.97882580e-08,
        9.97192323e-01, 1.10544443e-05, 1.54713348e-06, 2.81727880e-05,
        3.48721733e-06, 2.75991508e-03]], dtype=float32)>

Вывод

Поздравляем! Вы обучили модель машинного обучения, используя готовый набор данных с помощью API Keras .

Дополнительные примеры использования Keras см. в туториалах . Чтобы узнать больше о построении моделей с помощью Keras, прочитайте руководства . Если вы хотите узнать больше о загрузке и подготовке данных, см. руководства по загрузке данных изображения или загрузке данных CSV .